[ROCm][Test] Fix ROCM_AITER_UNIFIED_ATTN attn+quant fusion test (#37640)

Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
This commit is contained in:
vllmellm
2026-03-25 13:06:15 +08:00
committed by GitHub
parent a32783bb35
commit 42e9547976

View File

@@ -53,6 +53,7 @@ class AttentionQuantPatternModel(torch.nn.Module):
kv_cache_dtype: torch.dtype,
device: torch.device,
vllm_config: VllmConfig,
block_size: int,
**kwargs,
):
super().__init__()
@@ -74,7 +75,7 @@ class AttentionQuantPatternModel(torch.nn.Module):
self.attn._k_scale = self.attn._k_scale.to(device)
self.attn._v_scale = self.attn._v_scale.to(device)
self.block_size = 16
self.block_size = block_size
# Initialize attn MetadataBuilder
self.builder = self.attn.attn_backend.get_builder_cls()(
@@ -299,6 +300,9 @@ def test_attention_quant_pattern(
torch.set_default_dtype(dtype)
torch.manual_seed(42)
backend_cls = backend.get_class()
block_size = backend_cls.get_preferred_block_size(16)
model_config = ModelConfig(
model=model_name,
max_model_len=2048,
@@ -342,6 +346,7 @@ def test_attention_quant_pattern(
kv_cache_dtype=FP8_DTYPE,
device=device,
vllm_config=vllm_config_unfused,
block_size=block_size,
)
model_unfused = model_unfused.to(device)
result_unfused_0 = model_unfused(q, k, v) # noqa: F841 HACK: See #131044
@@ -370,6 +375,7 @@ def test_attention_quant_pattern(
device=device,
vllm_config=vllm_config,
w=model_unfused.w,
block_size=block_size,
)
model_fused = model_fused.to(device)