|
|
import json |
|
|
from collections.abc import Sequence |
|
|
from random import choices |
|
|
from string import ascii_letters, digits |
|
|
from typing import Union |
|
|
|
|
|
import partial_json_parser |
|
|
import regex as re |
|
|
from partial_json_parser.core.options import Allow |
|
|
from pydantic import Field |
|
|
|
|
|
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, |
|
|
DeltaFunctionCall, DeltaMessage, |
|
|
DeltaToolCall, |
|
|
ExtractedToolCallInformation, |
|
|
FunctionCall, ToolCall) |
|
|
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( |
|
|
ToolParser, ToolParserManager) |
|
|
from vllm.entrypoints.openai.tool_parsers.utils import ( |
|
|
extract_intermediate_diff) |
|
|
from vllm.logger import init_logger |
|
|
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer |
|
|
|
|
|
logger = init_logger(__name__) |
|
|
|
|
|
ALPHANUMERIC = ascii_letters + digits |
|
|
|
|
|
|
|
|
class NemotronToolCall(ToolCall): |
|
|
id: str = Field( |
|
|
default_factory=lambda: NemotronToolCall.generate_random_id()) |
|
|
|
|
|
@staticmethod |
|
|
def generate_random_id(): |
|
|
return "".join(choices(ALPHANUMERIC, k=9)) |
|
|
|
|
|
@staticmethod |
|
|
def is_valid_id(id: str) -> bool: |
|
|
return id.isalnum() and len(id) == 9 |
|
|
|
|
|
|
|
|
def _is_fn_name_regex_support(model_tokenizer: AnyTokenizer) -> bool: |
|
|
return isinstance(model_tokenizer, MistralTokenizer) \ |
|
|
and model_tokenizer.version >= 11 |
|
|
|
|
|
|
|
|
@ToolParserManager.register_module("nemotron_json") |
|
|
class NemotronToolParser(ToolParser): |
|
|
""" |
|
|
Tool call parser for Nemotron-Nano-V2 |
|
|
|
|
|
Used when --enable-auto-tool-choice --tool-call-parser nemotron_json are all set |
|
|
""" |
|
|
|
|
|
def __init__(self, tokenizer: AnyTokenizer): |
|
|
super().__init__(tokenizer) |
|
|
|
|
|
|
|
|
self.prev_tool_call_arr: list[dict] = [] |
|
|
self.current_tool_id: int = -1 |
|
|
self.current_tool_name_sent: bool = False |
|
|
self.streamed_args_for_tool: list[str] = [ |
|
|
] |
|
|
self.bot_token = "<TOOLCALL>" |
|
|
self.bot_token_id = self.vocab.get(self.bot_token) |
|
|
logger.info(f"Nemotron Tool Parser: bot_token: {self.bot_token}, bot_token_id: {self.bot_token_id}") |
|
|
self.tool_call_regex = re.compile(r"\[{.*}\]", re.DOTALL) |
|
|
if _is_fn_name_regex_support(self.model_tokenizer): |
|
|
self.fn_name_regex = re.compile( |
|
|
r'([a-zA-Z0-9_-]+)(\{[\s\S]*?\})(?=\s*$|,|\s)', re.DOTALL) |
|
|
else: |
|
|
self.fn_name_regex = None |
|
|
|
|
|
|
|
|
|
|
|
self._pending_tag_buffer: str = "" |
|
|
|
|
|
def adjust_request( |
|
|
self, request: ChatCompletionRequest) -> ChatCompletionRequest: |
|
|
if not isinstance( |
|
|
self.model_tokenizer, MistralTokenizer |
|
|
) and request.tools and request.tool_choice != 'none': |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
request.skip_special_tokens = False |
|
|
return request |
|
|
|
|
|
def extract_tool_calls( |
|
|
self, |
|
|
model_output: str, |
|
|
request: ChatCompletionRequest, |
|
|
) -> ExtractedToolCallInformation: |
|
|
""" |
|
|
Extract the tool calls from a complete model response. Requires |
|
|
find-and-replacing single quotes with double quotes for JSON parsing, |
|
|
make sure your tool call arguments don't ever include quotes! |
|
|
""" |
|
|
|
|
|
|
|
|
if self.bot_token not in model_output: |
|
|
return ExtractedToolCallInformation(tools_called=False, |
|
|
tool_calls=[], |
|
|
content=model_output) |
|
|
|
|
|
|
|
|
tool_content = model_output.replace(self.bot_token, "").strip() |
|
|
|
|
|
try: |
|
|
|
|
|
|
|
|
try: |
|
|
if self.fn_name_regex: |
|
|
matches = self.fn_name_regex.findall(tool_content) |
|
|
|
|
|
function_call_arr = [] |
|
|
for match in matches: |
|
|
fn_name = match[0] |
|
|
args = match[1] |
|
|
|
|
|
|
|
|
|
|
|
function_call_arr.append({ |
|
|
"name": fn_name, |
|
|
"arguments": json.loads(args) |
|
|
}) |
|
|
else: |
|
|
function_call_arr = json.loads(tool_content) |
|
|
except json.JSONDecodeError: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
raw_tool_call = self.tool_call_regex.findall(tool_content)[0] |
|
|
function_call_arr = json.loads(raw_tool_call) |
|
|
|
|
|
|
|
|
tool_calls: list[NemotronToolCall] = [ |
|
|
NemotronToolCall( |
|
|
type="function", |
|
|
function=FunctionCall( |
|
|
name=raw_function_call["name"], |
|
|
|
|
|
arguments=json.dumps(raw_function_call["arguments"], |
|
|
ensure_ascii=False))) |
|
|
for raw_function_call in function_call_arr |
|
|
] |
|
|
|
|
|
|
|
|
content = model_output.split(self.bot_token)[0] |
|
|
return ExtractedToolCallInformation( |
|
|
tools_called=True, |
|
|
tool_calls=tool_calls, |
|
|
content=content if len(content) > 0 else None) |
|
|
|
|
|
except Exception: |
|
|
logger.exception("Error in extracting tool call from response.") |
|
|
|
|
|
return ExtractedToolCallInformation(tools_called=False, |
|
|
tool_calls=[], |
|
|
content=tool_content) |
|
|
|
|
|
def extract_tool_calls_streaming( |
|
|
self, |
|
|
previous_text: str, |
|
|
current_text: str, |
|
|
delta_text: str, |
|
|
previous_token_ids: Sequence[int], |
|
|
current_token_ids: Sequence[int], |
|
|
delta_token_ids: Sequence[int], |
|
|
request: ChatCompletionRequest, |
|
|
) -> Union[DeltaMessage, None]: |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
start_token = self.bot_token |
|
|
end_token = f"</{self.bot_token[1:]}" if self.bot_token.startswith('<') else None |
|
|
|
|
|
|
|
|
if delta_text == '<' and not self._pending_tag_buffer: |
|
|
|
|
|
self._pending_tag_buffer = '<' |
|
|
return None |
|
|
|
|
|
|
|
|
if self._pending_tag_buffer: |
|
|
|
|
|
self._pending_tag_buffer += delta_text |
|
|
|
|
|
|
|
|
alpha_part = "" |
|
|
for i in range(1, len(self._pending_tag_buffer)): |
|
|
if self._pending_tag_buffer[i].isalpha(): |
|
|
alpha_part += self._pending_tag_buffer[i].upper() |
|
|
else: |
|
|
break |
|
|
|
|
|
|
|
|
|
|
|
if '<TOOLCALL>' in self._pending_tag_buffer: |
|
|
|
|
|
buffered_content = self._pending_tag_buffer |
|
|
self._pending_tag_buffer = "" |
|
|
|
|
|
|
|
|
updated_current_text = previous_text + buffered_content |
|
|
updated_delta_text = buffered_content |
|
|
|
|
|
|
|
|
current_text = updated_current_text |
|
|
delta_text = updated_delta_text |
|
|
|
|
|
elif self._pending_tag_buffer.startswith('</'): |
|
|
|
|
|
return None |
|
|
elif alpha_part and "TOOLCALL".startswith(alpha_part) and len(alpha_part) < 8: |
|
|
|
|
|
return None |
|
|
elif len(alpha_part) > 0 and not "TOOLCALL".startswith(alpha_part): |
|
|
|
|
|
content_to_flush = self._pending_tag_buffer |
|
|
self._pending_tag_buffer = "" |
|
|
return DeltaMessage(content=content_to_flush) |
|
|
else: |
|
|
|
|
|
return None |
|
|
|
|
|
|
|
|
if any(current_text.endswith(start_token[:k]) for k in range(1, len(start_token))): |
|
|
return None |
|
|
if end_token and any(current_text.endswith(end_token[:k]) for k in range(1, len(end_token))): |
|
|
return None |
|
|
except Exception: |
|
|
|
|
|
if current_text.endswith('<') or current_text.endswith('<T') or current_text.endswith('<TO') or current_text.endswith('<TOOL') or current_text.endswith('<TOOLCALL'): |
|
|
return None |
|
|
|
|
|
|
|
|
|
|
|
if self.bot_token not in current_text: |
|
|
|
|
|
if self._pending_tag_buffer: |
|
|
content_to_flush = self._pending_tag_buffer + delta_text |
|
|
self._pending_tag_buffer = "" |
|
|
return DeltaMessage(content=content_to_flush) |
|
|
return DeltaMessage(content=delta_text) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
flags = Allow.ALL if self.current_tool_name_sent \ |
|
|
else Allow.ALL & ~Allow.STR |
|
|
end_of_call: bool = False |
|
|
try: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
parsable_arr = current_text.split(self.bot_token)[-1] |
|
|
|
|
|
|
|
|
if '</TOOLCALL>' in parsable_arr: |
|
|
end_of_call = True |
|
|
parsable_arr = parsable_arr.split('</TOOLCALL>')[0] |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
tool_call_arr: list[dict] = partial_json_parser.loads( |
|
|
parsable_arr, flags) |
|
|
except partial_json_parser.core.exceptions.MalformedJSON: |
|
|
return None |
|
|
|
|
|
current_tool_call: dict = tool_call_arr[self.current_tool_id] \ |
|
|
if len(tool_call_arr) > 0 else {} |
|
|
|
|
|
|
|
|
|
|
|
if len(tool_call_arr) == 0: |
|
|
return None |
|
|
|
|
|
|
|
|
|
|
|
elif (len(tool_call_arr) > 0 |
|
|
and len(tool_call_arr) > self.current_tool_id + 1): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.current_tool_id >= 0: |
|
|
diff: Union[str, None] = current_tool_call.get("arguments") |
|
|
|
|
|
if diff: |
|
|
diff = json.dumps(diff, ensure_ascii=False).replace( |
|
|
self.streamed_args_for_tool[self.current_tool_id], |
|
|
"") |
|
|
delta = DeltaMessage(tool_calls=[ |
|
|
DeltaToolCall(index=self.current_tool_id, |
|
|
function=DeltaFunctionCall( |
|
|
arguments=diff).model_dump( |
|
|
exclude_none=True)) |
|
|
]) |
|
|
self.streamed_args_for_tool[ |
|
|
self.current_tool_id] += diff |
|
|
else: |
|
|
delta = None |
|
|
else: |
|
|
delta = None |
|
|
|
|
|
self.current_tool_id = len(tool_call_arr) - 1 |
|
|
self.current_tool_name_sent = False |
|
|
self.streamed_args_for_tool.append("") |
|
|
return delta |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if not self.current_tool_name_sent: |
|
|
function_name = current_tool_call.get("name") |
|
|
if function_name: |
|
|
|
|
|
delta = DeltaMessage(tool_calls=[ |
|
|
DeltaToolCall(index=self.current_tool_id, |
|
|
type="function", |
|
|
id=NemotronToolCall.generate_random_id(), |
|
|
function=DeltaFunctionCall( |
|
|
name=function_name).model_dump( |
|
|
exclude_none=True)) |
|
|
]) |
|
|
self.current_tool_name_sent = True |
|
|
else: |
|
|
delta = None |
|
|
|
|
|
|
|
|
|
|
|
else: |
|
|
|
|
|
prev_arguments = self.prev_tool_call_arr[ |
|
|
self.current_tool_id].get("arguments") |
|
|
cur_arguments = current_tool_call.get("arguments") |
|
|
|
|
|
new_text = delta_text.replace("\'", "\"") |
|
|
if ('"}' in new_text): |
|
|
new_text = new_text[:new_text.rindex('"}')] |
|
|
|
|
|
if not cur_arguments and not prev_arguments: |
|
|
|
|
|
delta = None |
|
|
elif not cur_arguments and prev_arguments: |
|
|
logger.error( |
|
|
"INVARIANT - impossible to have arguments reset " |
|
|
"mid-arguments") |
|
|
delta = None |
|
|
elif cur_arguments and not prev_arguments: |
|
|
cur_arguments_json = json.dumps(cur_arguments, |
|
|
ensure_ascii=False) |
|
|
streamed_prefix = self.streamed_args_for_tool[ |
|
|
self.current_tool_id] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if (cur_arguments_json.endswith('": ""}') and |
|
|
not streamed_prefix and |
|
|
'": ""' in cur_arguments_json): |
|
|
|
|
|
|
|
|
closing_pos = cur_arguments_json.rfind('": ""}') |
|
|
if closing_pos != -1: |
|
|
arguments_delta = cur_arguments_json[:closing_pos + 4] |
|
|
else: |
|
|
arguments_delta = cur_arguments_json |
|
|
else: |
|
|
|
|
|
if cur_arguments_json.startswith(streamed_prefix): |
|
|
arguments_delta = cur_arguments_json[len(streamed_prefix):] |
|
|
else: |
|
|
|
|
|
arguments_delta = extract_intermediate_diff( |
|
|
cur_arguments_json, streamed_prefix) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if (not self.streamed_args_for_tool[self.current_tool_id] |
|
|
and not end_of_call and arguments_delta |
|
|
and arguments_delta.endswith('}')): |
|
|
arguments_delta = arguments_delta[:-1] |
|
|
|
|
|
|
|
|
if arguments_delta.endswith('"'): |
|
|
arguments_delta = arguments_delta[:-1] |
|
|
if arguments_delta: |
|
|
delta = DeltaMessage(tool_calls=[ |
|
|
DeltaToolCall(index=self.current_tool_id, |
|
|
function=DeltaFunctionCall( |
|
|
arguments=arguments_delta). |
|
|
model_dump(exclude_none=True)) |
|
|
]) |
|
|
self.streamed_args_for_tool[ |
|
|
self.current_tool_id] += arguments_delta |
|
|
else: |
|
|
delta = None |
|
|
|
|
|
elif cur_arguments and prev_arguments: |
|
|
cur_args_json = json.dumps(cur_arguments, |
|
|
ensure_ascii=False) |
|
|
prev_args_json = json.dumps(prev_arguments, |
|
|
ensure_ascii=False) |
|
|
|
|
|
argument_diff = extract_intermediate_diff( |
|
|
cur_args_json, prev_args_json) |
|
|
if argument_diff: |
|
|
delta = DeltaMessage(tool_calls=[ |
|
|
DeltaToolCall(index=self.current_tool_id, |
|
|
function=DeltaFunctionCall( |
|
|
arguments=argument_diff).model_dump( |
|
|
exclude_none=True)) |
|
|
]) |
|
|
self.streamed_args_for_tool[ |
|
|
self.current_tool_id] += argument_diff |
|
|
else: |
|
|
|
|
|
|
|
|
delta = None |
|
|
else: |
|
|
|
|
|
delta = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.prev_tool_call_arr = tool_call_arr |
|
|
|
|
|
|
|
|
if end_of_call and self.current_tool_id >= 0: |
|
|
try: |
|
|
cur_arguments = current_tool_call.get("arguments") |
|
|
if cur_arguments is not None: |
|
|
cur_args_json = json.dumps(cur_arguments, |
|
|
ensure_ascii=False) |
|
|
streamed_prefix = self.streamed_args_for_tool[ |
|
|
self.current_tool_id] |
|
|
|
|
|
if cur_args_json.startswith(streamed_prefix): |
|
|
remaining_suffix = cur_args_json[len( |
|
|
streamed_prefix):] |
|
|
else: |
|
|
remaining_suffix = extract_intermediate_diff( |
|
|
cur_args_json, streamed_prefix) |
|
|
|
|
|
|
|
|
|
|
|
if remaining_suffix and remaining_suffix.strip() and len(remaining_suffix.strip()) > 0: |
|
|
extra = DeltaToolCall( |
|
|
index=self.current_tool_id, |
|
|
function=DeltaFunctionCall( |
|
|
arguments=remaining_suffix).model_dump( |
|
|
exclude_none=True)) |
|
|
if delta is None: |
|
|
delta = DeltaMessage(tool_calls=[extra]) |
|
|
else: |
|
|
if getattr(delta, "tool_calls", None): |
|
|
delta.tool_calls.append(extra) |
|
|
else: |
|
|
delta.tool_calls = [extra] |
|
|
self.streamed_args_for_tool[ |
|
|
self.current_tool_id] += remaining_suffix |
|
|
else: |
|
|
pass |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
return delta |
|
|
|
|
|
except Exception: |
|
|
logger.exception("Error trying to handle streaming tool call.") |
|
|
return None |
|
|
|