Skip to content

vllm.compilation.codegen

Code generation for split_gm stitching graph execution.

Generates a plain Python function that replaces the FX GraphModule's interpreter-based execution of the stitching graph, eliminating nn.Module.call overhead and getattr dispatch.

_node_ref

_node_ref(arg: Any) -> str

Convert an FX node argument to a source code reference recursively.

Source code in vllm/compilation/codegen.py
def _node_ref(arg: Any) -> str:
    """Convert an FX node argument to a source code reference recursively."""
    if isinstance(arg, torch.fx.Node):
        return arg.name
    if isinstance(arg, list):
        return f"[{', '.join(_node_ref(x) for x in arg)}]"
    if isinstance(arg, tuple):
        items = ", ".join(_node_ref(x) for x in arg)
        return f"({items},)" if len(arg) == 1 else f"({items})"
    if isinstance(arg, dict):
        return (
            "{"
            + ", ".join(f"{_node_ref(k)}: {_node_ref(v)}" for k, v in arg.items())
            + "}"
        )
    return repr(arg)

compile_execution_fn

compile_execution_fn(
    code: str,
    submod_callables: dict[str, Callable[..., Any]],
    submod_names: list[str],
) -> Callable[..., Any]

Compile execution code and bind submodule callables.

Parameters:

Name Type Description Default
code str

Python source from generate_execution_code().

required
submod_callables dict[str, Callable[..., Any]]

Mapping of submodule names to their callables.

required
submod_names list[str]

Ordered list of submodule names matching the indices used in the generated code.

required

Returns:

Type Description
Callable[..., Any]

A callable that executes the stitching logic.

Source code in vllm/compilation/codegen.py
@dynamo_timed("vllm.compile_execution_fn")
def compile_execution_fn(
    code: str,
    submod_callables: dict[str, Callable[..., Any]],
    submod_names: list[str],
) -> Callable[..., Any]:
    """Compile execution code and bind submodule callables.

    Args:
        code: Python source from generate_execution_code().
        submod_callables: Mapping of submodule names to their callables.
        submod_names: Ordered list of submodule names matching the indices
            used in the generated code.

    Returns:
        A callable that executes the stitching logic.
    """
    trace_structured(
        "artifact",
        metadata_fn=lambda: {
            "name": "vllm_execution_code",
            "encoding": "string",
        },
        payload_fn=lambda: code,
    )
    namespace: dict[str, Any] = {}
    exec(code, namespace)  # noqa: S102
    fn = namespace["execution_fn"]
    # Using .get() is intentional here because only piecewise backend will
    # be stored in submod_callables. The other submodules are inlined and
    # we don't need to bind them to the execution function. Instead, we
    # should use None as placeholder to ensure the list indices are preserved
    # for better debuggability.
    submods_list = [submod_callables.get(name) for name in submod_names]
    return partial(fn, __vllm_submods__=submods_list)

generate_execution_code

generate_execution_code(
    split_gm: GraphModule,
) -> tuple[str, list[str]]

Generate Python source code from a split_gm's stitching graph.

Walks split_gm.graph.nodes and produces a function that calls submodules via a vllm_submods list, avoiding FX GraphModule overhead and dict lookup cost.

If a submodule is a plain torch.fx.GraphModule, it is inlined directly in the generated code and we do not need to serialize it in the artifact.

Parameters:

Name Type Description Default
split_gm GraphModule

The split graph module produced by split_graph().

required

Returns:

Type Description
str

A tuple of (code, submod_names) where code is the Python source

list[str]

and submod_names is the ordered list of submodule target names

tuple[str, list[str]]

corresponding to list indices used in the generated code.

Source code in vllm/compilation/codegen.py
@dynamo_timed("vllm.generate_execution_code")
def generate_execution_code(
    split_gm: torch.fx.GraphModule,
) -> tuple[str, list[str]]:
    """Generate Python source code from a split_gm's stitching graph.

    Walks split_gm.graph.nodes and produces a function that calls
    submodules via a __vllm_submods__ list, avoiding FX GraphModule overhead
    and dict lookup cost.

    If a submodule is a plain torch.fx.GraphModule, it is inlined directly
    in the generated code and we do not need to serialize it in the artifact.

    Args:
        split_gm: The split graph module produced by split_graph().

    Returns:
        A tuple of (code, submod_names) where code is the Python source
        and submod_names is the ordered list of submodule target names
        corresponding to list indices used in the generated code.
    """

    code, submod_names = generate_execution_code_with_name(
        split_gm, "execution_fn", with_submod=True
    )
    return "import torch\nimport operator\n" + code, submod_names