vllm.compilation.passes.fusion.rocm_aiter_fusion ¶
AddAiterRMSNormPadPattern ¶
This pattern replaces an aiter_rmsnorm_with_add & a pad op with a custom triton_add_rmsnorm_pad op from AITER.
Source code in vllm/compilation/passes/fusion/rocm_aiter_fusion.py
AiterFusedAddRMSFp8GroupQuantPattern ¶
Bases: AiterRMSNormQuantPattern
This pattern fuses aiter rms_norm_with_add & group fp8 quant custom ops into a aiter rms_norm_with_add_group_fp8_quant op.
Source code in vllm/compilation/passes/fusion/rocm_aiter_fusion.py
AiterFusedAddRMSNormDynamicQuantPattern ¶
Bases: AiterRMSNormQuantPattern
AITER RMSNorm Fused Add + Dynamic Quantization pattern.
Source code in vllm/compilation/passes/fusion/rocm_aiter_fusion.py
AiterRMSFp8GroupQuantPattern ¶
Bases: AiterRMSNormQuantPattern
This pattern fuses aiter rms_norm & group fp8 quant custom ops into an aiter rms_norm_group_fp8_quant op.
Source code in vllm/compilation/passes/fusion/rocm_aiter_fusion.py
AiterRMSNormDynamicQuantPattern ¶
Bases: AiterRMSNormQuantPattern
AITER RMSNorm + Dynamic Quantization pattern.
Source code in vllm/compilation/passes/fusion/rocm_aiter_fusion.py
AiterSiluMulFp8GroupQuantPattern ¶
Bases: VllmPatternReplacement
This pattern fuses aiter silu_and_mul & group fp8 quant custom ops into an aiter silu_and_mul_group_fp8_quant op.
Source code in vllm/compilation/passes/fusion/rocm_aiter_fusion.py
MLADualRMSNormFusionPass ¶
Bases: VllmFusionPatternMatcherPass
Post-grad PatternMatcher pass that fuses paired q / kv RMS norms in MLA attention into fused_mla_dual_rms_norm backed by aiter's fused_qk_rmsnorm HIP kernel.
Source code in vllm/compilation/passes/fusion/rocm_aiter_fusion.py
MLADualRMSNormPattern ¶
Bases: VllmPatternReplacement[..., tuple[Tensor, Tensor, Tensor]]
Fuse paired q_a_layernorm + kv_a_layernorm in MLA attention into AITER's fused_qk_rmsnorm HIP kernel.
Target FX-graph pattern (unfused, vllm_ir stage)::
gemm -> split_with_sizes([q_dim, kv_dim])
+-- q_c -> vllm_ir.rms_norm(q_c, q_w, eps)
+-- kv_lora -> split_with_sizes([kv_c_dim, k_pe_dim])
+-- kv_c -> vllm_ir.rms_norm(kv_c, kv_w, eps)
+-- k_pe
The pattern covers the connected subgraph rooted at the first split_with_sizes (which produces q_c and kv_lora), through the two rms_norm calls, and the k_pe passthrough.
Source code in vllm/compilation/passes/fusion/rocm_aiter_fusion.py
509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 | |
RocmAiterRMSNormQuantFusionPass ¶
Bases: VllmPatternMatcherPass
This pass fuses aiter rms_norm & vllm/aiter quant custom ops into a fused rms_norm_quant op. It also supports fused_add_rms_norm.
Source code in vllm/compilation/passes/fusion/rocm_aiter_fusion.py
RocmAiterSiluMulFp8GroupQuantFusionPass ¶
Bases: VllmFusionPatternMatcherPass
This pass fuses a pre-defined set of custom ops into fused ops. It uses the torch pattern matcher to find the patterns and replace them.
Because patterns can only be registered once, the pass is a singleton. This will be addressed in a future version of PyTorch: https://github.com/pytorch/pytorch/pull/139321#issuecomment-2452354980
Source code in vllm/compilation/passes/fusion/rocm_aiter_fusion.py
RocmAiterTritonAddRMSNormPadFusionPass ¶
Bases: VllmPatternMatcherPass
This pass replaces an AITER CK RMSNorm + residual add and a pad op with an triton_add_rmsnorm_pad op from AITER.