[Kernels] Modular kernel refactor (#24812)
Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
@@ -244,7 +244,7 @@ if has_flashinfer_cutlass_fused_moe() and current_platform.has_device_capability
|
||||
register_prepare_and_finalize(
|
||||
FlashInferCutlassMoEPrepareAndFinalize,
|
||||
standard_format,
|
||||
nvfp4_types,
|
||||
nvfp4_types + fp8_types,
|
||||
blocked_quantization_support=True,
|
||||
backend=None,
|
||||
force_multigpu=True,
|
||||
@@ -254,7 +254,7 @@ if has_flashinfer_cutlass_fused_moe() and current_platform.has_device_capability
|
||||
register_experts(
|
||||
FlashInferExperts,
|
||||
standard_format,
|
||||
nvfp4_types,
|
||||
nvfp4_types + fp8_types,
|
||||
blocked_quantization_support=True,
|
||||
supports_chunking=True,
|
||||
# Note: this is a hack to get it to run for now
|
||||
@@ -274,17 +274,15 @@ if has_deep_gemm() and is_deep_gemm_supported():
|
||||
needs_matching_quant=False,
|
||||
needs_deep_gemm=True,
|
||||
)
|
||||
(
|
||||
register_experts(
|
||||
DeepGemmExperts,
|
||||
standard_format,
|
||||
fp8_types,
|
||||
blocked_quantization_support=True,
|
||||
supports_chunking=True,
|
||||
supports_expert_map=True,
|
||||
needs_matching_quant=False,
|
||||
needs_deep_gemm=True,
|
||||
),
|
||||
register_experts(
|
||||
DeepGemmExperts,
|
||||
standard_format,
|
||||
fp8_types,
|
||||
blocked_quantization_support=True,
|
||||
supports_chunking=True,
|
||||
supports_expert_map=True,
|
||||
needs_matching_quant=False,
|
||||
needs_deep_gemm=True,
|
||||
)
|
||||
register_experts(
|
||||
BatchedTritonOrDeepGemmExperts,
|
||||
@@ -464,7 +462,7 @@ def make_fused_experts(
|
||||
print(f"Making BatchedTritonOrDeepGemmExperts {kwargs} ...")
|
||||
experts = BatchedTritonOrDeepGemmExperts(**kwargs)
|
||||
elif fused_experts_type == DeepGemmExperts:
|
||||
print("Making DeepGemmExperts {quant_config} ...")
|
||||
print(f"Making DeepGemmExperts {quant_config} ...")
|
||||
experts = DeepGemmExperts(quant_config)
|
||||
elif fused_experts_type == TritonExperts:
|
||||
kwargs = quant_kwargs
|
||||
|
||||
Reference in New Issue
Block a user