diff --git a/.buildkite/scripts/hardware_ci/run-amd-test.sh b/.buildkite/scripts/hardware_ci/run-amd-test.sh index 58fd43569..0e5b21ddf 100755 --- a/.buildkite/scripts/hardware_ci/run-amd-test.sh +++ b/.buildkite/scripts/hardware_ci/run-amd-test.sh @@ -78,17 +78,13 @@ HF_MOUNT="/root/.cache/huggingface" commands=$@ echo "Commands:$commands" -if [[ $commands == *"pytest -v -s basic_correctness/test_basic_correctness.py"* ]]; then - commands=${commands//"pytest -v -s basic_correctness/test_basic_correctness.py"/"VLLM_USE_TRITON_FLASH_ATTN=0 pytest -v -s basic_correctness/test_basic_correctness.py"} -fi +commands=${commands//"pytest -v -s basic_correctness/test_basic_correctness.py"/"pytest -v -s basic_correctness/test_basic_correctness.py"} if [[ $commands == *"pytest -v -s models/test_registry.py"* ]]; then commands=${commands//"pytest -v -s models/test_registry.py"/"pytest -v -s models/test_registry.py -k 'not BambaForCausalLM and not GritLM and not Mamba2ForCausalLM and not Zamba2ForCausalLM'"} fi -if [[ $commands == *"pytest -v -s compile/test_basic_correctness.py"* ]]; then - commands=${commands//"pytest -v -s compile/test_basic_correctness.py"/"VLLM_USE_TRITON_FLASH_ATTN=0 pytest -v -s compile/test_basic_correctness.py"} -fi +commands=${commands//"pytest -v -s compile/test_basic_correctness.py"/"pytest -v -s compile/test_basic_correctness.py"} if [[ $commands == *"pytest -v -s lora"* ]]; then commands=${commands//"pytest -v -s lora"/"VLLM_ROCM_CUSTOM_PAGED_ATTN=0 pytest -v -s lora"} diff --git a/tests/kernels/test_triton_flash_attention.py b/tests/kernels/test_triton_flash_attention.py deleted file mode 100644 index 4b0bbb992..000000000 --- a/tests/kernels/test_triton_flash_attention.py +++ /dev/null @@ -1,516 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Tests for the triton_flash_attention kernel - -Run `pytest tests/kernels/test_triton_flash_attention.py`. -""" - -import pytest -import torch - -from vllm.attention.ops.triton_flash_attention import ( - SUPPORTED_LAYOUTS, - MetaData, - compute_alibi_tensor, - scale_fp8, - triton_attention_rocm, -) -from vllm.platforms import current_platform - - -class ReferenceAttention: - def __init__( - self, Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, use_alibi, dtype, input_metadata - ): - self.Z = Z - self.HQ = HQ - self.HK = HK - self.N_CTX_Q = N_CTX_Q - self.N_CTX_K = N_CTX_K - self.D_HEAD = D_HEAD - self.use_alibi = use_alibi - self.dtype = dtype - self.input_metadata = input_metadata - - def fwd(self, q, k, v): - scores = ( - torch.einsum("bhqd,bhkd->bhqk", q, k).float() * self.input_metadata.sm_scale - ) - if self.input_metadata.causal: - mask = torch.tril( - torch.ones(self.N_CTX_Q, self.N_CTX_K, device="cuda"), - diagonal=self.N_CTX_K - self.N_CTX_Q, - ) - scores[:, :, mask == 0] = float("-inf") - - if self.input_metadata.bias is not None: - scores += self.input_metadata.bias - - if self.use_alibi: - scores += compute_alibi_tensor( - self.input_metadata.alibi_slopes, self.N_CTX_Q, self.N_CTX_K - ) - - p = torch.softmax(scores, dim=-1) - if self.input_metadata.causal: - # If N_CTX_Q > N_CTX_K, there's at least one row of all -infs going - # into softmax. This creates a row of NaNs as -inf - -inf == NaN. - # So we fix this by converting the NaNs to 0s, which is what they - # should be out of the softmax. - nan_mask = torch.isnan(p) - p[nan_mask == 1] = 0 - ref_out = torch.einsum("bhqk,bhkd->bhqd", p.to(self.dtype), v) - # compare - if self.input_metadata.layout == "bshd": - ref_out = ref_out.transpose(1, 2).clone() - return ref_out - - def fwd_fp8(self, q_quantized, k_quantized, v_quantized): - q = (q_quantized.to(torch.float16) * self.input_metadata.q_descale).to( - self.dtype - ) - k = (k_quantized.to(torch.float16) * self.input_metadata.k_descale).to( - self.dtype - ) - v = (v_quantized.to(torch.float16) * self.input_metadata.v_descale).to( - self.dtype - ) - result = self.fwd(q, k, v) - if self.input_metadata.o_scale is not None: - result, _ = scale_fp8(result, self.input_metadata.o_scale) - return result - - def fwd_fp8_kv(self, q, k_quantized, v_quantized): - k_descale, v_descale = ( - self.input_metadata.k_descale, - self.input_metadata.v_descale, - ) - k_dequantized = ( - k_quantized.to(torch.float32) * k_descale.to(torch.float32) - ).to(self.dtype) - v_dequantized = ( - v_quantized.to(torch.float32) * v_descale.to(torch.float32) - ).to(self.dtype) - return self.fwd(q, k_dequantized, v_dequantized) - - def varlen_fwd(self, q, k, v, is_mqa=False): - ref_out = torch.empty_like(q) - if is_mqa: - # Make KV look like HQ/HK "groups" of HK. Later, we will reshape so - # the size aligns with Q. - k_ref = k.view(k.shape[0], k.shape[1], 1, k.shape[2]).expand( - -1, -1, self.HQ // self.HK, -1 - ) - v_ref = v.view(v.shape[0], v.shape[1], 1, v.shape[2]).expand( - -1, -1, self.HQ // self.HK, -1 - ) - else: - k_ref = k - v_ref = v - - for i in range(0, self.input_metadata.num_contexts): - start_q, start_k = ( - self.input_metadata.cu_seqlens_q[i], - self.input_metadata.cu_seqlens_k[i], - ) - end_q, end_k = ( - self.input_metadata.cu_seqlens_q[i + 1], - self.input_metadata.cu_seqlens_k[i + 1], - ) - k_curr = k_ref[start_k:end_k] - v_curr = v_ref[start_k:end_k] - if is_mqa: - k_curr = k_curr.reshape(k_curr.shape[0], -1, k_curr.shape[3]) - v_curr = v_curr.reshape(v_curr.shape[0], -1, v_curr.shape[3]) - scores = torch.einsum("qhd,khd->qhk", q[start_q:end_q], k_curr).float() - p = torch.softmax(scores * self.input_metadata.sm_scale, dim=-1).half() - ref_out[start_q:end_q] = torch.einsum("qhk,khd->qhd", p, v_curr) - return ref_out - - -def quantize_input(q, k, v, fp8_kv=False, use_o_scale=False): - q_descale = None - if not fp8_kv: - q, q_descale = scale_fp8(q) - k, k_descale = scale_fp8(k) - v, v_descale = scale_fp8(v) - - # In real world use case, the p scale would be a parameter trained by the - # model. - p_scale = None - - o_scale = torch.rand(1, device="cuda", requires_grad=False) if use_o_scale else None - - return q, k, v, q_descale, k_descale, v_descale, p_scale, o_scale - - -def input_helper( - Z, - HQ, - HK, - N_CTX_Q, - N_CTX_K, - D_HEAD, - dtype, - layout=None, - use_alibi=None, - causal=None, - is_fp8=False, - fp8_kv=False, - use_o_scale=False, - use_bias=False, -): - assert layout in SUPPORTED_LAYOUTS, "Got unsupported layout." - - current_platform.seed_everything(0) - - # Initialize q, k, v - if layout == "bhsd": - q_tensor_shape = (Z, HQ, N_CTX_Q, D_HEAD) - k_tensor_shape = (Z, HK, N_CTX_K, D_HEAD) - elif layout == "bshd": - q_tensor_shape = (Z, N_CTX_Q, HQ, D_HEAD) - k_tensor_shape = (Z, N_CTX_K, HK, D_HEAD) - - if use_alibi: - # for n heads the set of slopes is the geometric sequence that starts - # 2^(-8/n) - alibi_slopes = torch.tensor( - [2 ** (-8 / HQ * i) for i in range(1, HQ + 1)], - dtype=torch.float32, - device="cuda", - ).repeat(Z, 1) - else: - alibi_slopes = None - - if use_bias: - bias = torch.randn( - (1, HQ, N_CTX_Q, N_CTX_K), dtype=dtype, device="cuda", requires_grad=False - ) - else: - bias = None - - q = torch.randn(q_tensor_shape, dtype=dtype, device="cuda", requires_grad=False) - k = torch.randn(k_tensor_shape, dtype=dtype, device="cuda", requires_grad=False) - v = torch.randn(k_tensor_shape, dtype=dtype, device="cuda", requires_grad=False) - - if is_fp8: - (q, k, v, q_descale, k_descale, v_descale, p_scale, o_scale) = quantize_input( - q, k, v, use_o_scale=use_o_scale, fp8_kv=fp8_kv - ) - else: - q_descale = k_descale = v_descale = p_scale = o_scale = None - - input_metadata = MetaData( - sm_scale=D_HEAD**-0.5, - max_seqlens_q=N_CTX_Q, - max_seqlens_k=N_CTX_K, - layout=layout, - alibi_slopes=alibi_slopes, - alibi_batch=Z, - alibi_nheads=HQ, - q_descale=q_descale, - k_descale=k_descale, - v_descale=v_descale, - p_scale=p_scale, - o_scale=o_scale, - bias=bias, - seqlen_q=N_CTX_Q, - seqlen_k=N_CTX_K, - ) - return q, k, v, input_metadata - - -def varlen_input_helper( - Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, equal_seqlens=False -): - current_platform.seed_everything(0) - - # Random sequence lengths. Using N_CTX as kind of max of sum of individual - # seqs - if not equal_seqlens: - max_seqlens_q = N_CTX_Q // Z - max_seqlens_k = N_CTX_K // Z - seqlens_q = torch.randint(1, max_seqlens_q + 1, (Z,), dtype=torch.int32) - seqlens_k = torch.randint(1, max_seqlens_k + 1, (Z,), dtype=torch.int32) - else: - seqlens_q = torch.full((Z,), N_CTX_Q // Z) - seqlens_k = torch.full((Z,), N_CTX_K // Z) - - # Calculate cumulative sequence lengths - cu_seqlens_q = torch.cat( - [ - torch.tensor([0], dtype=torch.int32), - seqlens_q.cumsum(dim=0, dtype=torch.int32), - ] - ) - cu_seqlens_k = torch.cat( - [ - torch.tensor([0], dtype=torch.int32), - seqlens_k.cumsum(dim=0, dtype=torch.int32), - ] - ) - cu_seqlens_q = cu_seqlens_q.to(device="cuda") - cu_seqlens_k = cu_seqlens_k.to(device="cuda") - - # Initialize q, k, v with variable lengths - total_q = cu_seqlens_q[-1].item() - total_k = cu_seqlens_k[-1].item() - q = ( - torch.randn((total_q, HQ, D_HEAD), dtype=dtype, device="cuda") - .normal_(mean=0.0, std=0.5) - .requires_grad_() - ) - k = ( - torch.randn((total_k, HK, D_HEAD), dtype=dtype, device="cuda") - .normal_(mean=0.0, std=0.5) - .requires_grad_() - ) - v = ( - torch.randn((total_k, HK, D_HEAD), dtype=dtype, device="cuda") - .normal_(mean=0.0, std=0.5) - .requires_grad_() - ) - sm_scale = D_HEAD**-0.5 - input_metadata = MetaData(sm_scale=sm_scale) - input_metadata.set_varlen_params(cu_seqlens_q, cu_seqlens_k) - return q, k, v, input_metadata - - -@pytest.mark.parametrize( - "Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD", - [ - (1, 48, 12, 1, 1, 64), - (4, 4, 4, 128, 128, 65), - (16, 48, 48, 1, 1, 128), - (64, 48, 24, 3, 3, 128), - (4, 4, 4, 113, 123, 1), - ], -) -@pytest.mark.parametrize("causal", [True, False]) -@pytest.mark.parametrize("use_alibi", [True, False]) -@pytest.mark.parametrize("layout", ["bshd"]) -def test_op_fwd( - Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_alibi, layout, dtype=torch.float16 -): - current_platform.seed_everything(0) - q, k, v, input_metadata = input_helper( - Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, layout, use_alibi, causal - ) - - o = torch.empty_like(q) - - # triton implementation - tri_out, _ = triton_attention_rocm(q, k, v, o, input_metadata) - - # Transpose here if layout is bshd so we have same reference code for all - # layouts - if layout == "bshd": - q = q.transpose(1, 2).clone() - k = k.transpose(1, 2).clone() - v = v.transpose(1, 2).clone() - # Replicate K and V if using MQA/GQA - if HQ != HK: - k = ( - k.view(k.shape[0], k.shape[1], -1, k.shape[2], k.shape[3]) - .expand(-1, -1, HQ // HK, -1, -1) - .reshape(k.shape[0], -1, k.shape[2], k.shape[3]) - ) - v = ( - v.view(v.shape[0], v.shape[1], -1, v.shape[2], v.shape[3]) - .expand(-1, -1, HQ // HK, -1, -1) - .reshape(v.shape[0], -1, v.shape[2], v.shape[3]) - ) - - ref_impl = ReferenceAttention( - Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, use_alibi, dtype, input_metadata - ) - ref_out = ref_impl.fwd(q, k, v) - - torch.testing.assert_close(ref_out, tri_out, atol=2e-2, rtol=2e-2) - - -@pytest.mark.parametrize( - "Z, H, N_CTX_Q, N_CTX_K, D_HEAD", - [ - (4, 48, 1, 1, 64), - (4, 48, 1, 1, 128), - (4, 48, 3, 3, 128), - (4, 4, 128, 128, 65), - ], -) -@pytest.mark.parametrize("causal", [True, False]) -@pytest.mark.parametrize("layout", ["bhsd"]) -@pytest.mark.parametrize("use_o_scale", [True, False]) -@pytest.mark.skipif( - torch.cuda.get_device_capability() < (9, 0), - reason="Triton FP8 requires CUDA 9.0 or higher", -) -def test_op_fwd_fp8( - Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, layout, use_o_scale, dtype=torch.float32 -): - current_platform.seed_everything(0) - - # Disable grad to save memory it won't run into OOM on CI machine. - # q, k, v, input_metadata = input_helper(Z, H, H, N_CTX_Q, N_CTX_K, D_HEAD, - # dtype, layout) - - q_quantized, k_quantized, v_quantized, input_metadata = input_helper( - Z, - H, - H, - N_CTX_Q, - N_CTX_K, - D_HEAD, - dtype, - causal=causal, - layout=layout, - is_fp8=True, - use_o_scale=use_o_scale, - ) - - o = torch.empty_like(q_quantized) if use_o_scale else None - - tri_out, _ = triton_attention_rocm( - q_quantized, k_quantized, v_quantized, o, input_metadata - ) - - ref_impl = ReferenceAttention( - Z, H, H, N_CTX_Q, N_CTX_K, D_HEAD, False, dtype, input_metadata - ) - ref_out = ref_impl.fwd_fp8(q_quantized, k_quantized, v_quantized) - - # compare - torch.testing.assert_close( - ref_out.to(torch.float32), tri_out.to(torch.float32), atol=7e-2, rtol=2e-1 - ) - - -@pytest.mark.parametrize( - "Z, H, N_CTX_Q, N_CTX_K, D_HEAD", - [ - (4, 48, 1, 1, 64), - (4, 48, 1, 1, 128), - (4, 48, 3, 3, 128), - (4, 4, 128, 128, 65), - (4, 4, 113, 123, 1), - ], -) -@pytest.mark.parametrize("causal", [True, False]) -@pytest.mark.parametrize("layout", ["bhsd"]) -def test_op_fwd_fp8_kv( - Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, layout, dtype=torch.float32 -): - current_platform.seed_everything(0) - - q, k_quantized, v_quantized, input_metadata = input_helper( - Z, - H, - H, - N_CTX_Q, - N_CTX_K, - D_HEAD, - dtype, - causal=causal, - layout=layout, - is_fp8=True, - fp8_kv=True, - ) - - o = torch.empty_like(q) - - tri_out, _ = triton_attention_rocm(q, k_quantized, v_quantized, o, input_metadata) - - ref_impl = ReferenceAttention( - Z, H, H, N_CTX_Q, N_CTX_K, D_HEAD, False, dtype, input_metadata - ) - ref_out = ref_impl.fwd_fp8_kv(q, k_quantized, v_quantized) - - torch.testing.assert_close(ref_out, tri_out, atol=3e-2, rtol=8e-1) - - -@pytest.mark.parametrize( - "Z, H, N_CTX_Q, N_CTX_K, D_HEAD", - [ - (4, 48, 1, 1, 64), - (4, 48, 1, 1, 128), - (4, 48, 3, 3, 128), - (4, 4, 128, 128, 65), - ], -) -@pytest.mark.parametrize("causal", [True, False]) -@pytest.mark.parametrize("use_bias", [True]) -@pytest.mark.parametrize("dtype", [torch.bfloat16]) -def test_op_fwd_bias(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_bias, dtype): - current_platform.seed_everything(0) - q, k, v, input_metadata = input_helper( - Z, - H, - H, - N_CTX_Q, - N_CTX_K, - D_HEAD, - dtype, - layout="bhsd", - causal=causal, - use_bias=use_bias, - ) - o = torch.empty_like(q) - - # triton implementation - tri_out, _ = triton_attention_rocm(q, k, v, o, input_metadata) - - ref_impl = ReferenceAttention( - Z, H, H, N_CTX_Q, N_CTX_K, D_HEAD, False, dtype, input_metadata - ) - ref_out = ref_impl.fwd(q, k, v) - - # compare - torch.testing.assert_close(ref_out, tri_out, atol=2e-2, rtol=2e-2) - - -# NOTE: Uses thd layout, so also tests thd. -@pytest.mark.parametrize( - "Z, H, N_CTX, D_HEAD", - [(1, 48, 256, 64), (4, 48, 512, 64), (16, 48, 512, 64), (64, 48, 128, 128)], -) -@pytest.mark.parametrize("causal", [True, False]) -def test_op_varlen_fwd(Z, H, N_CTX, D_HEAD, causal, dtype=torch.float16): - q, k, v, input_metadata = varlen_input_helper(Z, H, H, N_CTX, N_CTX, D_HEAD, dtype) - - tri_out = torch.empty_like(q) - triton_attention_rocm(q, k, v, tri_out, input_metadata) - - ref_impl = ReferenceAttention( - Z, H, H, N_CTX, N_CTX, D_HEAD, False, dtype, input_metadata - ) - ref_out = ref_impl.varlen_fwd(q, k, v, is_mqa=False) - - torch.testing.assert_close(ref_out, tri_out, atol=2e-2, rtol=2e-2) - - -# NOTE: Uses thd layout, so also tests thd. -@pytest.mark.parametrize( - "Z, HQ, HK, N_CTX, D_HEAD", - [ - (2, 48, 24, 128, 64), - (4, 48, 12, 256, 64), - (4, 48, 4, 512, 64), - (4, 64, 16, 128, 128), - ], -) -@pytest.mark.parametrize("causal", [False]) -def test_op_varlen_mqa_fwd(Z, HQ, HK, N_CTX, D_HEAD, causal, dtype=torch.float16): - q, k, v, input_metadata = varlen_input_helper( - Z, HQ, HK, N_CTX, N_CTX, D_HEAD, dtype - ) - - tri_out = torch.empty_like(q) - triton_attention_rocm(q, k, v, tri_out, input_metadata) - - ref_impl = ReferenceAttention( - Z, HQ, HK, N_CTX, N_CTX, D_HEAD, False, dtype, input_metadata - ) - ref_out = ref_impl.varlen_fwd(q, k, v, is_mqa=True) - - torch.testing.assert_close(ref_out, tri_out, atol=2e-2, rtol=2e-2) diff --git a/tests/models/language/pooling/test_classification.py b/tests/models/language/pooling/test_classification.py index 471826f21..2723bb21d 100644 --- a/tests/models/language/pooling/test_classification.py +++ b/tests/models/language/pooling/test_classification.py @@ -27,13 +27,7 @@ def test_models( example_prompts, model: str, dtype: str, - monkeypatch, ) -> None: - if current_platform.is_rocm(): - # ROCm Triton FA does not currently support sliding window attention - # switch to use ROCm CK FA backend - monkeypatch.setenv("VLLM_USE_TRITON_FLASH_ATTN", "False") - with vllm_runner(model, max_model_len=512, dtype=dtype) as vllm_model: vllm_outputs = vllm_model.classify(example_prompts) diff --git a/tests/models/language/pooling/test_embedding.py b/tests/models/language/pooling/test_embedding.py index d1b1c8d3d..93b984331 100644 --- a/tests/models/language/pooling/test_embedding.py +++ b/tests/models/language/pooling/test_embedding.py @@ -4,7 +4,6 @@ import pytest from vllm.config import PoolerConfig -from vllm.platforms import current_platform from ...utils import check_embeddings_close @@ -51,13 +50,7 @@ def test_models( vllm_runner, example_prompts, model, - monkeypatch, ) -> None: - if model == "BAAI/bge-multilingual-gemma2" and current_platform.is_rocm(): - # ROCm Triton FA does not currently support sliding window attention - # switch to use ROCm CK FA backend - monkeypatch.setenv("VLLM_USE_TRITON_FLASH_ATTN", "False") - vllm_extra_kwargs = {} if model == "ssmits/Qwen2-7B-Instruct-embed-base": vllm_extra_kwargs["pooler_config"] = PoolerConfig( diff --git a/tests/models/language/pooling/test_mm_classifier_conversion.py b/tests/models/language/pooling/test_mm_classifier_conversion.py index 91be6cd09..248245264 100644 --- a/tests/models/language/pooling/test_mm_classifier_conversion.py +++ b/tests/models/language/pooling/test_mm_classifier_conversion.py @@ -2,18 +2,11 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from vllm.config.pooler import PoolerConfig -from vllm.platforms import current_platform def test_idefics_multimodal( vllm_runner, - monkeypatch, ) -> None: - if current_platform.is_rocm(): - # ROCm Triton FA does not currently support sliding window attention - # switch to use ROCm CK FA backend - monkeypatch.setenv("VLLM_USE_TRITON_FLASH_ATTN", "False") - prompts = [ "Hello, my name is", "The president of the United States is", @@ -59,13 +52,7 @@ def update_config(config): def test_gemma_multimodal( vllm_runner, - monkeypatch, ) -> None: - if current_platform.is_rocm(): - # ROCm Triton FA does not currently support sliding window attention - # switch to use ROCm CK FA backend - monkeypatch.setenv("VLLM_USE_TRITON_FLASH_ATTN", "False") - messages = [ { "role": "system", diff --git a/tests/models/language/pooling/test_reward.py b/tests/models/language/pooling/test_reward.py index 46504d025..c42186c7d 100644 --- a/tests/models/language/pooling/test_reward.py +++ b/tests/models/language/pooling/test_reward.py @@ -76,7 +76,6 @@ def test_prm_models( math_step_prompts, model: str, dtype: str, - monkeypatch, ) -> None: check_transformers_version( "Qwen/Qwen2.5-Math-PRM-7B", max_transformers_version="4.53.2" @@ -85,11 +84,6 @@ def test_prm_models( if current_platform.is_cpu(): pytest.skip("CPU only supports V1") - if current_platform.is_rocm(): - # ROCm Triton FA does not currently support sliding window attention - # switch to use ROCm CK FA backend - monkeypatch.setenv("VLLM_USE_TRITON_FLASH_ATTN", "False") - with vllm_runner(model, max_model_len=1024, dtype=dtype) as vllm_model: vllm_outputs = vllm_model.reward(math_step_prompts) diff --git a/tests/models/multimodal/generation/test_common.py b/tests/models/multimodal/generation/test_common.py index 4c79ac318..de74acf3b 100644 --- a/tests/models/multimodal/generation/test_common.py +++ b/tests/models/multimodal/generation/test_common.py @@ -5,7 +5,6 @@ image, embedding, and video support for different VLMs in vLLM. """ import math -import os from collections import defaultdict from pathlib import PosixPath @@ -38,13 +37,6 @@ from .vlm_utils.types import ( VLMTestType, ) -# This hack is needed for phi3v & paligemma models -# ROCm Triton FA can run into shared memory issues with these models, -# use other backends in the meantime -# FIXME (mattwong, gshtrasb, hongxiayan) -if current_platform.is_rocm(): - os.environ["VLLM_USE_TRITON_FLASH_ATTN"] = "0" - COMMON_BROADCAST_SETTINGS = { "test_type": VLMTestType.IMAGE, "dtype": "half", diff --git a/tests/models/multimodal/generation/test_phi4_multimodal.py b/tests/models/multimodal/generation/test_phi4_multimodal.py index cbc7dfca0..624562217 100644 --- a/tests/models/multimodal/generation/test_phi4_multimodal.py +++ b/tests/models/multimodal/generation/test_phi4_multimodal.py @@ -11,7 +11,6 @@ from huggingface_hub import snapshot_download from vllm.assets.image import ImageAsset from vllm.lora.request import LoRARequest from vllm.multimodal.image import rescale_image_size -from vllm.platforms import current_platform from ....conftest import ( IMAGE_ASSETS, @@ -46,12 +45,6 @@ models = [model_path] target_dtype = "half" -# ROCm Triton FA can run into shared memory issues with these models, -# use other backends in the meantime -# FIXME (mattwong, gshtrasb, hongxiayan) -if current_platform.is_rocm(): - os.environ["VLLM_USE_TRITON_FLASH_ATTN"] = "0" - def run_test( hf_runner: type[HfRunner], diff --git a/tests/models/multimodal/generation/test_phi4mm.py b/tests/models/multimodal/generation/test_phi4mm.py index 5619cecc0..665bed512 100644 --- a/tests/models/multimodal/generation/test_phi4mm.py +++ b/tests/models/multimodal/generation/test_phi4mm.py @@ -14,7 +14,6 @@ from vllm.assets.image import ImageAsset from vllm.logprobs import SampleLogprobs from vllm.lora.request import LoRARequest from vllm.multimodal.image import convert_image_mode, rescale_image_size -from vllm.platforms import current_platform from ....conftest import ( IMAGE_ASSETS, @@ -68,12 +67,6 @@ def vllm_to_hf_output( target_dtype = "half" -# ROCm Triton FA can run into shared memory issues with these models, -# use other backends in the meantime -# FIXME (mattwong, gshtrasb, hongxiayan) -if current_platform.is_rocm(): - os.environ["VLLM_USE_TRITON_FLASH_ATTN"] = "0" - def run_test( hf_runner: type[HfRunner], diff --git a/tests/quantization/test_quark.py b/tests/quantization/test_quark.py index 0af27aff9..334f9a65e 100644 --- a/tests/quantization/test_quark.py +++ b/tests/quantization/test_quark.py @@ -8,7 +8,6 @@ See also `tests/kernels/moe/test_ocp_mx_moe.py`. """ import importlib.metadata -import os from dataclasses import dataclass from importlib.util import find_spec @@ -246,8 +245,6 @@ def test_mxfp4_gsm8k_correctness(config: AccuracyTestConfig): task = "gsm8k" rtol = 0.03 - os.environ["VLLM_USE_TRITON_FLASH_ATTN"] = "0" - results = lm_eval.simple_evaluate( model="vllm", model_args=config.get_model_args(tp_size=8, model_max_len=38768), @@ -263,8 +260,6 @@ def test_mxfp4_gsm8k_correctness(config: AccuracyTestConfig): and measured_value + rtol > EXPECTED_VALUE ), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}" - del os.environ["VLLM_USE_TRITON_FLASH_ATTN"] - @pytest.mark.skipif(not QUARK_MXFP4_AVAILABLE, reason="amd-quark>=0.9 is not available") @pytest.mark.parametrize("float_dtype", [torch.bfloat16, torch.float16]) diff --git a/vllm/attention/ops/triton_flash_attention.py b/vllm/attention/ops/triton_flash_attention.py deleted file mode 100644 index c0ab35d07..000000000 --- a/vllm/attention/ops/triton_flash_attention.py +++ /dev/null @@ -1,932 +0,0 @@ -#!/usr/bin/env python -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -""" -Fused Attention -=============== - -This is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao -(https://tridao.me/publications/flash2/flash2.pdf) -Credits: OpenAI kernel team, AMD ML Frameworks Triton team - -Features supported: - -1) Fwd with causal masking -2) Any sequence lengths without padding (currently fwd kernel only) -3) Support for different sequence lengths for q and k -4) Nested tensor API currently does not support dropout or bias. - -Not currently supported: - -1) Non power of two head dims - -""" - -import torch - -from vllm.platforms import current_platform -from vllm.triton_utils import tl, triton - -# Avoid misleading ROCm warning. -if current_platform.is_rocm(): - from vllm.platforms.rocm import on_gfx1x -else: - on_gfx1x = lambda *args, **kwargs: False - -torch_dtype: tl.constexpr = torch.float16 - - -@triton.jit -def cdiv_fn(x, y): - return (x + y - 1) // y - - -@triton.jit -def max_fn(x, y): - return tl.math.max(x, y) - - -@triton.jit -def dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, stride): - ms = tl.arange(0, m) - ns = tl.arange(0, n) - return philox_offset + ms[:, None] * stride + ns[None, :] - - -@triton.jit -def dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride): - rng_offsets = dropout_offsets( - philox_seed, philox_offset, dropout_p, m, n, stride - ).to(tl.uint32) - # TODO: use tl.randint for better performance - return tl.rand(philox_seed, rng_offsets) - - -@triton.jit -def dropout_mask(philox_seed, philox_offset, dropout_p, m, n, stride): - rng_output = dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride) - rng_keep = rng_output > dropout_p - return rng_keep - - -@triton.jit -def load_fn(block_ptr, first, second, pad): - if first and second: - tensor = tl.load(block_ptr, boundary_check=(0, 1), padding_option=pad) - elif first: - tensor = tl.load(block_ptr, boundary_check=(0,), padding_option=pad) - elif second: - tensor = tl.load(block_ptr, boundary_check=(1,), padding_option=pad) - else: - tensor = tl.load(block_ptr) - return tensor - - -@triton.jit -def _attn_fwd_inner( - acc, - l_i, - m_i, - q, - K_block_ptr, - V_block_ptr, - start_m, - actual_seqlen_k, - dropout_p, - philox_seed, - batch_philox_offset, - encoded_softmax_block_ptr, - block_min, - block_max, - offs_n_causal, - masked_blocks, - n_extra_tokens, - bias_ptr, - IS_CAUSAL: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, - OFFS_M: tl.constexpr, - OFFS_N: tl.constexpr, - PRE_LOAD_V: tl.constexpr, - MASK_STEPS: tl.constexpr, - ENABLE_DROPOUT: tl.constexpr, - RETURN_ENCODED_SOFTMAX: tl.constexpr, - PADDED_HEAD: tl.constexpr, - USE_FP8: tl.constexpr, - qk_scale, - p_descale, -): - # loop over k, v, and update accumulator - for start_n in range(block_min, block_max, BLOCK_N): - # For padded blocks, we will overrun the tensor size if - # we load all BLOCK_N. For others, the blocks are all within range. - k = load_fn( - K_block_ptr, - PADDED_HEAD, - MASK_STEPS and (n_extra_tokens != 0), - "zero", - ) - if PRE_LOAD_V: - v = load_fn( - V_block_ptr, - MASK_STEPS and (n_extra_tokens != 0), - PADDED_HEAD, - "zero", - ) - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - # We start from end of seqlen_k so only the first iteration would need - # to be checked for padding if it is not a multiple of block_n - # TODO: This can be optimized to only be true for the padded block. - if MASK_STEPS: # noqa: SIM102 - # If this is the last block / iteration, we want to - # mask if the sequence length is not a multiple of block size - # a solution is to always do BLOCK_M // BLOCK_N + 1 steps - # if not is_modulo_mn. last step might get wasted but that is okay. - # check if this masking works for that case. - if (start_n + BLOCK_N == block_max) and (n_extra_tokens != 0): - boundary_m = tl.full([BLOCK_M], actual_seqlen_k, dtype=tl.int32) - size_n = start_n + OFFS_N[None, :] - mask = size_n < boundary_m[:, None] - qk = tl.where(mask, qk, float("-inf")) - if IS_CAUSAL: - causal_boundary = start_n + offs_n_causal - causal_mask = OFFS_M[:, None] >= causal_boundary[None, :] - qk = tl.where(causal_mask, qk, float("-inf")) - # -- compute qk ---- - qk += tl.dot(q, k) - if USE_FP8: - qk *= qk_scale - if bias_ptr is not None: - bias = load_fn( - bias_ptr, False, MASK_STEPS and (n_extra_tokens != 0), "zero" - ) - # While bias is added after multiplying qk with sm_scale, our - # optimization to use 2^x instead of e^x results in an additional - # scale factor of log2(e) which we must also multiply the bias with. - qk += bias * 1.44269504089 - m_ij = tl.maximum(m_i, tl.max(qk, 1)) - qk = qk - m_ij[:, None] - p = tl.math.exp2(qk) - - # CAVEAT: Must update l_ij before applying dropout - l_ij = tl.sum(p, 1) - if ENABLE_DROPOUT: - philox_offset = ( - batch_philox_offset - + start_m * BLOCK_M * actual_seqlen_k - + start_n - - BLOCK_N - ) - keep = dropout_mask( - philox_seed, - philox_offset, - dropout_p, - BLOCK_M, - BLOCK_N, - actual_seqlen_k, - ) - if RETURN_ENCODED_SOFTMAX: - tl.store( - encoded_softmax_block_ptr, - tl.where(keep, p, -p).to(encoded_softmax_block_ptr.type.element_ty), - ) - p = tl.where(keep, p, 0.0) - elif RETURN_ENCODED_SOFTMAX: - tl.store( - encoded_softmax_block_ptr, - p.to(encoded_softmax_block_ptr.type.element_ty), - ) - # -- update output accumulator -- - alpha = tl.math.exp2(m_i - m_ij) - acc = acc * alpha[:, None] - if not PRE_LOAD_V: - v = load_fn( - V_block_ptr, - MASK_STEPS and (n_extra_tokens != 0), - PADDED_HEAD, - "zero", - ) - # -- update m_i and l_i - l_i = l_i * alpha + l_ij - # update m_i and l_i - m_i = m_ij - - if USE_FP8: - p *= p_descale - - acc += tl.dot(p.to(V_block_ptr.type.element_ty), v) - - V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) - K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) - if bias_ptr is not None: - bias_ptr = tl.advance(bias_ptr, (0, BLOCK_N)) - if RETURN_ENCODED_SOFTMAX: - encoded_softmax_block_ptr = tl.advance( - encoded_softmax_block_ptr, (0, BLOCK_N) - ) - return acc, l_i, m_i - - -def get_cdna_autotune_configs(): - return [ - triton.Config( - {"BLOCK_M": 256, "BLOCK_N": 64, "waves_per_eu": 2, "PRE_LOAD_V": False}, - num_stages=1, - num_warps=8, - ), - triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 128, "waves_per_eu": 2, "PRE_LOAD_V": False}, - num_stages=1, - num_warps=4, - ), - triton.Config( - {"BLOCK_M": 256, "BLOCK_N": 128, "waves_per_eu": 2, "PRE_LOAD_V": False}, - num_stages=1, - num_warps=8, - ), - triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 64, "waves_per_eu": 1, "PRE_LOAD_V": False}, - num_stages=1, - num_warps=4, - ), - triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 64, "waves_per_eu": 3, "PRE_LOAD_V": True}, - num_stages=1, - num_warps=4, - ), - triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 64, "waves_per_eu": 3, "PRE_LOAD_V": False}, - num_stages=1, - num_warps=4, - ), - triton.Config( - {"BLOCK_M": 64, "BLOCK_N": 64, "waves_per_eu": 4, "PRE_LOAD_V": False}, - num_stages=1, - num_warps=8, - ), - triton.Config( - {"BLOCK_M": 32, "BLOCK_N": 32, "waves_per_eu": 4, "PRE_LOAD_V": False}, - num_stages=1, - num_warps=8, - ), - # TODO: This config fails with head_size not pow2 with data mismatches. - # triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 1, - # 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), - # Fails in AccelerateAMDMatmul (Triton) assert when using FP8: - # triton.Config( - # { - # "BLOCK_M": 16, - # "BLOCK_N": 16, - # "waves_per_eu": 1, - # "PRE_LOAD_V": False, - # }, - # num_stages=1, - # num_warps=4, - # ), - ], ["IS_CAUSAL", "dropout_p", "BLOCK_DMODEL", "USE_FP8"] - - -def get_rdna_autotune_configs(): - return [ - triton.Config( - {"BLOCK_M": 32, "BLOCK_N": 32, "waves_per_eu": 4, "PRE_LOAD_V": False}, - num_stages=1, - num_warps=2, - ), - triton.Config( - {"BLOCK_M": 32, "BLOCK_N": 32, "waves_per_eu": 2, "PRE_LOAD_V": False}, - num_stages=1, - num_warps=2, - ), - triton.Config( - {"BLOCK_M": 32, "BLOCK_N": 16, "waves_per_eu": 4, "PRE_LOAD_V": False}, - num_stages=1, - num_warps=2, - ), - triton.Config( - {"BLOCK_M": 32, "BLOCK_N": 16, "waves_per_eu": 2, "PRE_LOAD_V": False}, - num_stages=1, - num_warps=2, - ), - # Fails in AccelerateAMDMatmul (Triton) assert when using FP8: - # triton.Config( - # { - # 'BLOCK_M': 16, - # 'BLOCK_N': 16, - # 'waves_per_eu': 4, - # 'PRE_LOAD_V': False - # }, - # num_stages=1, - # num_warps=2), - # triton.Config( - # { - # 'BLOCK_M': 16, - # 'BLOCK_N': 16, - # 'waves_per_eu': 2, - # 'PRE_LOAD_V': False - # }, - # num_stages=1, - # num_warps=2), - # # Fall-back config. - # triton.Config( - # { - # 'BLOCK_M': 16, - # 'BLOCK_N': 16, - # 'waves_per_eu': 1, - # 'PRE_LOAD_V': False - # }, - # num_stages=1, - # num_warps=2), - ], ["IS_CAUSAL", "dropout_p", "BLOCK_DMODEL", "USE_FP8"] - - -def get_autotune_configs(): - if on_gfx1x(): - return get_rdna_autotune_configs() - else: - return get_cdna_autotune_configs() - - -autotune_configs, autotune_keys = get_autotune_configs() - -float8_info = torch.finfo(current_platform.fp8_dtype()) - - -@triton.autotune( - configs=autotune_configs, - key=autotune_keys, -) -@triton.jit -def attn_fwd( - Q, - K, - V, - bias, - sm_scale, - q_scale, - k_scale, - v_scale, - p_scale, - p_descale, - o_descale, - L, - Out, - stride_qz: tl.int64, - stride_qh: tl.int64, - stride_qm: tl.int64, - stride_qk: tl.int64, - stride_kz: tl.int64, - stride_kh: tl.int64, - stride_kn: tl.int64, - stride_kk: tl.int64, - stride_vz: tl.int64, - stride_vh: tl.int64, - stride_vk: tl.int64, - stride_vn: tl.int64, - stride_oz: tl.int64, - stride_oh: tl.int64, - stride_om: tl.int64, - stride_on: tl.int64, - stride_bz: tl.int64, - stride_bh: tl.int64, - stride_bm: tl.int64, - stride_bn: tl.int64, - cu_seqlens_q, - cu_seqlens_k, - dropout_p, - philox_seed, - philox_offset_base, - encoded_softmax, - HQ: tl.constexpr, - HK: tl.constexpr, - ACTUAL_BLOCK_DMODEL: tl.constexpr, - MAX_SEQLENS_Q: tl.constexpr, - MAX_SEQLENS_K: tl.constexpr, - VARLEN: tl.constexpr, - IS_CAUSAL: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - USE_FP8: tl.constexpr, - USE_FP8_OUT: tl.constexpr, - BLOCK_N: tl.constexpr, - PRE_LOAD_V: tl.constexpr, - BIAS_TYPE: tl.constexpr, - ENABLE_DROPOUT: tl.constexpr, - RETURN_ENCODED_SOFTMAX: tl.constexpr, - FP8_MIN: tl.constexpr = float8_info.min, - FP8_MAX: tl.constexpr = float8_info.max, -): - start_m = tl.program_id(0) - off_h_q = tl.program_id(1) - off_z = tl.program_id(2) - offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - offs_n = tl.arange(0, BLOCK_N) - if VARLEN: - cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z) - cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1) - seqlen_q = cu_seqlens_q_end - cu_seqlens_q_start - # We have a one-size-fits-all grid in id(0). Some seqlens might be too - # small for all start_m so for those we return early. - if start_m * BLOCK_M > seqlen_q: - return - cu_seqlens_k_start = tl.load(cu_seqlens_k + off_z) - cu_seqlens_k_end = tl.load(cu_seqlens_k + off_z + 1) - seqlen_k = cu_seqlens_k_end - cu_seqlens_k_start - else: - cu_seqlens_q_start = 0 - cu_seqlens_k_start = 0 - seqlen_q = MAX_SEQLENS_Q - seqlen_k = MAX_SEQLENS_K - - # Now we compute whether we need to exit early due to causal masking. - # This is because for seqlen_q > seqlen_k, M rows of the attn scores - # are completely masked, resulting in 0s written to the output, and - # inf written to LSE. We don't need to do any GEMMs in this case. - # This block of code determines what N is, and if this WG is operating - # on those M rows. - n_blocks = cdiv_fn(seqlen_k, BLOCK_N) - if IS_CAUSAL: - # If seqlen_q == seqlen_k, the attn scores are a square matrix. - # If seqlen_q != seqlen_k, attn scores are rectangular which means - # the causal mask boundary is bottom right aligned, and ends at either - # the top edge (seqlen_q < seqlen_k) or left edge. - # This captures the decrease in n_blocks if we have a rectangular attn - # matrix - n_blocks_seqlen = cdiv_fn( - (start_m + 1) * BLOCK_M + seqlen_k - seqlen_q, BLOCK_N - ) - # This is what adjusts the block_max for the current WG, only - # if IS_CAUSAL. Otherwise we want to always iterate through all n_blocks - n_blocks = min(n_blocks, n_blocks_seqlen) - # If we have no blocks after adjusting for seqlen deltas, this WG is - # part of the blocks that are all 0. We exit early. - if n_blocks <= 0: - o_offset = ( - off_z * stride_oz + cu_seqlens_q_start * stride_om + off_h_q * stride_oh - ) - O_block_ptr = tl.make_block_ptr( - base=Out + o_offset, - shape=(seqlen_q, BLOCK_DMODEL), - strides=(stride_om, stride_on), - offsets=(start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, BLOCK_DMODEL), - order=(1, 0), - ) - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=Out.type.element_ty) - # We still need to write 0s to the result - # tl.store(O_block_ptr, - # acc.to(Out.type.element_ty), boundary_check=(0,1)) - # l_ptrs = L + off_z * HQ * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q - # + offs_m - # We store inf to LSE, not -inf because in the bwd pass, - # we subtract this - # from qk which makes it -inf, such that exp(qk - inf) = 0 - # for these masked blocks. - # l = tl.full([BLOCK_M], value=float("inf"), dtype=tl.float32) - # tl.store(l_ptrs, l) - # TODO: Should dropout and return encoded softmax be handled here? - return - - # If MQA / GQA, set the K and V head offsets appropriately. - GROUP_SIZE: tl.constexpr = HQ // HK - off_h_k = off_h_q // GROUP_SIZE if GROUP_SIZE != 1 else off_h_q - - n_extra_tokens = 0 - if seqlen_k < BLOCK_N: - n_extra_tokens = BLOCK_N - seqlen_k - elif seqlen_k % BLOCK_N: - n_extra_tokens = seqlen_k % BLOCK_N - padded_head = ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL - - # Compute pointers for all the tensors used in this kernel. - q_offset = off_z * stride_qz + off_h_q * stride_qh + cu_seqlens_q_start * stride_qm - Q_block_ptr = tl.make_block_ptr( - base=Q + q_offset, - shape=(seqlen_q, ACTUAL_BLOCK_DMODEL), - strides=(stride_qm, stride_qk), - offsets=(start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, BLOCK_DMODEL), - order=(1, 0), - ) - k_offset = off_z * stride_kz + off_h_k * stride_kh + cu_seqlens_k_start * stride_kn - K_block_ptr = tl.make_block_ptr( - base=K + k_offset, - shape=(ACTUAL_BLOCK_DMODEL, seqlen_k), - strides=(stride_kk, stride_kn), - offsets=(0, 0), - block_shape=(BLOCK_DMODEL, BLOCK_N), - order=(0, 1), - ) - v_offset = off_z * stride_vz + off_h_k * stride_vh + cu_seqlens_k_start * stride_vk - V_block_ptr = tl.make_block_ptr( - base=V + v_offset, - shape=(seqlen_k, ACTUAL_BLOCK_DMODEL), - strides=(stride_vk, stride_vn), - offsets=(0, 0), - block_shape=(BLOCK_N, BLOCK_DMODEL), - order=(1, 0), - ) - if BIAS_TYPE != 0: - bias_ptr = tl.make_block_ptr( - base=bias + off_h_q * stride_bh, - shape=(seqlen_q, seqlen_k), - strides=(stride_bm, stride_bn), - offsets=(start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, BLOCK_N), - order=(1, 0), - ) - else: - bias_ptr = None - if ENABLE_DROPOUT: - batch_philox_offset = ( - philox_offset_base + (off_z * HQ + off_h_q) * seqlen_q * seqlen_k - ) - else: - batch_philox_offset = 0 - # We can ask to return the dropout mask without actually doing any dropout. - # In this case, we return an invalid pointer so indicate the mask is not i - # valid. - # TODO: Fix encoded softmax. It currently uses just h_q in the base offset. - if RETURN_ENCODED_SOFTMAX: - encoded_softmax_block_ptr = tl.make_block_ptr( - base=encoded_softmax + off_h_q * seqlen_q * seqlen_k, - shape=(seqlen_q, seqlen_k), - strides=(seqlen_k, 1), - offsets=(start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, BLOCK_N), - order=(1, 0), - ) - else: - encoded_softmax_block_ptr = 0 - # initialize pointer to m and l - m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) - l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32) - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) - # scale sm_scale by log_2(e) and use 2^x in the loop as we do not - # have native e^x support in HW. - qk_scale = sm_scale * 1.44269504089 - # Q is loaded once at the beginning and shared by all N blocks. - q = load_fn(Q_block_ptr, True, padded_head, "zero") - if not USE_FP8: - q = (q * qk_scale).to(Q_block_ptr.type.element_ty) - acc_scale = 1.0 - else: - qk_scale *= q_scale * k_scale - acc_scale = p_scale * v_scale - - # Here we compute how many full and masked blocks we have. - padded_block_k = n_extra_tokens != 0 - is_modulo_mn = not padded_block_k and (seqlen_q % BLOCK_M == 0) - if IS_CAUSAL: - # There are always at least BLOCK_M // BLOCK_N masked blocks. - # Additionally there might be one more due to dissimilar seqlens. - masked_blocks = BLOCK_M // BLOCK_N + (not is_modulo_mn) - else: - # Padding on Q does not need to be masked in the FA loop. - masked_blocks = padded_block_k - # if IS_CAUSAL, not is_modulo_mn does not always result in an additional - # block. In this case we might exceed n_blocks so pick the min. - masked_blocks = min(masked_blocks, n_blocks) - n_full_blocks = n_blocks - masked_blocks - block_min = 0 - block_max = n_blocks * BLOCK_N - # Compute for full blocks. Here we set causal to false regardless of its - # value because there is no masking. Similarly we do not need padding. - if n_full_blocks > 0: - block_max = (n_blocks - masked_blocks) * BLOCK_N - acc, l_i, m_i = _attn_fwd_inner( - acc, - l_i, - m_i, - q, - K_block_ptr, - V_block_ptr, - start_m, - seqlen_k, - dropout_p, - philox_seed, - batch_philox_offset, - encoded_softmax_block_ptr, - # _, _, offs_n_causal, masked_blocks, n_extra_tokens, _ - block_min, - block_max, - 0, - 0, - 0, - bias_ptr, - # IS_CAUSAL, .... - False, - BLOCK_M, - BLOCK_DMODEL, - BLOCK_N, - offs_m, - offs_n, - # _, MASK_STEPS, ... - PRE_LOAD_V, - False, - ENABLE_DROPOUT, - RETURN_ENCODED_SOFTMAX, - padded_head, - USE_FP8, - qk_scale, - p_descale, - ) - block_min = block_max - block_max = n_blocks * BLOCK_N - - tl.debug_barrier() - # Remaining blocks, if any, are full / not masked. - if masked_blocks > 0: - offs_n_causal = offs_n + (seqlen_q - seqlen_k) if IS_CAUSAL else 0 - K_block_ptr = tl.advance(K_block_ptr, (0, n_full_blocks * BLOCK_N)) - V_block_ptr = tl.advance(V_block_ptr, (n_full_blocks * BLOCK_N, 0)) - if bias_ptr is not None: - bias_ptr = tl.advance(bias_ptr, (0, n_full_blocks * BLOCK_N)) - if RETURN_ENCODED_SOFTMAX: - encoded_softmax_block_ptr = tl.advance( - encoded_softmax_block_ptr, (0, n_full_blocks) - ) - acc, l_i, m_i = _attn_fwd_inner( - acc, - l_i, - m_i, - q, - K_block_ptr, - V_block_ptr, - start_m, - seqlen_k, - dropout_p, - philox_seed, - batch_philox_offset, - encoded_softmax_block_ptr, - block_min, - block_max, - offs_n_causal, - masked_blocks, - n_extra_tokens, - bias_ptr, - IS_CAUSAL, - BLOCK_M, - BLOCK_DMODEL, - BLOCK_N, - offs_m, - offs_n, - # _, MASK_STEPS, ... - PRE_LOAD_V, - True, - ENABLE_DROPOUT, - RETURN_ENCODED_SOFTMAX, - padded_head, - USE_FP8, - qk_scale, - p_descale, - ) - # epilogue - - if USE_FP8: - acc *= acc_scale - acc = acc / l_i[:, None] - if ENABLE_DROPOUT: - acc = acc / (1 - dropout_p) - # If seqlen_q > seqlen_k but the delta is not a multiple of BLOCK_M, - # then we have one block with a row of all NaNs which come from computing - # softmax over a row of all -infs (-inf - inf = NaN). We check for that here - # and store 0s where there are NaNs as these rows should've been zeroed out. - end_m_idx = (start_m + 1) * BLOCK_M - start_m_idx = start_m * BLOCK_M - causal_start_idx = seqlen_q - seqlen_k - if USE_FP8_OUT: - acc *= o_descale - acc = tl.clamp(acc, FP8_MIN, FP8_MAX) - acc = acc.to(Out.type.element_ty) - if IS_CAUSAL: # noqa: SIM102 - if causal_start_idx > start_m_idx and causal_start_idx < end_m_idx: - out_mask_boundary = tl.full( - (BLOCK_DMODEL,), causal_start_idx, dtype=tl.int32 - ) - mask_m_offsets = start_m_idx + tl.arange(0, BLOCK_M) - out_ptrs_mask = mask_m_offsets[:, None] >= out_mask_boundary[None, :] - z = tl.zeros((1,), tl.float32) - acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty)) - # write back LSE - # l_ptrs = L + off_z * HQ * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m - # If seqlen_q not multiple of BLOCK_M, we need to mask out the last - # few rows. This is only true for the last M block. For others, - # overflow_size will be -ve - # overflow_size = end_m_idx - seqlen_q - # if overflow_size > 0: - # boundary = tl.full((BLOCK_M,), BLOCK_M - overflow_size, dtype=tl.int32) - # # This is a > check because mask being 0 blocks the store. - # l_ptrs_mask = boundary > tl.arange(0, BLOCK_M) - # tl.store(l_ptrs, m_i + tl.math.log2(l_i), mask=l_ptrs_mask) - # else: - # tl.store(l_ptrs, m_i + tl.math.log2(l_i)) - - # write back O - o_offset = off_z * stride_oz + cu_seqlens_q_start * stride_om + off_h_q * stride_oh - O_block_ptr = tl.make_block_ptr( - base=Out + o_offset, - shape=(seqlen_q, ACTUAL_BLOCK_DMODEL), - strides=(stride_om, stride_on), - offsets=(start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, BLOCK_DMODEL), - order=(1, 0), - ) - # Need boundary check on this to make sure the padding from the - # Q and KV tensors in both dims are not part of what we store back. - # TODO: Do the boundary check optionally. - tl.store(O_block_ptr, acc, boundary_check=(0, 1)) - - -def check_args( - q, - k, - v, - o, - varlen=True, - max_seqlens=None, - cu_seqlens_q=None, - cu_seqlens_k=None, -): - assert q.dim() == k.dim() and q.dim() == v.dim() - if varlen: - assert q.dim() == 3 - total_q, nheads_q, head_size = q.shape - total_k, nheads_k, _ = k.shape - assert cu_seqlens_q is not None - assert cu_seqlens_k is not None - assert len(cu_seqlens_q) == len(cu_seqlens_k) - else: - assert q.dim() == 4 - batch, nheads_q, seqlen_q, head_size = q.shape - _, nheads_k, seqlen_k, _ = k.shape - assert max_seqlens > 0 - assert k.shape == v.shape - assert q.shape[-1] == k.shape[-1] and q.shape[-1] == v.shape[-1] - # TODO: Change assert if we support qkl f8 and v f16 - assert q.dtype == k.dtype and q.dtype == v.dtype - assert head_size <= 256 - assert o.shape == q.shape - assert (nheads_q % nheads_k) == 0 - - -class _attention(torch.autograd.Function): - @staticmethod - def forward( - ctx, - q, - k, - v, - o, - cu_seqlens_q, - cu_seqlens_k, - max_seqlens_q, - max_seqlens_k, - causal=False, - sm_scale=1.0, - bias=None, - fp8_scales=None, - fp8_out_scale=None, - ): - if fp8_scales is not None: - use_fp8 = True - (q_scale, k_scale, v_scale, p_scale) = fp8_scales - float8 = current_platform.fp8_dtype() - - def check_and_convert(t, scale): - if t.dtype != float8: - descale = 1.0 / scale - ts = (t * descale).clamp(min=float8_info.min, max=float8_info.max) - return ts.to(float8) - else: - return t - - q = check_and_convert(q, q_scale) - k = check_and_convert(k, k_scale) - v = check_and_convert(v, v_scale) - else: - use_fp8 = False - q_scale = k_scale = v_scale = p_scale = 1.0 - - if o is None: - o = torch.empty_like(q, dtype=v.dtype) - - check_args( - q, - k, - v, - o, - varlen=True, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=cu_seqlens_k, - ) - if True: # varlen - total_q, nheads_q, head_size = q.shape - total_k, nheads_k, _ = k.shape - batch = len(cu_seqlens_q) - 1 - q_strides = (0, q.stride(1), q.stride(0), q.stride(2)) - k_strides = (0, k.stride(1), k.stride(0), k.stride(2)) - v_strides = (0, v.stride(1), v.stride(0), v.stride(2)) - o_strides = (0, o.stride(1), o.stride(0), o.stride(2)) - else: - batch, seqlen_q, nheads_q, head_size = q.shape - _, seqlen_k, nheads_k, _ = k.shape - q_strides = (q.stride(0), q.stride(2), q.stride(1), q.stride(3)) - k_strides = (k.stride(0), k.stride(2), k.stride(1), k.stride(3)) - v_strides = (v.stride(0), v.stride(2), v.stride(1), v.stride(3)) - o_strides = (o.stride(0), o.stride(2), o.stride(1), o.stride(3)) - - # Get closest power of 2 over or equal to 32. - unpadded_head_dims = {32, 64, 128, 256} - if head_size not in unpadded_head_dims: - padded_d_model = None - for i in unpadded_head_dims: - if i > head_size: - padded_d_model = i - break - assert padded_d_model is not None - else: - padded_d_model = head_size - - grid = lambda META: ( - triton.cdiv(max_seqlens_q, META["BLOCK_M"]), - nheads_q, - batch, - ) - - encoded_softmax = None - - # Seed the RNG so we get reproducible results for testing. - philox_seed = 0x1BF52 - philox_offset = 0x1D4B42 - - if bias is not None: - bias_strides = ( - bias.stride(0), - bias.stride(1), - bias.stride(2), - bias.stride(3), - ) - else: - bias_strides = (0, 0, 0, 0) - - p_descale = 1.0 / p_scale - o_descale = 1.0 / fp8_out_scale.item() if fp8_out_scale is not None else 1.0 - - arg_max_seqlens_q = 0 if on_gfx1x() else max_seqlens_q - arg_max_seqlens_k = 0 if on_gfx1x() else max_seqlens_k - - attn_fwd[grid]( - q, - k, - v, - bias, - sm_scale, - q_scale, - k_scale, - v_scale, - p_scale, - p_descale, - o_descale, - None, - o, - *q_strides, - *k_strides, - *v_strides, - *o_strides, - *bias_strides, - cu_seqlens_q, - cu_seqlens_k, - dropout_p=0.0, - philox_seed=philox_seed, - philox_offset_base=philox_offset, - encoded_softmax=encoded_softmax, - HQ=nheads_q, - HK=nheads_k, - ACTUAL_BLOCK_DMODEL=head_size, - MAX_SEQLENS_Q=arg_max_seqlens_q, - MAX_SEQLENS_K=arg_max_seqlens_k, - IS_CAUSAL=causal, - VARLEN=True, - BLOCK_DMODEL=padded_d_model, - BIAS_TYPE=0 if bias is None else 1, - ENABLE_DROPOUT=False, - RETURN_ENCODED_SOFTMAX=False, - USE_FP8=use_fp8, - USE_FP8_OUT=fp8_out_scale is not None, - ) - - ctx.grid = grid - ctx.sm_scale = sm_scale - ctx.BLOCK_DMODEL = head_size - ctx.causal = causal - ctx.dropout_p = 0.0 - ctx.philox_seed = philox_seed - ctx.philox_offset = philox_offset - ctx.encoded_softmax = encoded_softmax - ctx.return_encoded_softmax = False - return o, encoded_softmax - - -triton_attention = _attention.apply diff --git a/vllm/envs.py b/vllm/envs.py index 2aa6afcab..b99e25243 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -18,7 +18,6 @@ if TYPE_CHECKING: VLLM_RINGBUFFER_WARNING_INTERVAL: int = 60 VLLM_NCCL_SO_PATH: str | None = None LD_LIBRARY_PATH: str | None = None - VLLM_USE_TRITON_FLASH_ATTN: bool = True VLLM_V1_USE_PREFILL_DECODE_ATTENTION: bool = False VLLM_FLASH_ATTN_VERSION: int | None = None LOCAL_RANK: int = 0 @@ -521,10 +520,6 @@ environment_variables: dict[str, Callable[[], Any]] = { # when `VLLM_NCCL_SO_PATH` is not set, vllm will try to find the nccl # library file in the locations specified by `LD_LIBRARY_PATH` "LD_LIBRARY_PATH": lambda: os.environ.get("LD_LIBRARY_PATH", None), - # flag to control if vllm should use triton flash attention - "VLLM_USE_TRITON_FLASH_ATTN": lambda: ( - os.environ.get("VLLM_USE_TRITON_FLASH_ATTN", "True").lower() in ("true", "1") - ), # Use separate prefill and decode kernels for V1 attention instead of # the unified triton kernel. "VLLM_V1_USE_PREFILL_DECODE_ATTENTION": lambda: ( @@ -1554,7 +1549,6 @@ def compute_hash() -> str: "VLLM_PP_LAYER_PARTITION", "VLLM_MLA_DISABLE", "VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH", - "VLLM_USE_TRITON_FLASH_ATTN", "VLLM_USE_TRITON_AWQ", "VLLM_DP_RANK", "VLLM_DP_SIZE", diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 5318bdb8b..f5f680825 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -49,25 +49,8 @@ _ROCM_UNSUPPORTED_MODELS: list[str] = [] # Models partially supported by ROCm. # Architecture -> Reason. -_ROCM_SWA_REASON = ( - "Sliding window attention (SWA) is not yet supported in " - "Triton flash attention. For half-precision SWA support, " - "please use CK flash attention by setting " - "`VLLM_USE_TRITON_FLASH_ATTN=0`" -) -_ROCM_PARTIALLY_SUPPORTED_MODELS: dict[str, str] = { - "Qwen2ForCausalLM": _ROCM_SWA_REASON, - "MistralForCausalLM": _ROCM_SWA_REASON, - "MixtralForCausalLM": _ROCM_SWA_REASON, - "PaliGemmaForConditionalGeneration": ( - "ROCm flash attention does not yet fully support 32-bit precision on PaliGemma" - ), - "Phi3VForCausalLM": ( - "ROCm Triton flash attention may run into compilation errors due to " - "excessive use of shared memory. If this happens, disable Triton FA " - "by setting `VLLM_USE_TRITON_FLASH_ATTN=0`" - ), -} +_ROCM_SWA_REASON = () +_ROCM_PARTIALLY_SUPPORTED_MODELS: dict[str, str] = {} _ROCM_DEVICE_ID_NAME_MAP: dict[str, str] = { "0x74a0": "AMD_Instinct_MI300A", "0x74a1": "AMD_Instinct_MI300X", diff --git a/vllm/usage/usage_lib.py b/vllm/usage/usage_lib.py index 4eddaf56d..69226763a 100644 --- a/vllm/usage/usage_lib.py +++ b/vllm/usage/usage_lib.py @@ -37,7 +37,6 @@ _GLOBAL_RUNTIME_DATA = dict[str, str | int | bool]() _USAGE_ENV_VARS_TO_COLLECT = [ "VLLM_USE_MODELSCOPE", - "VLLM_USE_TRITON_FLASH_ATTN", "VLLM_ATTENTION_BACKEND", "VLLM_USE_FLASHINFER_SAMPLER", "VLLM_PP_LAYER_PARTITION", diff --git a/vllm/v1/attention/backends/mla/triton_mla.py b/vllm/v1/attention/backends/mla/triton_mla.py index 0149639e8..54ad3acb9 100644 --- a/vllm/v1/attention/backends/mla/triton_mla.py +++ b/vllm/v1/attention/backends/mla/triton_mla.py @@ -5,22 +5,18 @@ from typing import ClassVar import torch -from vllm import envs from vllm.attention.backends.abstract import ( AttentionLayer, AttentionType, is_quantized_kv_cache, ) from vllm.attention.ops.triton_decode_attention import decode_attention_fwd -from vllm.attention.ops.triton_flash_attention import triton_attention from vllm.config.cache import CacheDType from vllm.logger import init_logger from vllm.model_executor.layers.batch_invariant import ( vllm_is_batch_invariant, ) -from vllm.platforms import current_platform from vllm.platforms.interface import DeviceCapability -from vllm.triton_utils import HAS_TRITON from vllm.v1.attention.backends.mla.common import ( MLACommonBackend, MLACommonImpl, @@ -99,54 +95,17 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]): "TritonMLA V1 with FP8 KV cache not yet supported" ) - self.use_triton_flash_attn = envs.VLLM_USE_TRITON_FLASH_ATTN - self.triton_fa_func = triton_attention if HAS_TRITON else None - - def _flash_attn_varlen_diff_headdims_rocm( - self, q, k, v, softmax_scale=None, **kwargs - ): - assert self.triton_fa_func is not None - - # Triton Attention requires a padded V - padded_v = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]], value=0) - # The output of triton_attention is a tuple of - # [output_tensor, encoded_softmax] where encoded_softmax is always None - output_tensor, _ = self.triton_fa_func( - q, - k, - padded_v, - None, # output - kwargs["cu_seqlens_q"], - kwargs["cu_seqlens_k"], - kwargs["max_seqlen_q"], - kwargs["max_seqlen_k"], - kwargs["causal"], - softmax_scale, - None, # bias - ) - - return output_tensor - def _flash_attn_varlen_diff_headdims( self, q, k, v, return_softmax_lse=False, softmax_scale=None, **kwargs ): - if ( - current_platform.is_rocm() - and self.use_triton_flash_attn - and not return_softmax_lse - ): - return self._flash_attn_varlen_diff_headdims_rocm( - q, k, v, softmax_scale=softmax_scale, **kwargs - ) - else: - return super()._flash_attn_varlen_diff_headdims( - q, - k, - v, - return_softmax_lse=return_softmax_lse, - softmax_scale=softmax_scale, - **kwargs, - ) + return super()._flash_attn_varlen_diff_headdims( + q, + k, + v, + return_softmax_lse=return_softmax_lse, + softmax_scale=softmax_scale, + **kwargs, + ) def _forward_decode( self,