fix: remove broken hc_head warmup (wrong tensor shape)
hc_head_fuse_tilelang expects fn shape[0]=hc_mult (4) but we passed hc_mult*(2+hc_mult) (24). Since --enforce-eager disables @torch.compile anyway, hc_head runs eagerly and doesn't need warmup.
This commit is contained in:
@@ -2276,29 +2276,8 @@ class DeepseekV4ForCausalLM(nn.Module):
|
||||
except Exception as e:
|
||||
print(f" mhc_post warmup failed (non-fatal): {e}", flush=True)
|
||||
|
||||
# Warmup hc_head (also @torch.compile lazy)
|
||||
try:
|
||||
from vllm.platforms import current_platform
|
||||
hc_head_fn = torch.randn(
|
||||
hc_mult * (2 + hc_mult), hc_mult * hidden_size,
|
||||
dtype=torch.float32, device=device)
|
||||
hc_head_scale = torch.randn(3, dtype=torch.float32, device=device)
|
||||
hc_head_base = torch.randn(
|
||||
hc_mult * (2 + hc_mult), dtype=torch.float32, device=device)
|
||||
hs = torch.randn(1, hc_mult, hidden_size, dtype=torch.bfloat16,
|
||||
device=device)
|
||||
hc_head(hs, hc_head_fn, hc_head_scale, hc_head_base,
|
||||
config.rms_norm_eps, config.hc_eps)
|
||||
print(" hc_head ✓", flush=True)
|
||||
except Exception as e:
|
||||
print(f" hc_head warmup failed (non-fatal): {e}", flush=True)
|
||||
|
||||
# Free dummy tensors
|
||||
del residual, fn, hc_scale, hc_base, x, post_mix, comb_mix
|
||||
try:
|
||||
del hc_head_fn, hc_head_scale, hc_head_base, hs
|
||||
except NameError:
|
||||
pass
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
|
||||
|
||||
Reference in New Issue
Block a user