fix: remove broken hc_head warmup (wrong tensor shape)

hc_head_fuse_tilelang expects fn shape[0]=hc_mult (4) but we passed
hc_mult*(2+hc_mult) (24). Since --enforce-eager disables @torch.compile
anyway, hc_head runs eagerly and doesn't need warmup.
This commit is contained in:
2026-05-16 10:11:34 +00:00
parent c803180706
commit f0c1be3ced

View File

@@ -2276,29 +2276,8 @@ class DeepseekV4ForCausalLM(nn.Module):
except Exception as e:
print(f" mhc_post warmup failed (non-fatal): {e}", flush=True)
# Warmup hc_head (also @torch.compile lazy)
try:
from vllm.platforms import current_platform
hc_head_fn = torch.randn(
hc_mult * (2 + hc_mult), hc_mult * hidden_size,
dtype=torch.float32, device=device)
hc_head_scale = torch.randn(3, dtype=torch.float32, device=device)
hc_head_base = torch.randn(
hc_mult * (2 + hc_mult), dtype=torch.float32, device=device)
hs = torch.randn(1, hc_mult, hidden_size, dtype=torch.bfloat16,
device=device)
hc_head(hs, hc_head_fn, hc_head_scale, hc_head_base,
config.rms_norm_eps, config.hc_eps)
print(" hc_head ✓", flush=True)
except Exception as e:
print(f" hc_head warmup failed (non-fatal): {e}", flush=True)
# Free dummy tensors
del residual, fn, hc_scale, hc_base, x, post_mix, comb_mix
try:
del hc_head_fn, hc_head_scale, hc_head_base, hs
except NameError:
pass
torch.cuda.empty_cache()
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: