mat-mul commited on
Commit
df9d415
ยท
verified ยท
1 Parent(s): f573d22

fix parsing logic

Browse files
kanana_tool_calls/functionary_kanana_tool_parser.py CHANGED
@@ -37,7 +37,6 @@ class BaseTemplate(ABC):
37
  def response_to_messages(self, generated_text):
38
  raise NotImplementedError
39
 
40
-
41
  class FunctionaryV3Llama31Template(BaseTemplate):
42
  def __init__(
43
  self,
@@ -144,7 +143,6 @@ def is_complete_json(input_str):
144
  except JSONDecodeError:
145
  return False
146
 
147
-
148
  @ToolParserManager.register_module(["functionary_v3_llama_31"])
149
  class FunctionaryV3Llama31ToolParser(ToolParser):
150
  def __init__(self, tokenizer: Union[PreTrainedTokenizerBase, AnyTokenizer]):
@@ -162,6 +160,8 @@ class FunctionaryV3Llama31ToolParser(ToolParser):
162
  self._python_tag_id = tokenizer.encode(self._python_tag,
163
  add_special_tokens=False)[0]
164
 
 
 
165
  def extract_tool_calls(
166
  self, model_output: str,
167
  request: ChatCompletionRequest) -> ExtractedToolCallInformation:
@@ -175,7 +175,7 @@ class FunctionaryV3Llama31ToolParser(ToolParser):
175
  tool_calls=[],
176
  content=result["content"])
177
 
178
- # our template: <function=function_name>{"arg":"var"}<function>
179
  def extract_tool_calls_streaming(
180
  self,
181
  previous_text: str,
@@ -186,13 +186,36 @@ class FunctionaryV3Llama31ToolParser(ToolParser):
186
  delta_token_ids: Sequence[int],
187
  request: ChatCompletionRequest,
188
  ) -> Union[DeltaMessage, None]:
 
189
  # if current_text does not start with function tag (or python tag),
190
- # stream right away as delta.content
191
  if not (current_text.startswith(self._python_tag)
192
  or current_text.startswith(self._func_prefix)
193
  or self._func_prefix.startswith(current_text)):
194
- return DeltaMessage(content=delta_text)
195
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
  # if current_text ends with stop token,
197
  # remove it from the text
198
  # CHECK: sometimes text is generated beyond <|eom_id|>
@@ -311,7 +334,7 @@ class FunctionaryV3Llama31ToolParser(ToolParser):
311
  else:
312
  delta = None
313
  # re-set stuff pertaining to progress in the current tool
314
- self.current_tool_id = len(tool_call_arr) - 1 # update current tool call # 1์„ ๋”ํ•˜๋Š”๊ฒŒ ์•„๋‹ˆ๋ผ ์ด๋ ‡๊ฒŒ ํ•˜๋„ค ..
315
  self.current_tool_name_sent = False
316
  self.streamed_args_for_tool.append("")
317
  logger.debug("starting on new tool %d", self.current_tool_id)
 
37
  def response_to_messages(self, generated_text):
38
  raise NotImplementedError
39
 
 
40
  class FunctionaryV3Llama31Template(BaseTemplate):
41
  def __init__(
42
  self,
 
143
  except JSONDecodeError:
144
  return False
145
 
 
146
  @ToolParserManager.register_module(["functionary_v3_llama_31"])
147
  class FunctionaryV3Llama31ToolParser(ToolParser):
148
  def __init__(self, tokenizer: Union[PreTrainedTokenizerBase, AnyTokenizer]):
 
160
  self._python_tag_id = tokenizer.encode(self._python_tag,
161
  add_special_tokens=False)[0]
162
 
163
+ # added buffer for each tool call parser
164
+ self._buffer = ""
165
  def extract_tool_calls(
166
  self, model_output: str,
167
  request: ChatCompletionRequest) -> ExtractedToolCallInformation:
 
175
  tool_calls=[],
176
  content=result["content"])
177
 
178
+ # our template: <function=function_name>{"arg":"var"}</function>
179
  def extract_tool_calls_streaming(
180
  self,
181
  previous_text: str,
 
186
  delta_token_ids: Sequence[int],
187
  request: ChatCompletionRequest,
188
  ) -> Union[DeltaMessage, None]:
189
+
190
  # if current_text does not start with function tag (or python tag),
 
191
  if not (current_text.startswith(self._python_tag)
192
  or current_text.startswith(self._func_prefix)
193
  or self._func_prefix.startswith(current_text)):
194
+ # for cases like "The answer is <function="
195
+ # let current_text="<function="
196
+ if self._func_prefix in current_text:
197
+ idx = current_text.find(self._func_prefix)
198
+ current_text = current_text[idx:]
199
+ self._buffer = ""
200
+ # for cases like "The answer is <function"
201
+ # add delta_text to buffer to figure out whether to print or not later
202
+ elif delta_text.endswith("<") or (current_text.endswith("<function") and delta_text.endswith("function")):
203
+ self._buffer += delta_text
204
+ return DeltaMessage(content=None)
205
+ # for cases that does not include "<function" at all,
206
+ # stream right away as delta.content
207
+ else:
208
+ delta_text = self._buffer + delta_text
209
+ self._buffer = ""
210
+ return DeltaMessage(content=delta_text)
211
+
212
+ # for cases like "<" or "<function"
213
+ # add to delta_text to buffer
214
+ if delta_text.endswith("<") or (current_text.endswith("<function") and delta_text.endswith("function")):
215
+ self._buffer += delta_text
216
+
217
+
218
+ # CHECK: this part not working (should use current_token_ids)
219
  # if current_text ends with stop token,
220
  # remove it from the text
221
  # CHECK: sometimes text is generated beyond <|eom_id|>
 
334
  else:
335
  delta = None
336
  # re-set stuff pertaining to progress in the current tool
337
+ self.current_tool_id = len(tool_call_arr) - 1 # update current tool call
338
  self.current_tool_name_sent = False
339
  self.streamed_args_for_tool.append("")
340
  logger.debug("starting on new tool %d", self.current_tool_id)