[Kernel] LoRA - Enable CUDAGraphs for V1 (#14626)
Signed-off-by: Varun Sundar Rabindranath <varun@neuralmagic.com> Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
This commit is contained in:
committed by
GitHub
parent
32ef4983cd
commit
0b1cfa6180
@@ -2287,9 +2287,14 @@ class LoRAConfig:
|
||||
excluding anything before input ids/embeddings and after
|
||||
the final hidden states.
|
||||
"""
|
||||
# no factors to consider.
|
||||
# LoRA is not compatible with `torch.compile` .
|
||||
factors: list[Any] = []
|
||||
factors.append(self.max_lora_rank)
|
||||
factors.append(self.max_loras)
|
||||
factors.append(self.fully_sharded_loras)
|
||||
factors.append(self.lora_dtype)
|
||||
factors.append(self.lora_extra_vocab_size)
|
||||
factors.append(self.long_lora_scaling_factors)
|
||||
factors.append(self.bias_enabled)
|
||||
hash_str = hashlib.md5(str(factors).encode()).hexdigest()
|
||||
return hash_str
|
||||
|
||||
@@ -3303,6 +3308,11 @@ class VllmConfig:
|
||||
vllm_factors.append("None")
|
||||
if self.lora_config:
|
||||
vllm_factors.append(self.lora_config.compute_hash())
|
||||
# LoRA creates static buffers based on max_num_batched_tokens.
|
||||
# The tensor sizes and strides get captured in the torch.compile
|
||||
# graph explicitly.
|
||||
vllm_factors.append(
|
||||
str(self.scheduler_config.max_num_batched_tokens))
|
||||
else:
|
||||
vllm_factors.append("None")
|
||||
if self.speculative_config:
|
||||
@@ -3453,12 +3463,15 @@ class VllmConfig:
|
||||
" Disabling `torch.compile`.")
|
||||
self.compilation_config.level = CompilationLevel.NO_COMPILATION
|
||||
|
||||
if self.lora_config is not None and self.compilation_config.level !=\
|
||||
CompilationLevel.NO_COMPILATION:
|
||||
logger.warning("LoRA is not supported with `torch.compile` yet. "
|
||||
"Disabling `torch.compile`.")
|
||||
if ((not envs.VLLM_USE_V1) and self.lora_config is not None
|
||||
and self.compilation_config.level
|
||||
!= CompilationLevel.NO_COMPILATION):
|
||||
logger.warning(
|
||||
"LoRA for V0 is not supported with `torch.compile` yet. "
|
||||
"Disabling `torch.compile`.")
|
||||
self.compilation_config.level = CompilationLevel.NO_COMPILATION
|
||||
|
||||
|
||||
if self.model_config and self.model_config.use_mla and \
|
||||
not (current_platform.is_cuda() or current_platform.is_rocm()):
|
||||
logger.info(
|
||||
|
||||
Reference in New Issue
Block a user