diff --git a/vllm/v1/attention/backends/mla/flashmla_sparse.py b/vllm/v1/attention/backends/mla/flashmla_sparse.py index 0cafd4bbf..d40f4d4d3 100644 --- a/vllm/v1/attention/backends/mla/flashmla_sparse.py +++ b/vllm/v1/attention/backends/mla/flashmla_sparse.py @@ -149,7 +149,7 @@ class FlashMLASparseMetadata(AttentionMetadata): cache_lens: torch.Tensor @dataclass - class FP8SeperatePrefillDecode: + class FP8SeparatePrefillDecode: @dataclass class Decode: kernel_metadata: "FlashMLASparseMetadata.FP8KernelMetadata" @@ -196,7 +196,7 @@ class FlashMLASparseMetadata(AttentionMetadata): decode: Decode | None = None prefill: Prefill | None = None - fp8_extra_metadata: FP8SeperatePrefillDecode | FP8KernelMetadata | None = None + fp8_extra_metadata: FP8SeparatePrefillDecode | FP8KernelMetadata | None = None fp8_use_mixed_batch: bool = False @@ -485,7 +485,7 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad def _build_fp8_separate_prefill_decode( self, common_attn_metadata: CommonAttentionMetadata, - ) -> "FlashMLASparseMetadata.FP8SeperatePrefillDecode": + ) -> "FlashMLASparseMetadata.FP8SeparatePrefillDecode": num_tokens = common_attn_metadata.num_actual_tokens (num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens) = ( @@ -496,7 +496,7 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad ) ) - FP8Meta = FlashMLASparseMetadata.FP8SeperatePrefillDecode + FP8Meta = FlashMLASparseMetadata.FP8SeparatePrefillDecode fp8_metadata = FP8Meta( num_decodes=num_decodes, num_prefills=num_prefills, @@ -659,7 +659,7 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad req_id_per_token = self.req_id_per_token_buffer[:num_tokens] fp8_extra_metadata: ( - FlashMLASparseMetadata.FP8SeperatePrefillDecode + FlashMLASparseMetadata.FP8SeparatePrefillDecode | FlashMLASparseMetadata.FP8KernelMetadata | None ) = None @@ -765,7 +765,7 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]): attn_metadata: FlashMLASparseMetadata, ) -> torch.Tensor: fp8_metadata = attn_metadata.fp8_extra_metadata - assert isinstance(fp8_metadata, FlashMLASparseMetadata.FP8SeperatePrefillDecode) + assert isinstance(fp8_metadata, FlashMLASparseMetadata.FP8SeparatePrefillDecode) num_decodes = fp8_metadata.num_decodes prefill_request_ids = None @@ -794,7 +794,7 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]): ) fp8_metadata = attn_metadata.fp8_extra_metadata - assert isinstance(fp8_metadata, FlashMLASparseMetadata.FP8SeperatePrefillDecode) + assert isinstance(fp8_metadata, FlashMLASparseMetadata.FP8SeparatePrefillDecode) def _fp8_decode(q: torch.Tensor, topk_indices: torch.Tensor) -> torch.Tensor: # Reshape q: (num_decode_tokens, num_heads, head_dim)