[V1][Core] Fix memory issue with logits & sampling (#14508)
Signed-off-by: Roger Wang <ywang@roblox.com> Co-authored-by: Varun Sundar Rabindranath <3337719+varun-sundar-rabindranath@users.noreply.github.com>
This commit is contained in:
@@ -1202,41 +1202,98 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
self,
|
||||
num_tokens: int,
|
||||
) -> torch.Tensor:
|
||||
model = self.model
|
||||
if self.is_multimodal_model:
|
||||
input_ids = None
|
||||
inputs_embeds = self.inputs_embeds[:num_tokens]
|
||||
else:
|
||||
input_ids = self.input_ids[:num_tokens]
|
||||
inputs_embeds = None
|
||||
if self.uses_mrope:
|
||||
positions = self.mrope_positions[:, :num_tokens]
|
||||
else:
|
||||
positions = self.positions[:num_tokens]
|
||||
|
||||
if get_pp_group().is_first_rank:
|
||||
intermediate_tensors = None
|
||||
else:
|
||||
if self.intermediate_tensors is None:
|
||||
self.intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors(
|
||||
batch_size=self.max_num_tokens,
|
||||
dtype=self.model_config.dtype,
|
||||
device=self.device))
|
||||
intermediate_tensors = IntermediateTensors({
|
||||
k: v[:num_tokens]
|
||||
for k, v in self.intermediate_tensors.items()
|
||||
})
|
||||
# Set num_scheduled_tokens based on num_tokens and max_num_seqs
|
||||
# for dummy run with LoRA so that the num_reqs collectively
|
||||
# has num_tokens in total.
|
||||
assert num_tokens <= self.scheduler_config.max_num_batched_tokens
|
||||
max_num_reqs = self.scheduler_config.max_num_seqs
|
||||
num_reqs = max_num_reqs if num_tokens >= max_num_reqs else num_tokens
|
||||
min_tokens_per_req = num_tokens // num_reqs
|
||||
num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs
|
||||
num_scheduled_tokens_list[-1] += num_tokens % num_reqs
|
||||
assert sum(num_scheduled_tokens_list) == num_tokens
|
||||
assert len(num_scheduled_tokens_list) == num_reqs
|
||||
num_scheduled_tokens = np.array(num_scheduled_tokens_list,
|
||||
dtype=np.int32)
|
||||
|
||||
with set_forward_context(None, self.vllm_config,
|
||||
num_tokens=num_tokens):
|
||||
hidden_states = model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
return hidden_states
|
||||
with self.maybe_dummy_run_with_lora(self.lora_config,
|
||||
num_scheduled_tokens):
|
||||
model = self.model
|
||||
if self.is_multimodal_model:
|
||||
input_ids = None
|
||||
inputs_embeds = self.inputs_embeds[:num_tokens]
|
||||
else:
|
||||
input_ids = self.input_ids[:num_tokens]
|
||||
inputs_embeds = None
|
||||
if self.uses_mrope:
|
||||
positions = self.mrope_positions[:, :num_tokens]
|
||||
else:
|
||||
positions = self.positions[:num_tokens]
|
||||
|
||||
if get_pp_group().is_first_rank:
|
||||
intermediate_tensors = None
|
||||
else:
|
||||
if self.intermediate_tensors is None:
|
||||
self.intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors(
|
||||
batch_size=self.max_num_tokens,
|
||||
dtype=self.model_config.dtype,
|
||||
device=self.device))
|
||||
intermediate_tensors = IntermediateTensors({
|
||||
k: v[:num_tokens]
|
||||
for k, v in self.intermediate_tensors.items()
|
||||
})
|
||||
|
||||
with set_forward_context(None,
|
||||
self.vllm_config,
|
||||
num_tokens=num_tokens):
|
||||
hidden_states = model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
|
||||
logit_indices = np.cumsum(num_scheduled_tokens) - 1
|
||||
return hidden_states[logit_indices]
|
||||
|
||||
@torch.inference_mode()
|
||||
def _dummy_sampler_run(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
|
||||
logits = self.model.compute_logits(hidden_states, None)
|
||||
num_reqs = logits.size(0)
|
||||
|
||||
dummy_tensors = lambda v: torch.full(
|
||||
(num_reqs, ), v, device=self.device)
|
||||
|
||||
dummy_metadata = SamplingMetadata(
|
||||
temperature=dummy_tensors(0.5),
|
||||
all_greedy=False,
|
||||
all_random=False,
|
||||
top_p=dummy_tensors(0.9),
|
||||
top_k=dummy_tensors(logits.size(1) - 1),
|
||||
min_p=None,
|
||||
generators={},
|
||||
max_num_logprobs=None,
|
||||
no_penalties=True,
|
||||
prompt_token_ids=None,
|
||||
frequency_penalties=dummy_tensors(0.1),
|
||||
presence_penalties=dummy_tensors(0.1),
|
||||
repetition_penalties=dummy_tensors(0.1),
|
||||
output_token_ids=[[] for _ in range(num_reqs)],
|
||||
min_tokens={},
|
||||
logit_bias=[None for _ in range(num_reqs)],
|
||||
allowed_token_ids_mask=None,
|
||||
bad_words_token_ids={},
|
||||
)
|
||||
sampler_output = self.model.sample(logits=logits,
|
||||
sampling_metadata=dummy_metadata)
|
||||
|
||||
return sampler_output
|
||||
|
||||
def profile_run(self) -> None:
|
||||
# Profile with multimodal encoder & encoder cache.
|
||||
@@ -1332,60 +1389,14 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
# Cache the dummy encoder outputs.
|
||||
self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs))
|
||||
|
||||
# For profile, have maximum num_reqs and that collectively have
|
||||
# maximum num_tokens.
|
||||
num_reqs = self.scheduler_config.max_num_seqs
|
||||
num_tokens = self.max_num_tokens
|
||||
min_tokens_per_req = num_tokens // num_reqs
|
||||
|
||||
num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs
|
||||
num_scheduled_tokens_list[-1] += num_tokens % num_reqs
|
||||
assert sum(num_scheduled_tokens_list) == num_tokens
|
||||
assert len(num_scheduled_tokens_list) == num_reqs
|
||||
|
||||
num_scheduled_tokens = np.array(num_scheduled_tokens_list,
|
||||
dtype=np.int32)
|
||||
logit_indices = np.cumsum(num_scheduled_tokens) - 1
|
||||
|
||||
with self.maybe_profile_with_lora(self.lora_config,
|
||||
num_scheduled_tokens):
|
||||
# Trigger compilation for general shape.
|
||||
hidden_states = self._dummy_run(self.max_num_tokens)
|
||||
if get_pp_group().is_last_rank:
|
||||
hidden_states = hidden_states[logit_indices]
|
||||
logits = self.model.compute_logits(hidden_states, None)
|
||||
dummy_tensors = lambda v: torch.full(
|
||||
(num_reqs, ), v, device=self.device)
|
||||
dummy_metadata = SamplingMetadata(
|
||||
temperature=dummy_tensors(0.5),
|
||||
all_greedy=False,
|
||||
all_random=False,
|
||||
top_p=dummy_tensors(0.9),
|
||||
top_k=dummy_tensors(logits.size(1) - 1),
|
||||
min_p=None,
|
||||
generators={},
|
||||
max_num_logprobs=None,
|
||||
no_penalties=True,
|
||||
prompt_token_ids=torch.ones_like(logits,
|
||||
dtype=torch.int64),
|
||||
frequency_penalties=dummy_tensors(0.1),
|
||||
presence_penalties=dummy_tensors(0.1),
|
||||
repetition_penalties=dummy_tensors(0.1),
|
||||
output_token_ids=[[] for _ in range(num_reqs)],
|
||||
min_tokens={},
|
||||
logit_bias=[None for _ in range(num_reqs)],
|
||||
allowed_token_ids_mask=None,
|
||||
bad_words_token_ids={},
|
||||
)
|
||||
sampler_output = self.model.sample(
|
||||
logits=logits, sampling_metadata=dummy_metadata)
|
||||
else:
|
||||
logits = None
|
||||
sampler_output = None
|
||||
dummy_metadata = None
|
||||
torch.cuda.synchronize()
|
||||
del hidden_states, logits, sampler_output, dummy_metadata
|
||||
self.encoder_cache.clear()
|
||||
hidden_states = self._dummy_run(self.max_num_tokens)
|
||||
if get_pp_group().is_last_rank:
|
||||
sampler_output = self._dummy_sampler_run(hidden_states)
|
||||
else:
|
||||
sampler_output = None
|
||||
torch.cuda.synchronize()
|
||||
del hidden_states, sampler_output
|
||||
self.encoder_cache.clear()
|
||||
gc.collect()
|
||||
|
||||
def capture_model(self) -> None:
|
||||
|
||||
@@ -119,6 +119,8 @@ class Worker(WorkerBase):
|
||||
self.model_runner: GPUModelRunner = GPUModelRunner(
|
||||
self.vllm_config, self.device)
|
||||
|
||||
# FIXME(youkaichao & ywang96): Use TorchDispatchMode instead of memory pool
|
||||
# to hijack tensor allocation.
|
||||
def load_model(self) -> None:
|
||||
if self.vllm_config.model_config.enable_sleep_mode:
|
||||
allocator = CuMemAllocator.get_instance()
|
||||
@@ -211,6 +213,27 @@ class Worker(WorkerBase):
|
||||
self.model_runner._dummy_run(size)
|
||||
if not self.model_config.enforce_eager:
|
||||
self.model_runner.capture_model()
|
||||
|
||||
# Warm up sampler and preallocate memory buffer for logits and other
|
||||
# sampling related tensors of max possible shape to avoid memory
|
||||
# fragmentation issue.
|
||||
# NOTE: This is called after `capture_model` on purpose to prevent
|
||||
# memory buffers from being cleared by `torch.cuda.empty_cache`.
|
||||
try:
|
||||
max_num_reqs = min(self.scheduler_config.max_num_seqs,
|
||||
self.scheduler_config.max_num_batched_tokens)
|
||||
self.model_runner._dummy_sampler_run(
|
||||
hidden_states=self.model_runner._dummy_run(
|
||||
num_tokens=max_num_reqs))
|
||||
except RuntimeError as e:
|
||||
if 'out of memory' in str(e):
|
||||
raise RuntimeError(
|
||||
"CUDA out of memory occurred when warming up sampler. "
|
||||
"Please try lowering `gpu_memory_utilization` when "
|
||||
"initializing the engine.") from None
|
||||
else:
|
||||
raise e
|
||||
|
||||
# Reset the seed to ensure that the random state is not affected by
|
||||
# the model initialization and profiling.
|
||||
set_random_seed(self.model_config.seed)
|
||||
|
||||
@@ -83,8 +83,8 @@ class LoRAModelRunnerMixin:
|
||||
lora_requests)
|
||||
|
||||
@contextmanager
|
||||
def maybe_profile_with_lora(self, lora_config: LoRAConfig,
|
||||
num_scheduled_tokens: np.ndarray):
|
||||
def maybe_dummy_run_with_lora(self, lora_config: LoRAConfig,
|
||||
num_scheduled_tokens: np.ndarray):
|
||||
if lora_config is None:
|
||||
yield
|
||||
else:
|
||||
@@ -145,4 +145,4 @@ class LoRAModelRunnerMixin:
|
||||
def list_loras(self) -> set[int]:
|
||||
if not self.lora_manager:
|
||||
raise RuntimeError("LoRA is not enabled.")
|
||||
return self.lora_manager.list_adapters()
|
||||
return self.lora_manager.list_adapters()
|
||||
|
||||
Reference in New Issue
Block a user