[V0] Correct CUDA Graph capture for encoder-decoder models (#22630)
This commit is contained in:
@@ -1164,8 +1164,18 @@ class ModelConfig:
|
|||||||
"non-quantized models.", self.quantization)
|
"non-quantized models.", self.quantization)
|
||||||
|
|
||||||
def _verify_cuda_graph(self) -> None:
|
def _verify_cuda_graph(self) -> None:
|
||||||
|
# The `max_seq_len_to_capture` was incorrectly
|
||||||
|
# based on the encoder's input length (448)
|
||||||
|
# but not the decoder's larger input length (1500).
|
||||||
|
# This change ensures the CUDA Graph captures the correct,
|
||||||
|
# larger sequence length, allowing it to work as intended.
|
||||||
|
effective_max_seq_len = self.max_model_len
|
||||||
|
if self.is_encoder_decoder:
|
||||||
|
effective_max_seq_len = max(
|
||||||
|
effective_max_seq_len,
|
||||||
|
getattr(self.hf_config, "max_source_positions", 0))
|
||||||
self.max_seq_len_to_capture = min(self.max_seq_len_to_capture,
|
self.max_seq_len_to_capture = min(self.max_seq_len_to_capture,
|
||||||
self.max_model_len)
|
effective_max_seq_len)
|
||||||
# CUDAGraph capture not supported for enc-dec models and mllama on ROCm
|
# CUDAGraph capture not supported for enc-dec models and mllama on ROCm
|
||||||
ROCM_UNSUPPORTED_MODELS = ['mllama']
|
ROCM_UNSUPPORTED_MODELS = ['mllama']
|
||||||
unsupported_rocm = (self.hf_config.model_type
|
unsupported_rocm = (self.hf_config.model_type
|
||||||
|
|||||||
Reference in New Issue
Block a user