diff --git a/tests/compile/passes/test_fusion_attn.py b/tests/compile/passes/test_fusion_attn.py index 5c2d03213..94014ca01 100644 --- a/tests/compile/passes/test_fusion_attn.py +++ b/tests/compile/passes/test_fusion_attn.py @@ -53,6 +53,7 @@ class AttentionQuantPatternModel(torch.nn.Module): kv_cache_dtype: torch.dtype, device: torch.device, vllm_config: VllmConfig, + block_size: int, **kwargs, ): super().__init__() @@ -74,7 +75,7 @@ class AttentionQuantPatternModel(torch.nn.Module): self.attn._k_scale = self.attn._k_scale.to(device) self.attn._v_scale = self.attn._v_scale.to(device) - self.block_size = 16 + self.block_size = block_size # Initialize attn MetadataBuilder self.builder = self.attn.attn_backend.get_builder_cls()( @@ -299,6 +300,9 @@ def test_attention_quant_pattern( torch.set_default_dtype(dtype) torch.manual_seed(42) + backend_cls = backend.get_class() + block_size = backend_cls.get_preferred_block_size(16) + model_config = ModelConfig( model=model_name, max_model_len=2048, @@ -342,6 +346,7 @@ def test_attention_quant_pattern( kv_cache_dtype=FP8_DTYPE, device=device, vllm_config=vllm_config_unfused, + block_size=block_size, ) model_unfused = model_unfused.to(device) result_unfused_0 = model_unfused(q, k, v) # noqa: F841 HACK: See #131044 @@ -370,6 +375,7 @@ def test_attention_quant_pattern( device=device, vllm_config=vllm_config, w=model_unfused.w, + block_size=block_size, ) model_fused = model_fused.to(device)