feat: tqdm progress bar for expert weight loading

Replaces heartbeat prints with a clean tqdm bar:
  Loading Native NVFP4 Expert Weights: 50%|██████████░░| 480/960
This commit is contained in:
2026-05-16 06:09:22 +00:00
parent 2e4ff6b8d4
commit 5d975d00d9

View File

@@ -223,6 +223,7 @@ class DeepseekV4MegaMoEExperts(nn.Module):
"""
_cutedsl_runner: 'CuTeDSLMoERunner | None' = None
_weight_load_count: int = 0
_weight_load_tqdm: 'tqdm | None' = None
# NVFP4 E2M1 lookup table (positive values, sign from bit 3)
E2M1_LUT = [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0]
@@ -352,11 +353,15 @@ class DeepseekV4MegaMoEExperts(nn.Module):
shard_id: str,
expert_id: int,
) -> bool:
# Heartbeat: print every 256 weight loads so k8s/docker
# don't think the pod is dead during GPU upload
# Progress bar for k8s/docker liveness during GPU upload
if DeepseekV4MegaMoEExperts._weight_load_count == 0:
DeepseekV4MegaMoEExperts._weight_load_tqdm = tqdm(
total=self.num_local_experts * 20, # ~20 tensors per expert
desc=" Loading Native NVFP4 Expert Weights",
unit="tensor",
)
DeepseekV4MegaMoEExperts._weight_load_count += 1
if DeepseekV4MegaMoEExperts._weight_load_count % 256 == 1:
print(f" Loading expert weights... ({DeepseekV4MegaMoEExperts._weight_load_count})", flush=True)
DeepseekV4MegaMoEExperts._weight_load_tqdm.update(1)
local_expert_id = self._map_global_expert_id(expert_id)
if local_expert_id == -1: