[V1][Core] Fix memory issue with logits & sampling (#13721)
This commit is contained in:
@@ -211,6 +211,16 @@ class Worker(WorkerBase):
|
||||
self.model_runner._dummy_run(size)
|
||||
if not self.model_config.enforce_eager:
|
||||
self.model_runner.capture_model()
|
||||
|
||||
# Warm up sampler and preallocate memory buffer for logits and other
|
||||
# sampling related tensors of max possible shape to avoid memory
|
||||
# fragmentation issue.
|
||||
# NOTE: This is called after `capture_model` on purpose to prevent
|
||||
# memory buffers from being cleared by `torch.cuda.empty_cache`.
|
||||
self.model_runner._dummy_sampler_run(
|
||||
hidden_states=self.model_runner._dummy_run(
|
||||
num_tokens=self.scheduler_config.max_num_seqs))
|
||||
|
||||
# Reset the seed to ensure that the random state is not affected by
|
||||
# the model initialization and profiling.
|
||||
set_random_seed(self.model_config.seed)
|
||||
|
||||
Reference in New Issue
Block a user