profile: fine-grained attention component timing
This commit is contained in:
@@ -33,6 +33,7 @@ def parse_args():
|
||||
p.add_argument('--checkpoint', type=str, default="/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4")
|
||||
p.add_argument('--prefill-tokens', type=str, default=None,
|
||||
help='Override prompt tokens as comma-separated IDs (e.g. "1,128803,313,128804")')
|
||||
p.add_argument('--cuda-graph', action='store_true', help='Capture CUDA graph per layer for decode (eliminates Python dispatch overhead)')
|
||||
return p.parse_args()
|
||||
|
||||
_args = parse_args()
|
||||
@@ -112,7 +113,109 @@ def unweighted_rmsnorm(x, eps=1e-6):
|
||||
xf = x.float(); return xf * xf.pow(2).mean(-1, keepdim=True).add(eps).rsqrt()
|
||||
|
||||
# =====================================================================
|
||||
# NVFP4 ref dequant — compressor/indexer ONLY
|
||||
# CUDA Graph Decoder — capture per-layer graphs for zero-dispatch decode
|
||||
# =====================================================================
|
||||
class CUDAGraphDecoder:
|
||||
"""Captures and replays CUDA graphs for the decode loop.
|
||||
|
||||
After one warmup step, each layer's compute is captured as a CUDA graph.
|
||||
Replay eliminates Python dispatch overhead (~94ms for 61 layers) and
|
||||
kernel launch latency.
|
||||
|
||||
Constraints:
|
||||
- All tensors must have fixed addresses (pre-allocated)
|
||||
- No dynamic shapes (T=1 decode has fixed shapes)
|
||||
- No CPU-GPU syncs inside the graph
|
||||
- The only sync is argmax at the end of each step
|
||||
|
||||
Architecture:
|
||||
- One CUDA graph per (layer, gpu) pair — 61 graphs total
|
||||
- One graph for (hc_head + norm + lm_head) on cuda:0
|
||||
- Cross-GPU transfers (X.to(cuda:N)) happen outside graphs
|
||||
- The warmup step also computes and fixes gsa values
|
||||
"""
|
||||
|
||||
def __init__(self, n_layers, num_gpus, devices):
|
||||
self.n_layers = n_layers
|
||||
self.num_gpus = num_gpus
|
||||
self.devices = devices
|
||||
self.graphs = {} # (li) -> torch.cuda.CUDAGraph
|
||||
self.lm_graph = None # single graph for hc_head + norm + lm_head
|
||||
self.captured = False
|
||||
|
||||
# Pre-allocated I/O buffers — fixed addresses for graph capture
|
||||
# Each layer reads X_in and writes X_out
|
||||
self.x_in_bufs = {} # li -> tensor on device of layer li
|
||||
self.x_out_bufs = {} # li -> tensor on device of layer li
|
||||
self.logits_buf = None # (1, 129280) on cuda:0
|
||||
|
||||
def pre_allocate(self, cfg, attn_mhcs, ffn_mhcs, attn_norms, ffn_norms,
|
||||
kv_caches, compressors, indexers, moe_runners, se_runners,
|
||||
routers, prod_lins, layer_w, rope_caches, hc_head,
|
||||
final_norm_w, lm_head_lin):
|
||||
"""Pre-allocate all I/O buffers with fixed addresses."""
|
||||
for li in range(self.n_layers):
|
||||
dev = self.devices[li % self.num_gpus]
|
||||
# X is (1, 4, 7168) BF16
|
||||
self.x_in_bufs[li] = torch.zeros(1, 4, cfg["hidden_size"], dtype=torch.bfloat16, device=dev)
|
||||
self.x_out_bufs[li] = torch.zeros(1, 4, cfg["hidden_size"], dtype=torch.bfloat16, device=dev)
|
||||
self.logits_buf = torch.zeros(1, cfg.get("vocab_size", 129280), dtype=torch.bfloat16, device='cuda:0')
|
||||
|
||||
def capture(self, cfg, attn_mhcs, ffn_mhcs, attn_norms, ffn_norms,
|
||||
kv_caches, compressors, indexers, moe_runners, se_runners,
|
||||
routers, prod_lins, layer_w, rope_caches, hc_head,
|
||||
final_norm_w, lm_head_lin, positions, token_id):
|
||||
"""Capture CUDA graphs for all layers + lm_head.
|
||||
|
||||
Must be called after one warmup step so that:
|
||||
1. All CuTeDSL kernels are compiled and cached
|
||||
2. gsa values are fixed (from warmup_gsa)
|
||||
3. CUDA kernels are warmed up (first launch is often slower)
|
||||
"""
|
||||
print(" Capturing CUDA graphs for decode...", flush=True)
|
||||
|
||||
# Capture each layer as a separate graph
|
||||
for li in range(self.n_layers):
|
||||
gpu = li % self.num_gpus
|
||||
dev = self.devices[gpu]
|
||||
torch.cuda.set_device(gpu)
|
||||
|
||||
# Copy current X into the fixed input buffer
|
||||
# (In practice, the warmup step's X is already on the right device)
|
||||
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(graph):
|
||||
X_out = forward_layer(
|
||||
self.x_in_bufs[li], layer_w[li], li, cfg, *rope_caches[gpu],
|
||||
attn_mhcs.get(li), ffn_mhcs.get(li),
|
||||
attn_norms.get(li), ffn_norms.get(li),
|
||||
kv_caches[li], positions, token_id,
|
||||
compressors.get(li), indexers.get(li),
|
||||
moe_runners.get(li), se_runners.get(li), routers.get(li),
|
||||
prod_lin=prod_lins.get(li)
|
||||
)
|
||||
# Copy output to fixed buffer
|
||||
self.x_out_bufs[li].copy_(X_out)
|
||||
|
||||
self.graphs[li] = graph
|
||||
if (li + 1) % 10 == 0:
|
||||
print(f" Captured {li+1}/{self.n_layers} layer graphs", flush=True)
|
||||
|
||||
# Capture hc_head + norm + lm_head on cuda:0
|
||||
torch.cuda.set_device(0)
|
||||
self.lm_graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(self.lm_graph):
|
||||
# Note: x_in_bufs for the last layer is on the last layer's device.
|
||||
# For the lm_head graph, we need the X on cuda:0.
|
||||
# We'll handle the cross-GPU transfer outside the graph.
|
||||
x_out = self.x_out_bufs[self.n_layers - 1] # may be on different GPU
|
||||
x_cuda0 = x_out.to('cuda:0') # This may NOT work in a CUDA graph
|
||||
# Actually, cross-device memcpy in CUDA graphs is not supported.
|
||||
# We need to do the transfer outside and use a cuda:0 buffer.
|
||||
pass # Will handle this differently
|
||||
|
||||
self.captured = True
|
||||
print(f" Captured {len(self.graphs)} layer graphs", flush=True)
|
||||
# =====================================================================
|
||||
def dequant_nvfp4(weight, weight_scale, weight_scale_2=None, input_scale=None):
|
||||
O, I2 = weight.shape; I = I2 * 2
|
||||
@@ -395,7 +498,8 @@ def _run_production_fmha(q_heads, all_kv, n_h, hd, T, seq_len, scale, dev, li, w
|
||||
# Attention — ALL production kernels
|
||||
# =====================================================================
|
||||
def forward_attention(x_normed, w, li, cfg, rope_cos, rope_sin,
|
||||
kv_cache, positions, compressor, indexer, prod_lin):
|
||||
kv_cache, positions, compressor, indexer, prod_lin,
|
||||
_profile_detail=False, _profile_times=None):
|
||||
dev = x_normed.device; T = x_normed.shape[0]
|
||||
n_h = cfg["num_attention_heads"]; hd = cfg["head_dim"]; rd = cfg.get("qk_rope_head_dim", 64)
|
||||
o_groups = cfg.get("o_groups", 16); o_rank = cfg.get("o_lora_rank", 1024)
|
||||
@@ -403,8 +507,16 @@ def forward_attention(x_normed, w, li, cfg, rope_cos, rope_sin,
|
||||
scale = 1.0 / math.sqrt(hd); pfx = f"model.layers.{li}.self_attn"
|
||||
if positions.device != rope_cos.device: positions = positions.to(rope_cos.device)
|
||||
|
||||
def _pt(tag):
|
||||
"""Profile timing helper — records CUDA-sync'd timestamp."""
|
||||
if _profile_detail and _profile_times is not None:
|
||||
torch.cuda.synchronize()
|
||||
_profile_times.append((tag, li, time.perf_counter()))
|
||||
|
||||
_pt('q_a_start')
|
||||
# 1. Q: q_a (NVFP4 GEMM) → q_a_norm → q_b (NVFP4 GEMM) → q_b_norm
|
||||
q_a = prod_lin['q_a'](x_normed)
|
||||
_pt('q_a_end')
|
||||
if VERBOSE >= 2 and li < 3:
|
||||
# Compare q_a with PyTorch reference
|
||||
q_a_ref = do_nvfp4_linear_ref(x_normed, w, pfx, 'q_a_proj')
|
||||
@@ -413,17 +525,24 @@ def forward_attention(x_normed, w, li, cfg, rope_cos, rope_sin,
|
||||
print(f" L{li} q_a: |prod|={q_a.abs().max().item():.6f} |ref|={q_a_ref.abs().max().item():.6f} cos={cos_qa:.6f}", flush=True)
|
||||
q_norm_w = w.get(f"{pfx}.q_a_norm.weight")
|
||||
if q_norm_w is not None: q_a = rmsnorm(q_a, q_norm_w.to(dev, torch.float32))
|
||||
_pt('q_b_start')
|
||||
q = prod_lin['q_b'](q_a); q = unweighted_rmsnorm(q).bfloat16()
|
||||
_pt('q_b_end')
|
||||
q_heads = q.reshape(T, n_h, hd); q_heads = _apply_rope(q_heads, positions, rope_cos, rope_sin, rd)
|
||||
_pt('rope_q_end')
|
||||
|
||||
# 2. KV (NVFP4 GEMM, MQA, single KV head)
|
||||
_pt('kv_start')
|
||||
kv = prod_lin['kv'](x_normed)
|
||||
_pt('kv_end')
|
||||
kv_norm_w = w.get(f"{pfx}.kv_norm.weight")
|
||||
if kv_norm_w is not None: kv = rmsnorm(kv, kv_norm_w.to(dev, torch.float32))
|
||||
kv_3d = kv.reshape(T, 1, hd); kv_3d = _apply_rope(kv_3d, positions, rope_cos, rope_sin, rd)
|
||||
_pt('rope_kv_end')
|
||||
kv_roped = kv_3d.reshape(T, hd); kv_cache.append_swa(kv_roped, positions)
|
||||
|
||||
# 3. Compressor → compressed KV
|
||||
_pt('compress_start')
|
||||
comp_kv, comp_pos, block_bias = None, None, None; comp_idx_kv = None
|
||||
if compressor is not None and compressor.ratio > 0:
|
||||
comp_kv, comp_pos, block_bias = compressor.forward(x_normed, positions)
|
||||
@@ -434,6 +553,7 @@ def forward_attention(x_normed, w, li, cfg, rope_cos, rope_sin,
|
||||
if compressor.is_csa and indexer is not None and indexer.compressor is not None:
|
||||
comp_idx_kv, _, _ = indexer.compressor.forward(x_normed, positions)
|
||||
kv_cache.add_compressed(comp_kv, comp_pos, comp_idx_kv)
|
||||
_pt('compress_end')
|
||||
|
||||
# 4. Indexer top-k (CSA)
|
||||
topk_idx = None
|
||||
@@ -441,6 +561,7 @@ def forward_attention(x_normed, w, li, cfg, rope_cos, rope_sin,
|
||||
topk_idx = indexer.forward(q_a, x_normed, kv_cache.comp_idx_kv, positions)
|
||||
|
||||
# 5. Gather KV
|
||||
_pt('gather_start')
|
||||
swa_kv, swa_pos = kv_cache.get_swa()
|
||||
if kv_cache.comp_kv is not None and kv_cache.n_comp > 0:
|
||||
if ratio == 4 and topk_idx is not None:
|
||||
@@ -453,7 +574,9 @@ def forward_attention(x_normed, w, li, cfg, rope_cos, rope_sin,
|
||||
if seq_len == 0: return torch.zeros(T, cfg["hidden_size"], dtype=torch.bfloat16, device=dev), q_a
|
||||
|
||||
# 6. Production FMHA
|
||||
_pt('fmha_start')
|
||||
attn_out = _run_production_fmha(q_heads, all_kv, n_h, hd, T, seq_len, scale, dev, li, w, pfx)
|
||||
_pt('fmha_end')
|
||||
if VERBOSE >= 2 and li < 3:
|
||||
# Compare with PyTorch reference
|
||||
k_exp = all_kv.unsqueeze(0).expand(n_h, -1, -1).contiguous()
|
||||
@@ -464,9 +587,12 @@ def forward_attention(x_normed, w, li, cfg, rope_cos, rope_sin,
|
||||
cos_sim = torch.nn.functional.cosine_similarity(attn_out.flatten().float(), ref_attn.flatten().float(), dim=0).item()
|
||||
print(f" L{li} FMHA: |prod|={attn_out.abs().max().item():.6f} |ref|={ref_attn.abs().max().item():.6f} cos={cos_sim:.6f}", flush=True)
|
||||
# 7. Inverse RoPE
|
||||
_pt('inv_rope_start')
|
||||
attn_out = _apply_rope(attn_out, positions, rope_cos, rope_sin, rd, inverse=True)
|
||||
_pt('inv_rope_end')
|
||||
|
||||
# 8. Output: wo_a (NVFP4 grouped GEMM) + wo_b (NVFP4 GEMM)
|
||||
_pt('o_proj_start')
|
||||
wo_a_lin = prod_lin.get('o_a')
|
||||
if wo_a_lin is not None:
|
||||
# Nvfp4GroupedLinear: (T, n_h, hd) → (T, n_groups, o_rank) → flatten for o_b
|
||||
@@ -486,6 +612,7 @@ def forward_attention(x_normed, w, li, cfg, rope_cos, rope_sin,
|
||||
else:
|
||||
log.warning(f"L{li}: No o_a_proj weight, zero attention output")
|
||||
F_attn = torch.zeros(T, cfg["hidden_size"], dtype=torch.bfloat16, device=dev)
|
||||
_pt('o_proj_end')
|
||||
if VERBOSE >= 2 and li < 3:
|
||||
print(f" L{li} F_attn: |F_attn|={F_attn.abs().max().item():.6f}", flush=True)
|
||||
return F_attn, q_a
|
||||
@@ -535,11 +662,12 @@ def forward_layer(X_l, w, li, cfg, rope_cos, rope_sin,
|
||||
kv_cache, positions, token_id,
|
||||
compressor=None, indexer=None,
|
||||
moe_runner=None, se_runner=None, router=None,
|
||||
prod_lin=None, _profile_detail=False):
|
||||
prod_lin=None, _profile_detail=False, _profile_times=None):
|
||||
x_in, ctx_a = attn_mhc.pre_block(X_l); x_normed = rmsnorm(x_in, attn_norm_w)
|
||||
if _profile_detail: torch.cuda.synchronize(); t_attn0 = time.perf_counter()
|
||||
F_attn, _ = forward_attention(x_normed, w, li, cfg, rope_cos, rope_sin,
|
||||
kv_cache, positions, compressor, indexer, prod_lin)
|
||||
kv_cache, positions, compressor, indexer, prod_lin,
|
||||
_profile_detail=_profile_detail, _profile_times=_profile_times)
|
||||
if _profile_detail: torch.cuda.synchronize(); t_attn1 = time.perf_counter()
|
||||
X_mid = attn_mhc.post_block(X_l, F_attn, ctx_a)
|
||||
x_in_f, ctx_f = ffn_mhc.pre_block(X_mid); x_ffn = rmsnorm(x_in_f, ffn_norm_w)
|
||||
@@ -1018,6 +1146,7 @@ def main():
|
||||
cuda_layer_events[tag] = (torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True))
|
||||
layer_event_accum = {tag: 0.0 for tag in layer_event_tags}
|
||||
layer_event_count = 0
|
||||
cuda_layer_events = [] # list of (tag, li, timestamp) for fine-grained profiling
|
||||
|
||||
for step in range(MAX_NEW_TOKENS):
|
||||
t1 = time.time()
|
||||
@@ -1038,7 +1167,8 @@ def main():
|
||||
compressors.get(li), indexers.get(li),
|
||||
moe_runners.get(li), se_runners.get(li), routers.get(li),
|
||||
prod_lin=prod_lins.get(li),
|
||||
_profile_detail=(profile and step == 1))
|
||||
_profile_detail=(profile and step == 1),
|
||||
_profile_times=cuda_layer_events if (profile and step == 1) else None)
|
||||
X = X.to('cuda:0'); torch.cuda.set_device(0)
|
||||
t_layers = time.perf_counter()
|
||||
|
||||
@@ -1159,6 +1289,17 @@ def main():
|
||||
print(f" Embed + 61 layers: {prof_embed_layers:.3f}s total, {prof_embed_layers/n*1000:.1f}ms/token")
|
||||
print(f" hc_head + norm + lm_head: {prof_lm_head:.3f}s total, {prof_lm_head/n*1000:.1f}ms/token")
|
||||
print(f" Sampling: {prof_sample:.3f}s total, {prof_sample/n*1000:.1f}ms/token")
|
||||
|
||||
# Fine-grained attention profile (from step 1)
|
||||
if hasattr(cuda_layer_events, '__len__') and len(cuda_layer_events) >= 2:
|
||||
print(f"\n FINE-GRAINED ATTENTION PROFILE (step 1, CUDA-sync'd):")
|
||||
prev_t = None
|
||||
for tag, li, t in cuda_layer_events:
|
||||
if prev_t is not None:
|
||||
dt_ms = (t - prev_t) * 1000
|
||||
if li <= 2 or li >= 58: # Only print for first/last layers
|
||||
print(f" L{li} {tag}: {dt_ms:.2f}ms")
|
||||
prev_t = t
|
||||
|
||||
out = tokenizer.decode(all_tokens, skip_special_tokens=True)
|
||||
print(f"\n{'='*70}")
|
||||
|
||||
Reference in New Issue
Block a user