stuff
This commit is contained in:
@@ -402,13 +402,13 @@ def warmup_compilation(num_experts, K_packed, N_packed, device,
|
||||
_warmup_a_bf16 = torch.randn(128, K_packed * 2, dtype=torch.bfloat16, device=device) * 0.1
|
||||
mat_a, scale_a, _ = quantize_to_nvfp4(_warmup_a_bf16)
|
||||
del _warmup_a_bf16
|
||||
_warmup_b_bf16 = torch.randn(num_experts, K_packed * 2, N_packed * 2, dtype=torch.bfloat16, device=device) * 0.1
|
||||
_warmup_b_bf16 = torch.randn(1, K_packed * 2, N_packed * 2, dtype=torch.bfloat16, device=device) * 0.1 # 1 expert: kernel compiles same regardless of count
|
||||
mat_b, scale_b, _ = quantize_to_nvfp4(_warmup_b_bf16)
|
||||
del _warmup_b_bf16
|
||||
out = torch.zeros(128, N_packed, dtype=torch.bfloat16, device=device)
|
||||
expert_offsets = torch.full((num_experts,), max(128 // num_experts, 1), dtype=torch.int32, device=device)
|
||||
global_scale_a = torch.ones(num_experts, dtype=torch.float32, device=device)
|
||||
global_scale_b = torch.ones(num_experts, dtype=torch.float32, device=device)
|
||||
expert_offsets = torch.full((1,), 128, dtype=torch.int32, device=device)
|
||||
global_scale_a = torch.ones(1, dtype=torch.float32, device=device)
|
||||
global_scale_b = torch.ones(1, dtype=torch.float32, device=device)
|
||||
|
||||
kernel = ScaledGroupedGemmKernel(
|
||||
scenario="2Dx3D",
|
||||
@@ -614,15 +614,15 @@ def warmup_fused_swiglu_compilation(num_experts, K_packed, N_packed, device,
|
||||
_warmup_a_bf16 = torch.randn(128, K_packed * 2, dtype=torch.bfloat16, device=device) * 0.1
|
||||
mat_a, scale_a, _ = quantize_to_nvfp4(_warmup_a_bf16)
|
||||
del _warmup_a_bf16
|
||||
_warmup_b_bf16 = torch.randn(num_experts, K_packed * 2, N_packed * 2, dtype=torch.bfloat16, device=device) * 0.1
|
||||
_warmup_b_bf16 = torch.randn(1, K_packed * 2, N_packed * 2, dtype=torch.bfloat16, device=device) * 0.1 # 1 expert: kernel compiles same regardless of count
|
||||
mat_b, scale_b, _ = quantize_to_nvfp4(_warmup_b_bf16)
|
||||
del _warmup_b_bf16
|
||||
# BF16 output (Stage 1: we still write BF16)
|
||||
# The fused kernel writes intermediate (N/2) since gate+up → silu result
|
||||
out = torch.zeros(128, N_packed, dtype=torch.bfloat16, device=device)
|
||||
expert_offsets = torch.full((num_experts,), max(128 // num_experts, 1), dtype=torch.int32, device=device)
|
||||
global_scale_a = torch.ones(num_experts, dtype=torch.float32, device=device)
|
||||
global_scale_b = torch.ones(num_experts, dtype=torch.float32, device=device)
|
||||
expert_offsets = torch.full((1,), 128, dtype=torch.int32, device=device)
|
||||
global_scale_a = torch.ones(1, dtype=torch.float32, device=device)
|
||||
global_scale_b = torch.ones(1, dtype=torch.float32, device=device)
|
||||
|
||||
kernel = FusedSwiGLUScaledGroupedGemmKernel(
|
||||
scenario="2Dx3D",
|
||||
|
||||
Reference in New Issue
Block a user