[Bugfix] Ensure calculated KV scales are applied in attention. (#27232)
Signed-off-by: adabeyta <aabeyta@redhat.com>
(cherry picked from commit a5a790eea6)
This commit is contained in:
committed by
Kevin H. Luu
parent
30700b1cd7
commit
75ecaf48fe
@@ -471,7 +471,7 @@ steps:
|
|||||||
- vllm/
|
- vllm/
|
||||||
- tests/compile
|
- tests/compile
|
||||||
commands:
|
commands:
|
||||||
- pytest -v -s compile/test_full_graph.py
|
- pytest -v -s compile/test_full_graph.py -k 'not test_fp8_kv_scale_compile'
|
||||||
# Limit to no custom ops to reduce running time
|
# Limit to no custom ops to reduce running time
|
||||||
# Wrap with quotes to escape yaml and avoid starting -k string with a -
|
# Wrap with quotes to escape yaml and avoid starting -k string with a -
|
||||||
- "pytest -v -s compile/test_fusions_e2e.py -k 'TRITON and -quant_fp8'"
|
- "pytest -v -s compile/test_fusions_e2e.py -k 'TRITON and -quant_fp8'"
|
||||||
@@ -951,10 +951,13 @@ steps:
|
|||||||
- vllm/model_executor/layers/activation.py
|
- vllm/model_executor/layers/activation.py
|
||||||
- vllm/model_executor/layers/quantization/input_quant_fp8.py
|
- vllm/model_executor/layers/quantization/input_quant_fp8.py
|
||||||
- tests/compile/test_fusions_e2e.py
|
- tests/compile/test_fusions_e2e.py
|
||||||
|
- tests/compile/test_full_graph.py
|
||||||
commands:
|
commands:
|
||||||
- nvidia-smi
|
- nvidia-smi
|
||||||
# Run all e2e fusion tests
|
# Run all e2e fusion tests
|
||||||
- pytest -v -s tests/compile/test_fusions_e2e.py
|
- pytest -v -s tests/compile/test_fusions_e2e.py
|
||||||
|
# test_fp8_kv_scale_compile requires FlashAttention (not supported on default L4/L40)
|
||||||
|
- pytest -v -s tests/compile/test_full_graph.py::test_fp8_kv_scale_compile
|
||||||
|
|
||||||
- label: Blackwell GPT-OSS Eval
|
- label: Blackwell GPT-OSS Eval
|
||||||
timeout_in_minutes: 60
|
timeout_in_minutes: 60
|
||||||
|
|||||||
@@ -183,8 +183,14 @@ def test_custom_compile_config(
|
|||||||
"compilation_mode",
|
"compilation_mode",
|
||||||
[CompilationMode.NONE, CompilationMode.VLLM_COMPILE],
|
[CompilationMode.NONE, CompilationMode.VLLM_COMPILE],
|
||||||
)
|
)
|
||||||
def test_fp8_kv_scale_compile(compilation_mode: int):
|
@pytest.mark.parametrize(
|
||||||
model = "Qwen/Qwen2-0.5B"
|
"model",
|
||||||
|
[
|
||||||
|
"Qwen/Qwen2-0.5B", # Standard attention model
|
||||||
|
"deepseek-ai/DeepSeek-V2-Lite", # MLA (Multi-head Latent Attention) model
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_fp8_kv_scale_compile(compilation_mode: int, model: str):
|
||||||
model_kwargs = {
|
model_kwargs = {
|
||||||
"quantization": "fp8",
|
"quantization": "fp8",
|
||||||
"kv_cache_dtype": "fp8_e4m3",
|
"kv_cache_dtype": "fp8_e4m3",
|
||||||
|
|||||||
@@ -745,6 +745,9 @@ class MLAAttention(nn.Module, AttentionLayerBase):
|
|||||||
k_pe: torch.Tensor,
|
k_pe: torch.Tensor,
|
||||||
output_shape: torch.Size | None = None,
|
output_shape: torch.Size | None = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
if self.calculate_kv_scales:
|
||||||
|
torch.ops.vllm.maybe_calc_kv_scales(q, kv_c_normed, k_pe, self.layer_name)
|
||||||
|
|
||||||
if self.use_direct_call:
|
if self.use_direct_call:
|
||||||
forward_context: ForwardContext = get_forward_context()
|
forward_context: ForwardContext = get_forward_context()
|
||||||
attn_metadata = forward_context.attn_metadata
|
attn_metadata = forward_context.attn_metadata
|
||||||
@@ -752,12 +755,6 @@ class MLAAttention(nn.Module, AttentionLayerBase):
|
|||||||
attn_metadata = attn_metadata[self.layer_name]
|
attn_metadata = attn_metadata[self.layer_name]
|
||||||
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
|
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||||
|
|
||||||
# Mirror Attention.forward scale calculation path
|
|
||||||
if self.calculate_kv_scales and getattr(
|
|
||||||
attn_metadata, "enable_kv_scales_calculation", False
|
|
||||||
):
|
|
||||||
self.calc_kv_scales(q, kv_c_normed, k_pe)
|
|
||||||
|
|
||||||
if self.attn_backend.accept_output_buffer:
|
if self.attn_backend.accept_output_buffer:
|
||||||
output = torch.empty(output_shape, dtype=q.dtype, device=q.device)
|
output = torch.empty(output_shape, dtype=q.dtype, device=q.device)
|
||||||
self.impl.forward(
|
self.impl.forward(
|
||||||
@@ -786,14 +783,6 @@ class MLAAttention(nn.Module, AttentionLayerBase):
|
|||||||
)
|
)
|
||||||
return output
|
return output
|
||||||
else:
|
else:
|
||||||
# We can still access forward context to check calculation flag
|
|
||||||
if self.calculate_kv_scales:
|
|
||||||
forward_context = get_forward_context()
|
|
||||||
attn_metadata = forward_context.attn_metadata
|
|
||||||
if isinstance(attn_metadata, dict):
|
|
||||||
attn_metadata = attn_metadata[self.layer_name]
|
|
||||||
if getattr(attn_metadata, "enable_kv_scales_calculation", False):
|
|
||||||
self.calc_kv_scales(q, kv_c_normed, k_pe)
|
|
||||||
return torch.ops.vllm.unified_mla_attention(
|
return torch.ops.vllm.unified_mla_attention(
|
||||||
q,
|
q,
|
||||||
kv_c_normed,
|
kv_c_normed,
|
||||||
@@ -881,17 +870,13 @@ def maybe_calc_kv_scales(
|
|||||||
layer_name: str,
|
layer_name: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
forward_context: ForwardContext = get_forward_context()
|
forward_context: ForwardContext = get_forward_context()
|
||||||
attn_metadata = forward_context.attn_metadata
|
self = forward_context.no_compile_layers[layer_name]
|
||||||
|
|
||||||
if isinstance(attn_metadata, dict):
|
# Only calculate if the layer's calculate_kv_scales flag is True
|
||||||
attn_metadata = attn_metadata[layer_name]
|
# This flag gets set to False after the first forward pass
|
||||||
|
if not self.calculate_kv_scales:
|
||||||
if attn_metadata is None or not getattr(
|
|
||||||
attn_metadata, "enable_kv_scales_calculation", False
|
|
||||||
):
|
|
||||||
return
|
return
|
||||||
|
|
||||||
self = forward_context.no_compile_layers[layer_name]
|
|
||||||
self.calc_kv_scales(query, key, value)
|
self.calc_kv_scales(query, key, value)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -279,6 +279,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
# This will be overridden in load_model()
|
# This will be overridden in load_model()
|
||||||
self.is_multimodal_pruning_enabled = False
|
self.is_multimodal_pruning_enabled = False
|
||||||
self.max_model_len = model_config.max_model_len
|
self.max_model_len = model_config.max_model_len
|
||||||
|
|
||||||
|
# Always set to false after the first forward pass
|
||||||
|
self.calculate_kv_scales = self.cache_config.calculate_kv_scales
|
||||||
self.dcp_world_size = self.parallel_config.decode_context_parallel_size
|
self.dcp_world_size = self.parallel_config.decode_context_parallel_size
|
||||||
self.dcp_rank = 0 if self.dcp_world_size <= 1 else get_dcp_group().rank_in_group
|
self.dcp_rank = 0 if self.dcp_world_size <= 1 else get_dcp_group().rank_in_group
|
||||||
self.max_num_tokens = scheduler_config.max_num_batched_tokens
|
self.max_num_tokens = scheduler_config.max_num_batched_tokens
|
||||||
@@ -2625,16 +2628,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Set cudagraph mode to none if calc_kv_scales is true.
|
# Set cudagraph mode to none if calc_kv_scales is true.
|
||||||
if attn_metadata is not None:
|
# KV scales calculation involves dynamic operations that are incompatible
|
||||||
metadata_list = (
|
# with CUDA graph capture.
|
||||||
attn_metadata.values()
|
if self.calculate_kv_scales:
|
||||||
if isinstance(attn_metadata, dict)
|
|
||||||
else [attn_metadata]
|
|
||||||
)
|
|
||||||
if any(
|
|
||||||
getattr(m, "enable_kv_scales_calculation", False) for m in metadata_list
|
|
||||||
):
|
|
||||||
cudagraph_runtime_mode = CUDAGraphMode.NONE
|
cudagraph_runtime_mode = CUDAGraphMode.NONE
|
||||||
|
# Mark KV scales as calculated after the first forward pass
|
||||||
|
self.calculate_kv_scales = False
|
||||||
|
|
||||||
# Run the model.
|
# Run the model.
|
||||||
# Use persistent buffers for CUDA graphs.
|
# Use persistent buffers for CUDA graphs.
|
||||||
|
|||||||
Reference in New Issue
Block a user