[Disagg][Perf] Use CUDA event sync instead of blocking tolist to avoid unintentional copy ops blocking across different CUDA streams, improving disagg TTIT/TTFT (#22760)
Signed-off-by: Zijing Liu <liuzijing2014@gmail.com> Signed-off-by: Zijing Liu <liuzijing2014@users.noreply.github.com>
This commit is contained in:
@@ -316,6 +316,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
# Cached outputs.
|
# Cached outputs.
|
||||||
self._draft_token_ids: Optional[Union[list[list[int]],
|
self._draft_token_ids: Optional[Union[list[list[int]],
|
||||||
torch.Tensor]] = None
|
torch.Tensor]] = None
|
||||||
|
self.transfer_event = torch.cuda.Event()
|
||||||
|
self.sampled_token_ids_pinned_cpu = torch.empty(
|
||||||
|
(self.max_model_len, 1),
|
||||||
|
dtype=torch.int64,
|
||||||
|
device="cpu",
|
||||||
|
pin_memory=True)
|
||||||
|
|
||||||
def _make_buffer(self, *args, dtype: torch.dtype) -> CpuGpuBuffer:
|
def _make_buffer(self, *args, dtype: torch.dtype) -> CpuGpuBuffer:
|
||||||
return CpuGpuBuffer(*args,
|
return CpuGpuBuffer(*args,
|
||||||
@@ -1691,7 +1697,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
max_gen_len = sampled_token_ids.shape[-1]
|
max_gen_len = sampled_token_ids.shape[-1]
|
||||||
if max_gen_len == 1:
|
if max_gen_len == 1:
|
||||||
# No spec decode tokens.
|
# No spec decode tokens.
|
||||||
valid_sampled_token_ids = sampled_token_ids.tolist()
|
valid_sampled_token_ids = self._to_list(sampled_token_ids)
|
||||||
else:
|
else:
|
||||||
# Includes spec decode tokens.
|
# Includes spec decode tokens.
|
||||||
valid_sampled_token_ids = self.rejection_sampler.parse_output(
|
valid_sampled_token_ids = self.rejection_sampler.parse_output(
|
||||||
@@ -3233,3 +3239,18 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
mamba_type=mamba_module.mamba_type)
|
mamba_type=mamba_module.mamba_type)
|
||||||
|
|
||||||
return kv_cache_spec
|
return kv_cache_spec
|
||||||
|
|
||||||
|
def _to_list(self, sampled_token_ids: torch.Tensor) -> list[list[int]]:
|
||||||
|
# This is a short term mitigation for issue mentioned in
|
||||||
|
# https://github.com/vllm-project/vllm/issues/22754.
|
||||||
|
# `tolist` would trigger a cuda wise stream sync, which
|
||||||
|
# would block other copy ops from other cuda streams.
|
||||||
|
# A cuda event sync would avoid such a situation. Since
|
||||||
|
# this is in the critical path of every single model
|
||||||
|
# forward loop, this has caused perf issue for a disagg
|
||||||
|
# setup.
|
||||||
|
pinned = self.sampled_token_ids_pinned_cpu[:sampled_token_ids.shape[0]]
|
||||||
|
pinned.copy_(sampled_token_ids, non_blocking=True)
|
||||||
|
self.transfer_event.record()
|
||||||
|
self.transfer_event.synchronize()
|
||||||
|
return pinned.tolist()
|
||||||
|
|||||||
Reference in New Issue
Block a user