class GPTQMarlinMoEMethod(FusedMoEMethodBase):
"""MoE Marlin method with quantization."""
def __init__(
self,
quant_config: GPTQMarlinConfig,
moe: FusedMoEConfig,
) -> None:
super().__init__(moe)
self.quant_config = quant_config
if self.quant_config.quant_type.size_bits == 4:
quant_type = scalar_types.uint4b8
scale = kInt4StaticGroupScale
elif self.quant_config.quant_type.size_bits == 8:
quant_type = scalar_types.uint8b128
scale = kInt8StaticGroupScale
else:
raise ValueError("GPTQMarlinMoEMethod only supports int4 and int8 now.")
self.input_dtype = None
self.use_marlin = True
weight_key = QuantKey(quant_type, scale)
self.wna16_moe_backend, self.experts_cls = select_wna16_moe_backend(
moe, weight_key, quant_config.weight_bits
)
def create_weights(
self,
layer: torch.nn.Module,
num_experts: int,
hidden_size: int,
intermediate_size_per_partition: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
layer.input_dtype = self.input_dtype
is_a_8bit = self.input_dtype is not None and self.input_dtype.itemsize == 1
if is_a_8bit:
assert self.quant_config.quant_type.size_bits == 8, (
"W8A8-INT8 is not supported by marlin kernel."
)
intermediate_size_full = extra_weight_attrs.pop("intermediate_size_full")
self.is_k_full = (not self.quant_config.desc_act) or (
intermediate_size_per_partition == intermediate_size_full
)
if self.quant_config.group_size != -1:
scales_size13 = hidden_size // self.quant_config.group_size
w2_scales_size = (
intermediate_size_full
if self.quant_config.desc_act
else intermediate_size_per_partition
)
scales_size2 = w2_scales_size // self.quant_config.group_size
strategy = FusedMoeWeightScaleSupported.GROUP.value
else:
scales_size13 = 1
scales_size2 = 1
strategy = FusedMoeWeightScaleSupported.CHANNEL.value
layer.num_groups_w13 = scales_size13
layer.num_groups_w2 = scales_size2
extra_weight_attrs.update({"quant_method": strategy, "is_transposed": True})
# Fused gate_up_proj (column parallel)
w13_qweight = torch.nn.Parameter(
torch.empty(
num_experts,
hidden_size // self.quant_config.pack_factor,
2 * intermediate_size_per_partition,
dtype=torch.int32,
),
requires_grad=False,
)
layer.register_parameter("w13_qweight", w13_qweight)
set_weight_attrs(w13_qweight, extra_weight_attrs)
# down_proj (row parallel)
w2_qweight = torch.nn.Parameter(
torch.empty(
num_experts,
intermediate_size_per_partition // self.quant_config.pack_factor,
hidden_size,
dtype=torch.int32,
),
requires_grad=False,
)
layer.register_parameter("w2_qweight", w2_qweight)
set_weight_attrs(w2_qweight, extra_weight_attrs)
# up_proj scales
w13_scales = torch.nn.Parameter(
torch.empty(
num_experts,
scales_size13,
2 * intermediate_size_per_partition,
dtype=params_dtype,
),
requires_grad=False,
)
layer.register_parameter("w13_scales", w13_scales)
set_weight_attrs(w13_scales, extra_weight_attrs)
# down_proj scales
w2_scales = torch.nn.Parameter(
torch.empty(num_experts, scales_size2, hidden_size, dtype=params_dtype),
requires_grad=False,
)
layer.register_parameter("w2_scales", w2_scales)
set_weight_attrs(w2_scales, extra_weight_attrs)
# don't shard the w2 scales when running act order
set_weight_attrs(w2_scales, {"load_full_w2": self.quant_config.desc_act})
# up_proj scales
w13_qzeros = torch.nn.Parameter(
torch.empty(
num_experts,
scales_size13,
2 * intermediate_size_per_partition // self.quant_config.pack_factor,
dtype=params_dtype,
),
requires_grad=False,
)
layer.register_parameter("w13_qzeros", w13_qzeros)
set_weight_attrs(w13_qzeros, extra_weight_attrs)
# down_proj scales
w2_qzeros = torch.nn.Parameter(
torch.empty(
num_experts,
scales_size2,
hidden_size // self.quant_config.pack_factor,
dtype=params_dtype,
),
requires_grad=False,
)
layer.register_parameter("w2_qzeros", w2_qzeros)
set_weight_attrs(w2_qzeros, extra_weight_attrs)
# don't shard the w2 scales when running act order
set_weight_attrs(w2_qzeros, {"load_full_w2": self.quant_config.desc_act})
w13_g_idx = torch.nn.Parameter(
torch.empty(
num_experts,
hidden_size,
dtype=torch.int32,
),
requires_grad=False,
)
layer.register_parameter("w13_g_idx", w13_g_idx)
set_weight_attrs(w13_g_idx, extra_weight_attrs)
w2_g_idx = torch.nn.Parameter(
torch.empty(
num_experts,
intermediate_size_per_partition,
dtype=torch.int32,
),
requires_grad=False,
)
layer.register_parameter("w2_g_idx", w2_g_idx)
set_weight_attrs(w2_g_idx, extra_weight_attrs)
w13_g_idx_sort_indices = torch.nn.Parameter(
torch.empty(
num_experts,
hidden_size,
dtype=torch.int32,
),
requires_grad=False,
)
layer.register_parameter("w13_g_idx_sort_indices", w13_g_idx_sort_indices)
set_weight_attrs(w13_g_idx_sort_indices, extra_weight_attrs)
w2_g_idx_sort_indices = torch.nn.Parameter(
torch.empty(
num_experts,
intermediate_size_per_partition,
dtype=torch.int32,
),
requires_grad=False,
)
layer.register_parameter("w2_g_idx_sort_indices", w2_g_idx_sort_indices)
set_weight_attrs(w2_g_idx_sort_indices, extra_weight_attrs)
device = layer.w13_qweight.device
layer.workspace = marlin_make_workspace_new(device, 4)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
is_a_8bit = self.input_dtype is not None and self.input_dtype.itemsize == 1
if is_a_8bit:
assert self.quant_config.quant_type.size_bits == 8, (
"W8A8-INT8 is not supported by marlin kernel."
)
(
w13,
w2,
w13_scale,
w2_scale,
w13_g_idx,
w2_g_idx,
w13_g_idx_sort_indices,
w2_g_idx_sort_indices,
w13_input_global_scale,
w2_input_global_scale,
w13_bias,
w2_bias,
) = convert_to_wna16_moe_kernel_format(
backend=self.wna16_moe_backend,
layer=layer,
quant_config=self.quant_config,
input_dtype=self.input_dtype,
w13=layer.w13_qweight,
w2=layer.w2_qweight,
w13_scale=layer.w13_scales,
w2_scale=layer.w2_scales,
w13_g_idx=layer.w13_g_idx,
w2_g_idx=layer.w2_g_idx,
w13_bias=getattr(layer, "w13_bias", None),
w2_bias=getattr(layer, "w2_bias", None),
)
replace_parameter(layer, "w13_qweight", w13)
replace_parameter(layer, "w2_qweight", w2)
replace_parameter(layer, "w13_scales", w13_scale)
replace_parameter(layer, "w2_scales", w2_scale)
replace_parameter(layer, "w13_g_idx", w13_g_idx)
replace_parameter(layer, "w2_g_idx", w2_g_idx)
replace_parameter(layer, "w13_g_idx_sort_indices", w13_g_idx_sort_indices)
replace_parameter(layer, "w2_g_idx_sort_indices", w2_g_idx_sort_indices)
if w13_input_global_scale is not None:
if hasattr(layer, "w13_input_global_scale"):
replace_parameter(
layer, "w13_input_global_scale", w13_input_global_scale
)
else:
layer.register_parameter(
"w13_input_global_scale",
torch.nn.Parameter(w13_input_global_scale, requires_grad=False),
)
if w2_input_global_scale is not None:
if hasattr(layer, "w2_input_global_scale"):
replace_parameter(layer, "w2_input_global_scale", w2_input_global_scale)
else:
layer.register_parameter(
"w2_input_global_scale",
torch.nn.Parameter(w2_input_global_scale, requires_grad=False),
)
if w13_bias is not None:
if hasattr(layer, "w13_bias"):
replace_parameter(layer, "w13_bias", w13_bias)
else:
layer.register_parameter(
"w13_bias", torch.nn.Parameter(w13_bias, requires_grad=False)
)
if w2_bias is not None:
if hasattr(layer, "w2_bias"):
replace_parameter(layer, "w2_bias", w2_bias)
else:
layer.register_parameter(
"w2_bias", torch.nn.Parameter(w2_bias, requires_grad=False)
)
self._setup_kernel(layer)
def _setup_kernel(self, layer: FusedMoE) -> None:
"""Build the FusedMoEKernel for this layer."""
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
self.moe_kernel = make_wna16_moe_kernel(
moe_quant_config=self.moe_quant_config,
moe_config=self.moe,
experts_cls=self.experts_cls,
layer=layer,
is_k_full=self.is_k_full,
w13_g_idx=layer.w13_g_idx,
w2_g_idx=layer.w2_g_idx,
w13_g_idx_sort_indices=layer.w13_g_idx_sort_indices,
w2_g_idx_sort_indices=layer.w2_g_idx_sort_indices,
routing_tables=layer._maybe_init_expert_routing_tables(),
shared_experts=layer.shared_experts,
)
def get_fused_moe_quant_config(self, layer: torch.nn.Module) -> FusedMoEQuantConfig:
from vllm.model_executor.layers.fused_moe.config import (
gptq_marlin_moe_quant_config,
)
return gptq_marlin_moe_quant_config(
w1_scale=layer.w13_scales,
w2_scale=layer.w2_scales,
weight_bits=self.quant_config.weight_bits,
group_size=self.quant_config.group_size,
w1_zp=getattr(layer, "w13_qzeros", None)
if not self.quant_config.is_sym
else None,
w2_zp=getattr(layer, "w2_qzeros", None)
if not self.quant_config.is_sym
else None,
w1_bias=getattr(layer, "w13_bias", None),
w2_bias=getattr(layer, "w2_bias", None),
)
def select_gemm_impl(
self,
prepare_finalize,
layer: torch.nn.Module,
):
raise ValueError(
f"{self.__class__.__name__} uses the new modular kernel "
"initialization logic. This function should not be called."
)
def apply(
self,
layer: FusedMoE,
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor:
assert not self.is_monolithic
assert self.moe_kernel is not None
return self.moe_kernel.apply(
hidden_states=x,
w1=layer.w13_qweight,
w2=layer.w2_qweight,
topk_weights=topk_weights,
topk_ids=topk_ids,
activation=layer.activation,
global_num_experts=layer.global_num_experts,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
expert_map=layer.expert_map,
shared_experts_input=shared_experts_input,
)