fix parsing logic
Browse files
kanana_tool_calls/functionary_kanana_tool_parser.py
CHANGED
|
@@ -37,7 +37,6 @@ class BaseTemplate(ABC):
|
|
| 37 |
def response_to_messages(self, generated_text):
|
| 38 |
raise NotImplementedError
|
| 39 |
|
| 40 |
-
|
| 41 |
class FunctionaryV3Llama31Template(BaseTemplate):
|
| 42 |
def __init__(
|
| 43 |
self,
|
|
@@ -144,7 +143,6 @@ def is_complete_json(input_str):
|
|
| 144 |
except JSONDecodeError:
|
| 145 |
return False
|
| 146 |
|
| 147 |
-
|
| 148 |
@ToolParserManager.register_module(["functionary_v3_llama_31"])
|
| 149 |
class FunctionaryV3Llama31ToolParser(ToolParser):
|
| 150 |
def __init__(self, tokenizer: Union[PreTrainedTokenizerBase, AnyTokenizer]):
|
|
@@ -162,6 +160,8 @@ class FunctionaryV3Llama31ToolParser(ToolParser):
|
|
| 162 |
self._python_tag_id = tokenizer.encode(self._python_tag,
|
| 163 |
add_special_tokens=False)[0]
|
| 164 |
|
|
|
|
|
|
|
| 165 |
def extract_tool_calls(
|
| 166 |
self, model_output: str,
|
| 167 |
request: ChatCompletionRequest) -> ExtractedToolCallInformation:
|
|
@@ -175,7 +175,7 @@ class FunctionaryV3Llama31ToolParser(ToolParser):
|
|
| 175 |
tool_calls=[],
|
| 176 |
content=result["content"])
|
| 177 |
|
| 178 |
-
# our template: <function=function_name>{"arg":"var"}<function>
|
| 179 |
def extract_tool_calls_streaming(
|
| 180 |
self,
|
| 181 |
previous_text: str,
|
|
@@ -186,13 +186,36 @@ class FunctionaryV3Llama31ToolParser(ToolParser):
|
|
| 186 |
delta_token_ids: Sequence[int],
|
| 187 |
request: ChatCompletionRequest,
|
| 188 |
) -> Union[DeltaMessage, None]:
|
|
|
|
| 189 |
# if current_text does not start with function tag (or python tag),
|
| 190 |
-
# stream right away as delta.content
|
| 191 |
if not (current_text.startswith(self._python_tag)
|
| 192 |
or current_text.startswith(self._func_prefix)
|
| 193 |
or self._func_prefix.startswith(current_text)):
|
| 194 |
-
|
| 195 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 196 |
# if current_text ends with stop token,
|
| 197 |
# remove it from the text
|
| 198 |
# CHECK: sometimes text is generated beyond <|eom_id|>
|
|
@@ -311,7 +334,7 @@ class FunctionaryV3Llama31ToolParser(ToolParser):
|
|
| 311 |
else:
|
| 312 |
delta = None
|
| 313 |
# re-set stuff pertaining to progress in the current tool
|
| 314 |
-
self.current_tool_id = len(tool_call_arr) - 1 # update current tool call
|
| 315 |
self.current_tool_name_sent = False
|
| 316 |
self.streamed_args_for_tool.append("")
|
| 317 |
logger.debug("starting on new tool %d", self.current_tool_id)
|
|
|
|
| 37 |
def response_to_messages(self, generated_text):
|
| 38 |
raise NotImplementedError
|
| 39 |
|
|
|
|
| 40 |
class FunctionaryV3Llama31Template(BaseTemplate):
|
| 41 |
def __init__(
|
| 42 |
self,
|
|
|
|
| 143 |
except JSONDecodeError:
|
| 144 |
return False
|
| 145 |
|
|
|
|
| 146 |
@ToolParserManager.register_module(["functionary_v3_llama_31"])
|
| 147 |
class FunctionaryV3Llama31ToolParser(ToolParser):
|
| 148 |
def __init__(self, tokenizer: Union[PreTrainedTokenizerBase, AnyTokenizer]):
|
|
|
|
| 160 |
self._python_tag_id = tokenizer.encode(self._python_tag,
|
| 161 |
add_special_tokens=False)[0]
|
| 162 |
|
| 163 |
+
# added buffer for each tool call parser
|
| 164 |
+
self._buffer = ""
|
| 165 |
def extract_tool_calls(
|
| 166 |
self, model_output: str,
|
| 167 |
request: ChatCompletionRequest) -> ExtractedToolCallInformation:
|
|
|
|
| 175 |
tool_calls=[],
|
| 176 |
content=result["content"])
|
| 177 |
|
| 178 |
+
# our template: <function=function_name>{"arg":"var"}</function>
|
| 179 |
def extract_tool_calls_streaming(
|
| 180 |
self,
|
| 181 |
previous_text: str,
|
|
|
|
| 186 |
delta_token_ids: Sequence[int],
|
| 187 |
request: ChatCompletionRequest,
|
| 188 |
) -> Union[DeltaMessage, None]:
|
| 189 |
+
|
| 190 |
# if current_text does not start with function tag (or python tag),
|
|
|
|
| 191 |
if not (current_text.startswith(self._python_tag)
|
| 192 |
or current_text.startswith(self._func_prefix)
|
| 193 |
or self._func_prefix.startswith(current_text)):
|
| 194 |
+
# for cases like "The answer is <function="
|
| 195 |
+
# let current_text="<function="
|
| 196 |
+
if self._func_prefix in current_text:
|
| 197 |
+
idx = current_text.find(self._func_prefix)
|
| 198 |
+
current_text = current_text[idx:]
|
| 199 |
+
self._buffer = ""
|
| 200 |
+
# for cases like "The answer is <function"
|
| 201 |
+
# add delta_text to buffer to figure out whether to print or not later
|
| 202 |
+
elif delta_text.endswith("<") or (current_text.endswith("<function") and delta_text.endswith("function")):
|
| 203 |
+
self._buffer += delta_text
|
| 204 |
+
return DeltaMessage(content=None)
|
| 205 |
+
# for cases that does not include "<function" at all,
|
| 206 |
+
# stream right away as delta.content
|
| 207 |
+
else:
|
| 208 |
+
delta_text = self._buffer + delta_text
|
| 209 |
+
self._buffer = ""
|
| 210 |
+
return DeltaMessage(content=delta_text)
|
| 211 |
+
|
| 212 |
+
# for cases like "<" or "<function"
|
| 213 |
+
# add to delta_text to buffer
|
| 214 |
+
if delta_text.endswith("<") or (current_text.endswith("<function") and delta_text.endswith("function")):
|
| 215 |
+
self._buffer += delta_text
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
# CHECK: this part not working (should use current_token_ids)
|
| 219 |
# if current_text ends with stop token,
|
| 220 |
# remove it from the text
|
| 221 |
# CHECK: sometimes text is generated beyond <|eom_id|>
|
|
|
|
| 334 |
else:
|
| 335 |
delta = None
|
| 336 |
# re-set stuff pertaining to progress in the current tool
|
| 337 |
+
self.current_tool_id = len(tool_call_arr) - 1 # update current tool call
|
| 338 |
self.current_tool_name_sent = False
|
| 339 |
self.streamed_args_for_tool.append("")
|
| 340 |
logger.debug("starting on new tool %d", self.current_tool_id)
|