[Bug] Fix fp8 deepgemm batch invariant (#37718)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
@@ -305,6 +305,11 @@ def _flashinfer_fp8_blockscale_gemm_impl(
|
|||||||
)
|
)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant
|
||||||
|
|
||||||
|
if vllm_is_batch_invariant():
|
||||||
|
return run_deepgemm(input, weight, weight_scale)
|
||||||
|
|
||||||
condition = input.shape[0] < 32
|
condition = input.shape[0] < 32
|
||||||
|
|
||||||
# PyTorch's torch.compile cannot handle input-dependent control flow in standard
|
# PyTorch's torch.compile cannot handle input-dependent control flow in standard
|
||||||
|
|||||||
Reference in New Issue
Block a user