[ROCm][Test] Fix ROCM_AITER_UNIFIED_ATTN attn+quant fusion test (#37640)
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
This commit is contained in:
@@ -53,6 +53,7 @@ class AttentionQuantPatternModel(torch.nn.Module):
|
||||
kv_cache_dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
vllm_config: VllmConfig,
|
||||
block_size: int,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
@@ -74,7 +75,7 @@ class AttentionQuantPatternModel(torch.nn.Module):
|
||||
self.attn._k_scale = self.attn._k_scale.to(device)
|
||||
self.attn._v_scale = self.attn._v_scale.to(device)
|
||||
|
||||
self.block_size = 16
|
||||
self.block_size = block_size
|
||||
|
||||
# Initialize attn MetadataBuilder
|
||||
self.builder = self.attn.attn_backend.get_builder_cls()(
|
||||
@@ -299,6 +300,9 @@ def test_attention_quant_pattern(
|
||||
torch.set_default_dtype(dtype)
|
||||
torch.manual_seed(42)
|
||||
|
||||
backend_cls = backend.get_class()
|
||||
block_size = backend_cls.get_preferred_block_size(16)
|
||||
|
||||
model_config = ModelConfig(
|
||||
model=model_name,
|
||||
max_model_len=2048,
|
||||
@@ -342,6 +346,7 @@ def test_attention_quant_pattern(
|
||||
kv_cache_dtype=FP8_DTYPE,
|
||||
device=device,
|
||||
vllm_config=vllm_config_unfused,
|
||||
block_size=block_size,
|
||||
)
|
||||
model_unfused = model_unfused.to(device)
|
||||
result_unfused_0 = model_unfused(q, k, v) # noqa: F841 HACK: See #131044
|
||||
@@ -370,6 +375,7 @@ def test_attention_quant_pattern(
|
||||
device=device,
|
||||
vllm_config=vllm_config,
|
||||
w=model_unfused.w,
|
||||
block_size=block_size,
|
||||
)
|
||||
model_fused = model_fused.to(device)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user