Skip to content

Commit 3978fe2

Browse files
authored
[NPU] Add optimized NPU mhc (#1173)
Add Ascend NPU Triton kernels for the three mHC sub-operators: - Fused matmul + RMS normalization (forward/backward) - Sinkhorn routing with split pre/post/residual coefficients (forward/backward) - Pre-aggregate weighted sum (forward/backward) - Post + residual mixing (forward/backward) NPU optimizations applied: - Unified UB tiling via compute_default_tiling_strategy for matrix - Persistent grid-stride loops (tl.range + num_programs) - Adaptive BLOCK_N/BLOCK_M for core utilisation at small seq_len - Fused backward coefficient assembly kernel Hardware Type: Atlas 800I A2 - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [ ] run `make test-convergence` to ensure convergence
1 parent 2ca3bd0 commit 3978fe2

File tree

2 files changed

+1686
-0
lines changed

2 files changed

+1686
-0
lines changed

src/liger_kernel/ops/backends/_ascend/ops/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,9 @@
5555
from liger_kernel.ops.backends._ascend.ops.llama4_rope import LigerLlama4RopeFunction
5656
from liger_kernel.ops.backends._ascend.ops.llama4_rope import llama4_rope_backward
5757
from liger_kernel.ops.backends._ascend.ops.llama4_rope import llama4_rope_forward
58+
from liger_kernel.ops.backends._ascend.ops.mhc import LigerMHCCoeffsFunction
59+
from liger_kernel.ops.backends._ascend.ops.mhc import LigerMHCPostResFunction
60+
from liger_kernel.ops.backends._ascend.ops.mhc import LigerMHCPreFunction
5861
from liger_kernel.ops.backends._ascend.ops.poly_norm import LigerPolyNormFunction
5962
from liger_kernel.ops.backends._ascend.ops.poly_norm import poly_norm_backward
6063
from liger_kernel.ops.backends._ascend.ops.poly_norm import poly_norm_forward
@@ -146,4 +149,7 @@
146149
"LigerFusedLinearCrossEntropyFunction",
147150
"fused_linear_cross_entropy_forward",
148151
"fused_linear_cross_entropy_backward",
152+
"LigerMHCCoeffsFunction",
153+
"LigerMHCPreFunction",
154+
"LigerMHCPostResFunction",
149155
]

0 commit comments

Comments
 (0)