[Bug] Fix fp8 deepgemm batch invariant (#37718)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Wentao Ye
2026-03-22 08:57:20 -04:00
committed by GitHub
parent b3e846017d
commit 77d24c4bfe

View File

@@ -305,6 +305,11 @@ def _flashinfer_fp8_blockscale_gemm_impl(
) )
return output return output
from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant
if vllm_is_batch_invariant():
return run_deepgemm(input, weight, weight_scale)
condition = input.shape[0] < 32 condition = input.shape[0] < 32
# PyTorch's torch.compile cannot handle input-dependent control flow in standard # PyTorch's torch.compile cannot handle input-dependent control flow in standard