Skip to content

vllm.tool_parsers.kimi_k2_tool_parser

KimiK2ToolParser

Bases: ToolParser

Source code in vllm/tool_parsers/kimi_k2_tool_parser.py
class KimiK2ToolParser(ToolParser):
    def __init__(self, tokenizer: TokenizerLike, tools: list[Tool] | None = None):
        super().__init__(tokenizer, tools)

        # Streaming state
        self._sent_content_idx: int = 0
        self.prev_tool_call_arr: list[dict] = []
        self.streamed_args_for_tool: list[str] = []

        # Section marker
        self.tool_calls_start_token: str = "<|tool_calls_section_begin|>"

        # Individual tool call markers
        self.tool_call_start_token: str = "<|tool_call_begin|>"
        self.tool_call_end_token: str = "<|tool_call_end|>"
        self.tool_call_arg_token: str = "<|tool_call_argument_begin|>"

        # Regex for non-streaming extraction
        self.tool_call_regex = re.compile(
            r"<\|tool_call_begin\|>\s*(?P<tool_call_id>[^<]+:\d+)\s*"
            r"<\|tool_call_argument_begin\|>\s*"
            r"(?P<function_arguments>(?:(?!<\|tool_call_begin\|>).)*?)\s*"
            r"<\|tool_call_end\|>",
            re.DOTALL,
        )

        if not self.model_tokenizer:
            raise ValueError(
                "The model tokenizer must be passed to the ToolParser "
                "constructor during construction."
            )

    def adjust_request(
        self, request: ChatCompletionRequest | ResponsesRequest
    ) -> ChatCompletionRequest | ResponsesRequest:
        request = super().adjust_request(request)
        if request.tools and request.tool_choice != "none":
            # Ensure special-token markers appear as literal text in
            # current_text so we can do pure text-based parsing.
            request.skip_special_tokens = False
        return request

    def extract_tool_calls(
        self,
        model_output: str,
        request: ChatCompletionRequest,
    ) -> ExtractedToolCallInformation:
        # sanity check; avoid unnecessary processing
        if self.tool_calls_start_token not in model_output:
            return ExtractedToolCallInformation(
                tools_called=False, tool_calls=[], content=model_output
            )

        else:
            try:
                # there are two possible captures - between tags, or between a
                # tag and end-of-string so the result of
                # findall is an array of tuples where one is a function call and
                # the other is None
                function_call_tuples = self.tool_call_regex.findall(model_output)

                logger.debug("function_call_tuples: %s", function_call_tuples)

                tool_calls = []
                for match in function_call_tuples:
                    function_id, function_args = match
                    # function_id: functions.get_weather:0 or get_weather:0
                    function_name = function_id.split(":")[0].split(".")[-1]
                    tool_calls.append(
                        ToolCall(
                            id=function_id,
                            type="function",
                            function=FunctionCall(
                                name=function_name, arguments=function_args
                            ),
                        )
                    )

                content = model_output[: model_output.find(self.tool_calls_start_token)]
                return ExtractedToolCallInformation(
                    tools_called=True,
                    tool_calls=tool_calls,
                    content=content if content else None,
                )

            except Exception:
                logger.exception("Error in extracting tool call from response.")
                return ExtractedToolCallInformation(
                    tools_called=False, tool_calls=[], content=model_output
                )

    def _extract_content(self, current_text: str) -> str | None:
        """Return unsent content before the tool-calls section, or None.

        Holds back any trailing suffix that partially matches
        ``<|tool_calls_section_begin|>`` to avoid leaking marker bytes.
        """
        if self.tool_calls_start_token not in current_text:
            overlap = partial_tag_overlap(current_text, self.tool_calls_start_token)
            sendable_idx = len(current_text) - overlap
        else:
            sendable_idx = current_text.index(self.tool_calls_start_token)

        if sendable_idx > self._sent_content_idx:
            content = current_text[self._sent_content_idx : sendable_idx]
            self._sent_content_idx = sendable_idx
            return content
        return None

    def _extract_tool_calls(self, current_text: str) -> list[str]:
        """Extract raw bodies from ``<|tool_call_begin|>…<|tool_call_end|>`` blocks."""
        if self.tool_calls_start_token not in current_text:
            return []

        results: list[str] = []
        pos = current_text.index(self.tool_calls_start_token)
        while True:
            start = current_text.find(self.tool_call_start_token, pos)
            if start == -1:
                break
            tc_start = start + len(self.tool_call_start_token)
            end = current_text.find(self.tool_call_end_token, tc_start)

            if end != -1:
                tool_call = current_text[tc_start:end]
                pos = end + len(self.tool_call_end_token)
            else:
                tool_call = current_text[tc_start:]
                overlap = partial_tag_overlap(tool_call, self.tool_call_end_token)
                if overlap:
                    tool_call = tool_call[:-overlap]

            results.append(tool_call)

            if end == -1:
                break
        return results

    @staticmethod
    def _extract_tool_id_and_name(
        header: str | None,
    ) -> tuple[str | None, str | None]:
        """Parse ``(tool_id, tool_name)`` from a header
        like ``"functions.get_weather:0"``."""
        if header is None:
            return None, None
        match = re.match(r"(.+:\d+)", header)
        if not match:
            return None, None

        tool_id = match.group(1).strip()
        tool_name = tool_id.split(":")[0].split(".")[-1]
        return tool_id, tool_name

    def _split_tool_call(self, tool_call: str) -> tuple[str | None, str | None]:
        """Split a tool-call body into ``(header, arguments)`` at the argument marker.

        Example::
            'get_weather:0 <|tool_call_argument_begin|>{"c'
            -> ("get_weather:0", '{"c')
        """
        arg_pos = tool_call.find(self.tool_call_arg_token)
        if arg_pos == -1:
            return None, None
        header = tool_call[:arg_pos].strip()
        tool_args = tool_call[arg_pos + len(self.tool_call_arg_token) :]
        return header, tool_args

    def _compute_args_diff(self, index: int, tool_args: str | None) -> str | None:
        """Return new argument text not yet sent for tool `index`, or None."""
        if tool_args is None:
            return None
        prev = self.streamed_args_for_tool[index]
        if len(tool_args) <= len(prev):
            return None
        diff = tool_args[len(prev) :]
        self.streamed_args_for_tool[index] = tool_args
        self.prev_tool_call_arr[index]["arguments"] = tool_args
        return diff

    def extract_tool_calls_streaming(
        self,
        previous_text: str,
        current_text: str,
        delta_text: str,
        previous_token_ids: Sequence[int],
        current_token_ids: Sequence[int],
        delta_token_ids: Sequence[int],
        request: ChatCompletionRequest,
    ) -> DeltaMessage | None:
        try:
            # Extract any content before tool calls.
            content = self._extract_content(current_text)
            tool_calls = self._extract_tool_calls(current_text)
            tool_call_deltas: list[DeltaToolCall] = []

            for i, tool_call in enumerate(tool_calls):
                # First time seeing tool call at index i.
                if i >= len(self.prev_tool_call_arr):
                    # Initialize streaming state.
                    self.prev_tool_call_arr.append({})
                    self.streamed_args_for_tool.append("")

                header, tool_args = self._split_tool_call(tool_call)

                # Stream back tool name.
                if "name" not in self.prev_tool_call_arr[i]:
                    tool_id, tool_name = self._extract_tool_id_and_name(header)
                    if not tool_name:
                        # Can't skip to tool i+1 if i isn't ready
                        break
                    self.prev_tool_call_arr[i]["name"] = tool_name
                    self.prev_tool_call_arr[i]["id"] = tool_id
                    tool_call_deltas.append(
                        DeltaToolCall(
                            index=i,
                            type="function",
                            id=tool_id,
                            function=DeltaFunctionCall(name=tool_name).model_dump(
                                exclude_none=True
                            ),
                        )
                    )

                # Stream back new tool args by diffing against what was sent.
                args_diff = self._compute_args_diff(i, tool_args)
                if args_diff:
                    tool_call_deltas.append(
                        DeltaToolCall(
                            index=i,
                            function=DeltaFunctionCall(arguments=args_diff).model_dump(
                                exclude_none=True
                            ),
                        )
                    )

            if content or tool_call_deltas:
                return DeltaMessage(
                    content=content,
                    tool_calls=tool_call_deltas,
                )
            return None

        except Exception:
            logger.exception("Error trying to handle streaming tool call.")
            return None

_compute_args_diff

_compute_args_diff(
    index: int, tool_args: str | None
) -> str | None

Return new argument text not yet sent for tool index, or None.

Source code in vllm/tool_parsers/kimi_k2_tool_parser.py
def _compute_args_diff(self, index: int, tool_args: str | None) -> str | None:
    """Return new argument text not yet sent for tool `index`, or None."""
    if tool_args is None:
        return None
    prev = self.streamed_args_for_tool[index]
    if len(tool_args) <= len(prev):
        return None
    diff = tool_args[len(prev) :]
    self.streamed_args_for_tool[index] = tool_args
    self.prev_tool_call_arr[index]["arguments"] = tool_args
    return diff

_extract_content

_extract_content(current_text: str) -> str | None

Return unsent content before the tool-calls section, or None.

Holds back any trailing suffix that partially matches <|tool_calls_section_begin|> to avoid leaking marker bytes.

Source code in vllm/tool_parsers/kimi_k2_tool_parser.py
def _extract_content(self, current_text: str) -> str | None:
    """Return unsent content before the tool-calls section, or None.

    Holds back any trailing suffix that partially matches
    ``<|tool_calls_section_begin|>`` to avoid leaking marker bytes.
    """
    if self.tool_calls_start_token not in current_text:
        overlap = partial_tag_overlap(current_text, self.tool_calls_start_token)
        sendable_idx = len(current_text) - overlap
    else:
        sendable_idx = current_text.index(self.tool_calls_start_token)

    if sendable_idx > self._sent_content_idx:
        content = current_text[self._sent_content_idx : sendable_idx]
        self._sent_content_idx = sendable_idx
        return content
    return None

_extract_tool_calls

_extract_tool_calls(current_text: str) -> list[str]

Extract raw bodies from <|tool_call_begin|>…<|tool_call_end|> blocks.

Source code in vllm/tool_parsers/kimi_k2_tool_parser.py
def _extract_tool_calls(self, current_text: str) -> list[str]:
    """Extract raw bodies from ``<|tool_call_begin|>…<|tool_call_end|>`` blocks."""
    if self.tool_calls_start_token not in current_text:
        return []

    results: list[str] = []
    pos = current_text.index(self.tool_calls_start_token)
    while True:
        start = current_text.find(self.tool_call_start_token, pos)
        if start == -1:
            break
        tc_start = start + len(self.tool_call_start_token)
        end = current_text.find(self.tool_call_end_token, tc_start)

        if end != -1:
            tool_call = current_text[tc_start:end]
            pos = end + len(self.tool_call_end_token)
        else:
            tool_call = current_text[tc_start:]
            overlap = partial_tag_overlap(tool_call, self.tool_call_end_token)
            if overlap:
                tool_call = tool_call[:-overlap]

        results.append(tool_call)

        if end == -1:
            break
    return results

_extract_tool_id_and_name staticmethod

_extract_tool_id_and_name(
    header: str | None,
) -> tuple[str | None, str | None]

Parse (tool_id, tool_name) from a header like "functions.get_weather:0".

Source code in vllm/tool_parsers/kimi_k2_tool_parser.py
@staticmethod
def _extract_tool_id_and_name(
    header: str | None,
) -> tuple[str | None, str | None]:
    """Parse ``(tool_id, tool_name)`` from a header
    like ``"functions.get_weather:0"``."""
    if header is None:
        return None, None
    match = re.match(r"(.+:\d+)", header)
    if not match:
        return None, None

    tool_id = match.group(1).strip()
    tool_name = tool_id.split(":")[0].split(".")[-1]
    return tool_id, tool_name

_split_tool_call

_split_tool_call(
    tool_call: str,
) -> tuple[str | None, str | None]

Split a tool-call body into (header, arguments) at the argument marker.

Example:: 'get_weather:0 <|tool_call_argument_begin|>{"c' -> ("get_weather:0", '{"c')

Source code in vllm/tool_parsers/kimi_k2_tool_parser.py
def _split_tool_call(self, tool_call: str) -> tuple[str | None, str | None]:
    """Split a tool-call body into ``(header, arguments)`` at the argument marker.

    Example::
        'get_weather:0 <|tool_call_argument_begin|>{"c'
        -> ("get_weather:0", '{"c')
    """
    arg_pos = tool_call.find(self.tool_call_arg_token)
    if arg_pos == -1:
        return None, None
    header = tool_call[:arg_pos].strip()
    tool_args = tool_call[arg_pos + len(self.tool_call_arg_token) :]
    return header, tool_args