Revert "[V1][Core] Fix memory issue with logits & sampling" (#13775)
This commit is contained in:
@@ -1179,43 +1179,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
)
|
)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
@torch.inference_mode()
|
|
||||||
def _dummy_sampler_run(
|
|
||||||
self,
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
|
|
||||||
logits = self.model.compute_logits(hidden_states, None)
|
|
||||||
num_reqs = logits.size(0)
|
|
||||||
|
|
||||||
dummy_tensors = lambda v: torch.full(
|
|
||||||
(num_reqs, ), v, device=self.device)
|
|
||||||
|
|
||||||
dummy_metadata = SamplingMetadata(
|
|
||||||
temperature=dummy_tensors(0.5),
|
|
||||||
all_greedy=False,
|
|
||||||
all_random=False,
|
|
||||||
spec_token_ids=None,
|
|
||||||
top_p=dummy_tensors(0.9),
|
|
||||||
top_k=dummy_tensors(logits.size(1) - 1),
|
|
||||||
min_p=None,
|
|
||||||
generators={},
|
|
||||||
max_num_logprobs=None,
|
|
||||||
no_penalties=True,
|
|
||||||
prompt_token_ids=None,
|
|
||||||
frequency_penalties=dummy_tensors(0.1),
|
|
||||||
presence_penalties=dummy_tensors(0.1),
|
|
||||||
repetition_penalties=dummy_tensors(0.1),
|
|
||||||
output_token_ids=[[] for _ in range(num_reqs)],
|
|
||||||
min_tokens={},
|
|
||||||
logit_bias=[None for _ in range(num_reqs)],
|
|
||||||
allowed_token_ids_mask=None,
|
|
||||||
)
|
|
||||||
sampler_output = self.model.sample(logits=logits,
|
|
||||||
sampling_metadata=dummy_metadata)
|
|
||||||
|
|
||||||
return sampler_output
|
|
||||||
|
|
||||||
def profile_run(self) -> None:
|
def profile_run(self) -> None:
|
||||||
# use an empty tensor instead of `None`` to force Dynamo to pass
|
# use an empty tensor instead of `None`` to force Dynamo to pass
|
||||||
# it by reference, rather by specializing on the value `None`.
|
# it by reference, rather by specializing on the value `None`.
|
||||||
@@ -1343,11 +1306,38 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
dummy_kv_caches)
|
dummy_kv_caches)
|
||||||
if get_pp_group().is_last_rank:
|
if get_pp_group().is_last_rank:
|
||||||
hidden_states = hidden_states[logit_indices]
|
hidden_states = hidden_states[logit_indices]
|
||||||
sampler_output = self._dummy_sampler_run(hidden_states)
|
logits = self.model.compute_logits(hidden_states, None)
|
||||||
|
dummy_tensors = lambda v: torch.full(
|
||||||
|
(num_reqs, ), v, device=self.device)
|
||||||
|
dummy_metadata = SamplingMetadata(
|
||||||
|
temperature=dummy_tensors(0.5),
|
||||||
|
all_greedy=False,
|
||||||
|
all_random=False,
|
||||||
|
spec_token_ids=None,
|
||||||
|
top_p=dummy_tensors(0.9),
|
||||||
|
top_k=dummy_tensors(logits.size(1) - 1),
|
||||||
|
min_p=None,
|
||||||
|
generators={},
|
||||||
|
max_num_logprobs=None,
|
||||||
|
no_penalties=True,
|
||||||
|
prompt_token_ids=torch.ones_like(logits,
|
||||||
|
dtype=torch.int64),
|
||||||
|
frequency_penalties=dummy_tensors(0.1),
|
||||||
|
presence_penalties=dummy_tensors(0.1),
|
||||||
|
repetition_penalties=dummy_tensors(0.1),
|
||||||
|
output_token_ids=[[] for _ in range(num_reqs)],
|
||||||
|
min_tokens={},
|
||||||
|
logit_bias=[None for _ in range(num_reqs)],
|
||||||
|
allowed_token_ids_mask=None,
|
||||||
|
)
|
||||||
|
sampler_output = self.model.sample(
|
||||||
|
logits=logits, sampling_metadata=dummy_metadata)
|
||||||
else:
|
else:
|
||||||
|
logits = None
|
||||||
sampler_output = None
|
sampler_output = None
|
||||||
|
dummy_metadata = None
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
del hidden_states, sampler_output
|
del hidden_states, logits, sampler_output, dummy_metadata
|
||||||
self.encoder_cache.clear()
|
self.encoder_cache.clear()
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
|
|||||||
@@ -211,16 +211,6 @@ class Worker(WorkerBase):
|
|||||||
self.model_runner._dummy_run(size)
|
self.model_runner._dummy_run(size)
|
||||||
if not self.model_config.enforce_eager:
|
if not self.model_config.enforce_eager:
|
||||||
self.model_runner.capture_model()
|
self.model_runner.capture_model()
|
||||||
|
|
||||||
# Warm up sampler and preallocate memory buffer for logits and other
|
|
||||||
# sampling related tensors of max possible shape to avoid memory
|
|
||||||
# fragmentation issue.
|
|
||||||
# NOTE: This is called after `capture_model` on purpose to prevent
|
|
||||||
# memory buffers from being cleared by `torch.cuda.empty_cache`.
|
|
||||||
self.model_runner._dummy_sampler_run(
|
|
||||||
hidden_states=self.model_runner._dummy_run(
|
|
||||||
num_tokens=self.scheduler_config.max_num_seqs))
|
|
||||||
|
|
||||||
# Reset the seed to ensure that the random state is not affected by
|
# Reset the seed to ensure that the random state is not affected by
|
||||||
# the model initialization and profiling.
|
# the model initialization and profiling.
|
||||||
set_random_seed(self.model_config.seed)
|
set_random_seed(self.model_config.seed)
|
||||||
|
|||||||
Reference in New Issue
Block a user