Skip to content

vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils

dequantize_to_dtype

dequantize_to_dtype(
    tensor_fp4: Tensor,
    tensor_sf: Tensor,
    global_scale: Tensor,
    dtype: dtype,
    block_size: int = 16,
    swizzle: bool | None = True,
)

Dequantize the fp4 tensor back to high precision.

Supports both 2D and 3D inputs: - 2D: [m, packed_k] -> [m, k] - 3D: [dim0, m, packed_k] -> [dim0, m, k]

Source code in vllm/model_executor/layers/quantization/utils/nvfp4_emulation_utils.py
def dequantize_to_dtype(
    tensor_fp4: torch.Tensor,
    tensor_sf: torch.Tensor,
    global_scale: torch.Tensor,
    dtype: torch.dtype,
    block_size: int = 16,
    swizzle: bool | None = True,
):
    """Dequantize the fp4 tensor back to high precision.

    Supports both 2D and 3D inputs:
    - 2D: [m, packed_k] -> [m, k]
    - 3D: [dim0, m, packed_k] -> [dim0, m, k]
    """
    # Two fp4 values are packed into one uint8.
    assert tensor_fp4.dtype == torch.uint8

    # We handle 3D tensors reshaping them to 2D.
    is_3d = tensor_fp4.ndim == 3

    if is_3d:
        dim0, m, packed_k = tensor_fp4.shape
        tensor_fp4 = tensor_fp4.reshape(-1, packed_k)
        tensor_sf = tensor_sf.reshape(-1, tensor_sf.shape[-1])
        global_scale = global_scale[:, None, None]
    else:
        m, packed_k = tensor_fp4.shape

    k = packed_k * 2
    tensor_f32 = break_fp4_bytes(tensor_fp4, torch.float32)
    tensor_f32 = tensor_f32.reshape(-1, k // block_size, block_size)
    tensor_sf = tensor_sf.view(torch.float8_e4m3fn)

    if swizzle:
        tensor_sf = convert_swizzled_to_linear(  # noqa: E501
            tensor_sf, tensor_f32.size(0), k, block_size
        )

    if is_3d:
        tensor_sf = tensor_sf.reshape(dim0, m, k // block_size)
    tensor_sf_dtype = tensor_sf.to(torch.float32) * global_scale

    if is_3d:
        tensor_f32 = tensor_f32.reshape(dim0, m, -1, block_size)

    # scale the tensor
    out = tensor_f32 * tensor_sf_dtype.unsqueeze(-1)
    out = out.reshape(*out.shape[:-2], -1)

    return out.to(dtype)

ref_nvfp4_quant_dequant

ref_nvfp4_quant_dequant(
    x: Tensor, global_scale: Tensor, block_size: int
) -> tuple[Tensor, None]

NVFP4 quantize-dequantize operation.

global_scale is expected to have a single element.

Source code in vllm/model_executor/layers/quantization/utils/nvfp4_emulation_utils.py
def ref_nvfp4_quant_dequant(
    x: torch.Tensor, global_scale: torch.Tensor, block_size: int
) -> tuple[torch.Tensor, None]:
    """
    NVFP4 quantize-dequantize operation.

    `global_scale` is expected to have a single element.
    """
    x_m, x_k = x.shape
    output_dtype = x.dtype

    # quantize input to (FP4 and interleaved block scale)
    x_fp4, x_blockscale = ref_nvfp4_quant(x, global_scale, block_size)

    # dequantize input
    x_fp4 = x_fp4.reshape(x_m, x_k // block_size, block_size)
    x_blockscale = x_blockscale.unsqueeze(-1) / global_scale
    x_dq = (x_fp4 * x_blockscale).reshape(x_m, x_k).to(output_dtype)

    return x_dq, None