Skip to content

vllm.model_executor.layers.fused_moe.oracle.mxfp4

_backend_activation_key

_backend_activation_key(
    backend: Mxfp4MoeBackend,
) -> QuantKey | None

Map backend to its activation key (MXFP8 or None for BF16).

Source code in vllm/model_executor/layers/fused_moe/oracle/mxfp4.py
def _backend_activation_key(backend: Mxfp4MoeBackend) -> QuantKey | None:
    """Map backend to its activation key (MXFP8 or None for BF16)."""
    if backend in (
        Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_MXFP8,
        Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_MXFP8,
    ):
        return kMxfp8Dynamic
    return None

_get_priority_backends

_get_priority_backends() -> list[Mxfp4MoeBackend]

Get available backends in priority order based on platform and config. Only includes BF16 backends. MXFP8 backends are selected via env vars.

Source code in vllm/model_executor/layers/fused_moe/oracle/mxfp4.py
def _get_priority_backends() -> list[Mxfp4MoeBackend]:
    """
    Get available backends in priority order based on platform and config.
    Only includes BF16 backends. MXFP8 backends are selected via env vars.
    """
    _AVAILABLE_BACKENDS = [
        Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_BF16,
        Mxfp4MoeBackend.AITER,
        Mxfp4MoeBackend.TRITON,
        Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_BF16,
        Mxfp4MoeBackend.TRITON_UNFUSED,
        Mxfp4MoeBackend.MARLIN,
        Mxfp4MoeBackend.BATCHED_MARLIN,
        Mxfp4MoeBackend.XPU,
        Mxfp4MoeBackend.EMULATION,
    ]
    return _AVAILABLE_BACKENDS

convert_gpt_oss_weight_to_mxfp4_moe_kernel_format

convert_gpt_oss_weight_to_mxfp4_moe_kernel_format(
    mxfp4_backend: Mxfp4MoeBackend,
    layer: Module,
    w13_weight: Tensor,
    w2_weight: Tensor,
    w13_weight_scale: Tensor,
    w2_weight_scale: Tensor,
    w13_bias: Tensor | None = None,
    w2_bias: Tensor | None = None,
    _cache_permute_indices: dict[Size, Tensor]
    | None = None,
) -> tuple[
    Tensor,
    Tensor,
    Union[Tensor, PrecisionConfig],
    Union[Tensor, PrecisionConfig],
    Tensor | None,
    Tensor | None,
]

Convert loaded weights into backend-specific kernel format.

Source code in vllm/model_executor/layers/fused_moe/oracle/mxfp4.py
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
def convert_gpt_oss_weight_to_mxfp4_moe_kernel_format(
    mxfp4_backend: Mxfp4MoeBackend,
    layer: torch.nn.Module,
    w13_weight: torch.Tensor,
    w2_weight: torch.Tensor,
    w13_weight_scale: torch.Tensor,
    w2_weight_scale: torch.Tensor,
    w13_bias: torch.Tensor | None = None,
    w2_bias: torch.Tensor | None = None,
    _cache_permute_indices: dict[torch.Size, torch.Tensor] | None = None,
) -> tuple[
    torch.Tensor,
    torch.Tensor,
    Union[torch.Tensor, "PrecisionConfig"],
    Union[torch.Tensor, "PrecisionConfig"],
    torch.Tensor | None,
    torch.Tensor | None,
]:
    """Convert loaded weights into backend-specific kernel format."""

    num_experts = w13_weight.shape[0]
    intermediate_size = w13_weight.shape[1] // 2
    hidden_size = w13_weight.shape[2] * 2

    sf_block_size = 32  # mxfp4 block size

    if mxfp4_backend in (
        Mxfp4MoeBackend.MARLIN,
        Mxfp4MoeBackend.BATCHED_MARLIN,
    ):
        from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
            prepare_moe_mxfp4_layer_for_marlin,
        )

        return prepare_moe_mxfp4_layer_for_marlin(
            layer,
            w13_weight,
            w2_weight,
            w13_weight_scale,
            w2_weight_scale,
            w13_bias,
            w2_bias,
        )

    elif mxfp4_backend in TRTLLM_BACKENDS:
        assert _cache_permute_indices is not None
        from flashinfer.fp4_quantization import nvfp4_block_scale_interleave
        from flashinfer.fused_moe.core import get_w2_permute_indices_with_cache

        # gemm1_alpha/beta/clamp_limit are created by the expert class
        # (TrtLlmMxfp4ExpertsBase), not on the layer.

        w13_weight = w13_weight.data
        w2_weight = w2_weight.data
        w13_weight_scale = w13_weight_scale.data
        w2_weight_scale = w2_weight_scale.data
        assert w13_bias is not None and w2_bias is not None
        w13_bias = w13_bias.data.to(torch.float32)
        w2_bias = w2_bias.data.to(torch.float32)

        # Swap w1 and w3 as the definition of swiglu is different in trtllm-gen
        def swap_every_two_rows(x, axis=-1):
            shape = x.shape
            if axis < 0:
                axis = len(shape) + axis
            new_shape = list(shape)
            new_shape[axis] = shape[axis] // 2
            new_shape.insert(axis + 1, 2)
            x = x.reshape(*new_shape)
            x = x.flip(axis + 1)
            new_shape = list(shape)
            return x.reshape(*new_shape)

        w13_weight_scale = swap_every_two_rows(w13_weight_scale, -2)
        w13_weight = swap_every_two_rows(w13_weight, -2)
        w13_bias = swap_every_two_rows(w13_bias, -1)

        # Shuffle weights and scaling factors for transposed mma output
        gemm1_weights_shuffled = []
        gemm1_scales_shuffled = []
        gemm2_weights_shuffled = []
        gemm2_scales_shuffled = []
        gemm1_bias_shuffled = []
        gemm2_bias_shuffled = []
        epilogue_tile_m = 128
        for i in range(num_experts):
            # w13 weight
            permute_indices = get_w2_permute_indices_with_cache(
                _cache_permute_indices,
                w13_weight[i].view(torch.uint8),
                epilogue_tile_m,
            )
            gemm1_weights_shuffled.append(
                w13_weight[i]
                .view(torch.uint8)[permute_indices.to(w13_weight.device)]
                .contiguous()
            )
            # w13 scale
            permute_sf_indices = get_w2_permute_indices_with_cache(
                _cache_permute_indices,
                w13_weight_scale[i].view(torch.uint8),
                epilogue_tile_m,
                num_elts_per_sf=16,
            )
            gemm1_scales_shuffled.append(
                nvfp4_block_scale_interleave(
                    w13_weight_scale[i]
                    .view(torch.uint8)[permute_sf_indices.to(w13_weight_scale.device)]
                    .contiguous()
                )
            )
            # w13 bias
            permute_bias_indices = get_w2_permute_indices_with_cache(
                _cache_permute_indices,
                w13_bias[i].clone().reshape(-1, 1),
                epilogue_tile_m,
            )
            gemm1_bias_shuffled.append(
                w13_bias[i]
                .clone()
                .reshape(-1, 1)[permute_bias_indices.to(w13_bias.device)]
                .contiguous()
            )
            # w2 weight
            permute_indices = get_w2_permute_indices_with_cache(
                _cache_permute_indices,
                w2_weight[i].view(torch.uint8),
                epilogue_tile_m,
            )
            gemm2_weights_shuffled.append(
                w2_weight[i]
                .view(torch.uint8)[permute_indices.to(w2_weight.device)]
                .contiguous()
            )
            # w2 scale
            permute_sf_indices = get_w2_permute_indices_with_cache(
                _cache_permute_indices,
                w2_weight_scale[i].view(torch.uint8),
                epilogue_tile_m,
                num_elts_per_sf=16,
            )
            gemm2_scales_shuffled.append(
                nvfp4_block_scale_interleave(
                    w2_weight_scale[i]
                    .view(torch.uint8)[permute_sf_indices.to(w2_weight_scale.device)]
                    .contiguous()
                )
            )
            # w2 bias
            permute_indices = get_w2_permute_indices_with_cache(
                _cache_permute_indices,
                w2_bias[i].clone().reshape(-1, 1),
                epilogue_tile_m,
            )
            gemm2_bias_shuffled.append(
                w2_bias[i]
                .clone()
                .reshape(-1, 1)[permute_indices.to(w2_bias.device)]
                .contiguous()
            )

        w13_weight = torch.stack(gemm1_weights_shuffled)
        w13_weight_scale = (
            torch.stack(gemm1_scales_shuffled)
            .reshape(num_experts, 2 * intermediate_size, hidden_size // sf_block_size)
            .view(torch.float8_e4m3fn)
        )
        w2_weight = torch.stack(gemm2_weights_shuffled)
        w2_weight_scale = (
            torch.stack(gemm2_scales_shuffled)
            .reshape(num_experts, hidden_size, intermediate_size // sf_block_size)
            .view(torch.float8_e4m3fn)
        )
        w13_bias = torch.stack(gemm1_bias_shuffled).reshape(num_experts, -1)
        w2_bias = torch.stack(gemm2_bias_shuffled).reshape(num_experts, -1)

        return (
            w13_weight,
            w2_weight,
            w13_weight_scale,
            w2_weight_scale,
            w13_bias,
            w2_bias,
        )

    elif mxfp4_backend in (
        Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_BF16,
        Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_MXFP8,
    ):
        # De-interleave and swap for w13 weight, bias, and scales
        w13_w = w13_weight.data
        gate_w, up_w = w13_w[:, ::2, :], w13_w[:, 1::2, :]
        deinterleaved_w13_w = torch.cat([gate_w, up_w], dim=1)
        w1_w, w3_w = torch.chunk(deinterleaved_w13_w, 2, dim=1)
        w13_weight_swapped = torch.cat([w3_w, w1_w], dim=1)

        assert w13_bias is not None and w2_bias is not None
        w13_b = w13_bias.data.to(torch.float32)
        gate_b, up_b = w13_b[:, ::2], w13_b[:, 1::2]
        deinterleaved_w13_b = torch.cat([gate_b, up_b], dim=1)
        b1, b3 = torch.chunk(deinterleaved_w13_b, 2, dim=-1)
        w13_bias_swapped = torch.cat([b3, b1], dim=-1).to(torch.bfloat16)

        w13_s = w13_weight_scale.data
        gate_s, up_s = w13_s[:, ::2, :], w13_s[:, 1::2, :]
        deinterleaved_w13_s = torch.cat([gate_s, up_s], dim=1)
        s1, s3 = torch.chunk(deinterleaved_w13_s, 2, dim=1)
        w13_scale_swapped = torch.cat([s3, s1], dim=1)

        if mxfp4_backend == Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_MXFP8:
            from flashinfer import block_scale_interleave

            orig_shape = w13_scale_swapped.shape
            w13_scale_interleaved = block_scale_interleave(
                w13_scale_swapped.view(torch.uint8)
            ).reshape(orig_shape)

            w2_s = w2_weight_scale.data
            orig_shape = w2_s.shape
            w2_scale_interleaved = block_scale_interleave(
                w2_s.view(torch.uint8)
            ).reshape(orig_shape)

            return (
                w13_weight_swapped,
                w2_weight,
                w13_scale_interleaved,
                w2_scale_interleaved,
                w13_bias_swapped,
                w2_bias,
            )

        else:
            assert mxfp4_backend == Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_BF16

            def _interleave_mxfp4_cutlass_sm90(w):
                w_shape = w.shape
                w_interleaved = w.reshape(w_shape[0], w_shape[1], (w_shape[2] // 4), 4)
                w_interleaved = w_interleaved.permute(0, 2, 1, 3)
                w_interleaved = w_interleaved.reshape(
                    w_shape[0], w_shape[2] // 4, w_shape[1] * 4
                )
                return w_interleaved

            w31_scales = w13_scale_swapped.to(torch.uint8)
            w31_scales_interleaved = _interleave_mxfp4_cutlass_sm90(w31_scales)

            w2_scale = w2_weight_scale.data.to(torch.uint8)
            w2_scale_interleaved = _interleave_mxfp4_cutlass_sm90(w2_scale)

            return (
                w13_weight_swapped,
                w2_weight,
                w31_scales_interleaved,
                w2_scale_interleaved,
                w13_bias_swapped,
                w2_bias,
            )

    elif mxfp4_backend == Mxfp4MoeBackend.AITER:
        from vllm._aiter_ops import rocm_aiter_ops

        if w13_bias is not None:
            w13_bias = w13_bias.data.to(torch.float32)
        if w2_bias is not None:
            w2_bias = w2_bias.data.to(torch.float32)

        e, n, k = w13_weight.shape

        # De-interleave w13 rows: gate/up pairs -> contiguous gate, up blocks
        w13_weight.view(torch.uint8).copy_(
            w13_weight.data.view(torch.uint8)
            .view(e, n // 2, 2, k)
            .permute(0, 2, 1, 3)
            .contiguous()
            .view(e, n, k)
        )
        w13_weight_scale.data = (
            w13_weight_scale.data.view(e, n // 2, 2, -1)
            .permute(0, 2, 1, 3)
            .contiguous()
            .view(e, n, -1)
        )

        # View as native FP4 dtype for AITER shuffle
        w13_weight.data = w13_weight.data.view(torch.float4_e2m1fn_x2)
        w2_weight.data = w2_weight.data.view(torch.float4_e2m1fn_x2)

        # Shuffle weights and scales for AITER CK kernel layout
        w13_weight.data = rocm_aiter_ops.shuffle_weight_a16w4(w13_weight, 16, True)
        shuffled_w13_scale = rocm_aiter_ops.shuffle_scale_a16w4(
            w13_weight_scale.view(-1, w13_weight_scale.shape[-1]),
            num_experts,
            True,
        )

        w2_weight.data = rocm_aiter_ops.shuffle_weight_a16w4(w2_weight, 16, False)
        shuffled_w2_scale = rocm_aiter_ops.shuffle_scale_a16w4(
            w2_weight_scale.view(-1, w2_weight_scale.shape[-1]),
            num_experts,
            False,
        )

        # Permute bias to match de-interleaved weight layout
        if w13_bias is not None:
            w13_bias = (
                w13_bias.data.view(-1, n // 2, 2)
                .permute(0, 2, 1)
                .contiguous()
                .view(-1, n)
            )

        return (
            w13_weight,
            w2_weight,
            shuffled_w13_scale,
            shuffled_w2_scale,
            w13_bias,
            w2_bias,
        )

    elif mxfp4_backend in TRITON_BACKENDS:
        from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig

        assert w13_bias is not None and w2_bias is not None
        w13_bias = w13_bias.to(torch.float32)
        w2_bias = w2_bias.to(torch.float32)

        w13_weight, w13_flex, w13_scale = _swizzle_mxfp4(
            w13_weight,
            w13_weight_scale,
        )
        w2_weight, w2_flex, w2_scale = _swizzle_mxfp4(
            w2_weight,
            w2_weight_scale,
        )

        w13_precision_config = PrecisionConfig(
            weight_scale=w13_scale, flex_ctx=FlexCtx(rhs_data=w13_flex)
        )
        w2_precision_config = PrecisionConfig(
            weight_scale=w2_scale, flex_ctx=FlexCtx(rhs_data=w2_flex)
        )

        del layer.w13_weight
        del layer.w2_weight

        return (
            w13_weight,
            w2_weight,
            w13_precision_config,
            w2_precision_config,
            w13_bias,
            w2_bias,
        )
    elif mxfp4_backend == Mxfp4MoeBackend.XPU:
        # No additional transformation needed for XPU backend
        return (
            w13_weight,
            w2_weight,
            w13_weight_scale,
            w2_weight_scale,
            w13_bias,
            w2_bias,
        )
    elif mxfp4_backend == Mxfp4MoeBackend.EMULATION:
        # No additional transformation needed for emulation backend,
        # weights are dequantized on the fly in the experts class.
        return (
            w13_weight,
            w2_weight,
            w13_weight_scale,
            w2_weight_scale,
            w13_bias,
            w2_bias,
        )
    else:
        raise ValueError(
            f"Unsupported mxfp4_backend: {mxfp4_backend}: "
            f"should be one of: {list(Mxfp4MoeBackend)}."
        )

make_mxfp4_moe_kernel

make_mxfp4_moe_kernel(
    moe_quant_config: FusedMoEQuantConfig,
    moe_config: FusedMoEConfig,
    experts_cls: type[FusedMoEExperts],
    mxfp4_backend: Mxfp4MoeBackend,
    routing_tables: tuple[Tensor, Tensor, Tensor]
    | None = None,
    shared_experts: Module | None = None,
) -> FusedMoEKernel

Create a FusedMoEKernel for the given MXFP4 backend.

Source code in vllm/model_executor/layers/fused_moe/oracle/mxfp4.py
def make_mxfp4_moe_kernel(
    moe_quant_config: FusedMoEQuantConfig,
    moe_config: FusedMoEConfig,
    experts_cls: type[mk.FusedMoEExperts],
    mxfp4_backend: Mxfp4MoeBackend,
    routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
    shared_experts: torch.nn.Module | None = None,
) -> mk.FusedMoEKernel:
    """Create a FusedMoEKernel for the given MXFP4 backend."""
    is_monolithic = issubclass(experts_cls, mk.FusedMoEExpertsMonolithic)

    # Create Prepare/Finalize.
    prepare_finalize = maybe_make_prepare_finalize(
        moe=moe_config,
        quant_config=moe_quant_config,
        routing_tables=routing_tables,
        allow_new_interface=True,
        use_monolithic=is_monolithic,
    )
    assert prepare_finalize is not None

    logger.info_once("Using %s", prepare_finalize.__class__.__name__)

    # Create Experts.
    if prepare_finalize.activation_format == mk.FusedMoEActivationFormat.BatchedExperts:
        max_num_tokens = prepare_finalize.max_num_tokens_per_rank()
        assert max_num_tokens is not None
        experts = experts_cls(
            moe_config=moe_config,
            quant_config=moe_quant_config,
            max_num_tokens=max_num_tokens,
            num_dispatchers=prepare_finalize.num_dispatchers(),
        )
    else:
        experts = experts_cls(
            moe_config=moe_config,
            quant_config=moe_quant_config,
        )

    kernel = mk.FusedMoEKernel(
        prepare_finalize,
        experts,
        shared_experts=(
            shared_experts
            if moe_config.moe_parallel_config.use_batched_activation_format
            else None
        ),
        inplace=(
            not moe_config.disable_inplace and mxfp4_backend not in TRTLLM_BACKENDS
        ),
    )

    return kernel

make_mxfp4_moe_quant_config

make_mxfp4_moe_quant_config(
    mxfp4_backend: Mxfp4MoeBackend,
    w1_scale: Union[Tensor, PrecisionConfig],
    w2_scale: Union[Tensor, PrecisionConfig],
    w1_bias: Tensor | None = None,
    w2_bias: Tensor | None = None,
) -> FusedMoEQuantConfig | None

Create a FusedMoEQuantConfig for the given MXFP4 backend.

Source code in vllm/model_executor/layers/fused_moe/oracle/mxfp4.py
def make_mxfp4_moe_quant_config(
    mxfp4_backend: Mxfp4MoeBackend,
    w1_scale: Union[torch.Tensor, "PrecisionConfig"],
    w2_scale: Union[torch.Tensor, "PrecisionConfig"],
    w1_bias: torch.Tensor | None = None,
    w2_bias: torch.Tensor | None = None,
) -> FusedMoEQuantConfig | None:
    """Create a FusedMoEQuantConfig for the given MXFP4 backend."""
    if mxfp4_backend in (
        Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_MXFP8,
        Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_MXFP8,
    ):
        return mxfp4_mxfp8_moe_quant_config(
            w1_bias=w1_bias,
            w2_bias=w2_bias,
            w1_scale=w1_scale,
            w2_scale=w2_scale,
        )
    elif mxfp4_backend in (
        Mxfp4MoeBackend.MARLIN,
        Mxfp4MoeBackend.BATCHED_MARLIN,
        Mxfp4MoeBackend.TRITON,
        Mxfp4MoeBackend.TRITON_UNFUSED,
        Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_BF16,
        Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_BF16,
        Mxfp4MoeBackend.AITER,
    ):
        return mxfp4_w4a16_moe_quant_config(
            w1_bias=w1_bias,
            w2_bias=w2_bias,
            w1_scale=w1_scale,
            w2_scale=w2_scale,
        )
    else:
        return ocp_mx_moe_quant_config(
            quant_dtype="mxfp4",
            w1_bias=w1_bias,
            w2_bias=w2_bias,
            w1_scale=w1_scale,
            w2_scale=w2_scale,
        )

map_mxfp4_backend

map_mxfp4_backend(
    runner_backend: MoEBackend,
) -> Mxfp4MoeBackend

Map user's moe_backend string to Mxfp4MoeBackend.

Source code in vllm/model_executor/layers/fused_moe/oracle/mxfp4.py
def map_mxfp4_backend(runner_backend: MoEBackend) -> Mxfp4MoeBackend:
    """Map user's moe_backend string to Mxfp4MoeBackend."""
    mapping: dict[str, Mxfp4MoeBackend] = {
        "flashinfer_trtllm": Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_BF16,
        "flashinfer_trtllm_afp8": Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_MXFP8,
        "flashinfer_cutlass": Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_BF16,
        "flashinfer_cutlass_afp8": Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_MXFP8,
        "triton": Mxfp4MoeBackend.TRITON,
        "marlin": Mxfp4MoeBackend.MARLIN,
        "aiter": Mxfp4MoeBackend.AITER,
        "xpu": Mxfp4MoeBackend.XPU,
        "emulation": Mxfp4MoeBackend.EMULATION,
    }
    if backend := mapping.get(runner_backend):
        return backend
    raise ValueError(
        f"moe_backend='{runner_backend}' is not supported for MXFP4 MoE. "
        f"Expected one of {list(mapping.keys())}."
    )

mxfp4_round_up_hidden_size_and_intermediate_size

mxfp4_round_up_hidden_size_and_intermediate_size(
    backend: Mxfp4MoeBackend,
    hidden_size: int,
    intermediate_size: int,
) -> tuple[int, int]

Round up hidden_size and intermediate_size based on backend requirements.

Source code in vllm/model_executor/layers/fused_moe/oracle/mxfp4.py
def mxfp4_round_up_hidden_size_and_intermediate_size(
    backend: Mxfp4MoeBackend, hidden_size: int, intermediate_size: int
) -> tuple[int, int]:
    """Round up hidden_size and intermediate_size based on backend requirements."""
    if backend in (Mxfp4MoeBackend.MARLIN, Mxfp4MoeBackend.BATCHED_MARLIN):
        intermediate_size = round_up(intermediate_size, 128)
        if current_platform.is_xpu():
            hidden_size = round_up(hidden_size, 128)
        else:
            hidden_size = round_up(hidden_size, 256)
    elif backend in TRTLLM_BACKENDS:
        intermediate_size = round_up(intermediate_size, 256)
        hidden_size = round_up(hidden_size, 256)
    elif backend in (
        Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_BF16,
        Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_MXFP8,
    ):
        intermediate_size = round_up(intermediate_size, 128)
        hidden_size = round_up(hidden_size, 128)
    elif current_platform.is_rocm():
        intermediate_size = round_up(intermediate_size, 256)
        hidden_size = round_up(hidden_size, 256)
    else:
        intermediate_size = round_up(intermediate_size, 64)
    return hidden_size, intermediate_size

select_gpt_oss_mxfp4_moe_backend

select_gpt_oss_mxfp4_moe_backend(
    config: FusedMoEConfig,
) -> tuple[Mxfp4MoeBackend, type[FusedMoEExperts] | None]

Select the primary MXFP4 MoE backend. Note: Shape-specific fallbacks may still occur at runtime.

Source code in vllm/model_executor/layers/fused_moe/oracle/mxfp4.py
def select_gpt_oss_mxfp4_moe_backend(
    config: FusedMoEConfig,
) -> tuple[Mxfp4MoeBackend, type[mk.FusedMoEExperts] | None]:
    """
    Select the primary MXFP4 MoE backend.
    Note: Shape-specific fallbacks may still occur at runtime.
    """
    device_capability = current_platform.get_device_capability()
    triton_kernels_supported = (
        has_triton_kernels()
        and device_capability is not None
        and (9, 0) <= device_capability < (11, 0)
    )

    # LoRA: separate experts backend path
    if config.is_lora_enabled:
        if not current_platform.is_cuda():
            # ROCm: Triton mxfp4 LoRA hits GPU memory faults due to
            # triton_kernels.tensor.Tensor / HIP read-only page issues
            # during weight swizzle and LoRA forward. Needs work from
            # the triton_kernels/aiter side.
            raise NotImplementedError("Mxfp4 LoRA is currently only supported on CUDA.")
        if envs.VLLM_MXFP4_USE_MARLIN is False and triton_kernels_supported:
            logger.info_once("Using Triton backend for mxfp4 lora")
            return Mxfp4MoeBackend.TRITON_UNFUSED, backend_to_kernel_cls(
                Mxfp4MoeBackend.TRITON_UNFUSED
            )[0]
        logger.info_once("Using Marlin backend for mxfp4 lora")
        return Mxfp4MoeBackend.MARLIN, backend_to_kernel_cls(Mxfp4MoeBackend.MARLIN)[0]

    activation_format = (
        mk.FusedMoEActivationFormat.BatchedExperts
        if config.moe_parallel_config.use_batched_activation_format
        else mk.FusedMoEActivationFormat.Standard
    )

    def _make_log_backend(backend: Mxfp4MoeBackend):
        return f"Using '{backend.value}' Mxfp4 MoE backend."

    def _make_log_unsupported(backend: Mxfp4MoeBackend, reason: str | None) -> str:
        if reason:
            return (
                f"Mxfp4 MoE backend '{backend.value}' does not support the "
                f"deployment configuration since {reason}."
            )
        return (
            f"Mxfp4 MoE backend '{backend.value}' does not support the "
            "deployment configuration."
        )

    def _return_or_raise(
        backend: Mxfp4MoeBackend,
        config: FusedMoEConfig,
        weight_key: QuantKey | None,
        activation_key: QuantKey | None,
        activation_format: mk.FusedMoEActivationFormat,
    ) -> tuple[Mxfp4MoeBackend, type[mk.FusedMoEExperts]]:
        reason: str | None = None
        for k_cls in backend_to_kernel_cls(backend):
            supported, reason = k_cls.is_supported_config(
                k_cls, config, weight_key, activation_key, activation_format
            )
            if supported:
                logger.info_once(_make_log_backend(backend))
                return backend, k_cls
        raise ValueError(_make_log_unsupported(backend, reason))

    runner_backend = config.moe_backend
    if runner_backend != "auto":
        requested_backend = map_mxfp4_backend(runner_backend)
        if (
            activation_format == mk.FusedMoEActivationFormat.BatchedExperts
            and requested_backend == Mxfp4MoeBackend.MARLIN
        ):
            requested_backend = Mxfp4MoeBackend.BATCHED_MARLIN
        return _return_or_raise(
            requested_backend,
            config,
            kMxfp4Static,
            _backend_activation_key(requested_backend),
            activation_format,
        )

    # Select kernels in order of backend.
    AVAILABLE_BACKENDS = _get_priority_backends()

    # Handle explicit FlashInfer MXFP4 BF16 configuration.
    if envs.is_set("VLLM_USE_FLASHINFER_MOE_MXFP4_BF16"):
        if not envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16:
            AVAILABLE_BACKENDS.remove(Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_BF16)
            AVAILABLE_BACKENDS.remove(Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_BF16)
        else:
            if current_platform.is_device_capability(90):
                return _return_or_raise(
                    Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_BF16,
                    config,
                    kMxfp4Static,
                    None,
                    activation_format,
                )
            if current_platform.is_device_capability_family(100):
                return _return_or_raise(
                    Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_BF16,
                    config,
                    kMxfp4Static,
                    None,
                    activation_format,
                )
            raise ValueError(
                "VLLM_USE_FLASHINFER_MOE_MXFP4_BF16=1 is set but the "
                "current device capability is not supported. "
                "Only SM90 (CUTLASS) and SM100+ (TRTLLM) are supported."
            )

    # Handle explicit FlashInfer MXFP4 MXFP8 TRTLLM configuration.
    if (
        envs.is_set("VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8")
        and envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8
    ):
        return _return_or_raise(
            Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_MXFP8,
            config,
            kMxfp4Static,
            kMxfp8Dynamic,
            activation_format,
        )

    # Handle explicit FlashInfer MXFP4 MXFP8 CUTLASS configuration.
    if (
        envs.is_set("VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS")
        and envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS
    ):
        return _return_or_raise(
            Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_MXFP8,
            config,
            kMxfp4Static,
            kMxfp8Dynamic,
            activation_format,
        )

    # Handle explicit Marlin MXFP4 configuration.
    if envs.is_set("VLLM_MXFP4_USE_MARLIN") and envs.VLLM_MXFP4_USE_MARLIN:
        return _return_or_raise(
            Mxfp4MoeBackend.MARLIN,
            config,
            kMxfp4Static,
            None,
            activation_format,
        )

    for backend in AVAILABLE_BACKENDS:
        activation_key = _backend_activation_key(backend)
        for k_cls in backend_to_kernel_cls(backend):
            supported, reason = k_cls.is_supported_config(
                k_cls, config, kMxfp4Static, activation_key, activation_format
            )
            if supported:
                logger.info_once(_make_log_backend(backend))
                return backend, k_cls
            else:
                logger.debug_once(_make_log_unsupported(backend, reason))

    if current_platform.is_xpu():
        backend = Mxfp4MoeBackend.XPU
        logger.info_once(_make_log_backend(backend))
        return _return_or_raise(
            Mxfp4MoeBackend.XPU,
            config,
            kMxfp4Static,
            None,
            activation_format,
        )

    if current_platform.is_cuda() or current_platform.is_rocm():
        raise NotImplementedError(
            "No MXFP4 MoE backend supports the deployment configuration."
        )

    return Mxfp4MoeBackend.NONE, None