[Model Runner V2] Fix warmup for pipeline parallel (#36280)

Signed-off-by: Nick Hill <nickhill123@gmail.com>
This commit is contained in:
Nick Hill
2026-03-06 16:58:51 -08:00
committed by GitHub
parent 6a18d8789b
commit b354686524
2 changed files with 16 additions and 8 deletions

View File

@@ -584,7 +584,7 @@ class Worker(WorkerBase):
if self.use_v2_model_runner:
# V2: Run full execute_model + sample_tokens to JIT compile triton kernels.
warmup_kernels(self.model_runner)
warmup_kernels(self.model_runner, self.execute_model, self.sample_tokens)
elif get_pp_group().is_last_rank:
# V1: Warm up sampler and preallocate memory buffer for logits and other
# sampling related tensors of max possible shape to avoid memory