[Misc] Fix typo: seperator -> separator in flashmla_sparse.py (#32411)

Signed-off-by: Guofang Tang <tinggofun@gmail.com>
Co-authored-by: Guofang Tang <tinggofun@gmail.com>
This commit is contained in:
Guofang.Tang
2026-01-17 20:18:30 +08:00
committed by GitHub
parent 1646fea672
commit 2b99f210f5

View File

@@ -149,7 +149,7 @@ class FlashMLASparseMetadata(AttentionMetadata):
cache_lens: torch.Tensor
@dataclass
class FP8SeperatePrefillDecode:
class FP8SeparatePrefillDecode:
@dataclass
class Decode:
kernel_metadata: "FlashMLASparseMetadata.FP8KernelMetadata"
@@ -196,7 +196,7 @@ class FlashMLASparseMetadata(AttentionMetadata):
decode: Decode | None = None
prefill: Prefill | None = None
fp8_extra_metadata: FP8SeperatePrefillDecode | FP8KernelMetadata | None = None
fp8_extra_metadata: FP8SeparatePrefillDecode | FP8KernelMetadata | None = None
fp8_use_mixed_batch: bool = False
@@ -485,7 +485,7 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad
def _build_fp8_separate_prefill_decode(
self,
common_attn_metadata: CommonAttentionMetadata,
) -> "FlashMLASparseMetadata.FP8SeperatePrefillDecode":
) -> "FlashMLASparseMetadata.FP8SeparatePrefillDecode":
num_tokens = common_attn_metadata.num_actual_tokens
(num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens) = (
@@ -496,7 +496,7 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad
)
)
FP8Meta = FlashMLASparseMetadata.FP8SeperatePrefillDecode
FP8Meta = FlashMLASparseMetadata.FP8SeparatePrefillDecode
fp8_metadata = FP8Meta(
num_decodes=num_decodes,
num_prefills=num_prefills,
@@ -659,7 +659,7 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad
req_id_per_token = self.req_id_per_token_buffer[:num_tokens]
fp8_extra_metadata: (
FlashMLASparseMetadata.FP8SeperatePrefillDecode
FlashMLASparseMetadata.FP8SeparatePrefillDecode
| FlashMLASparseMetadata.FP8KernelMetadata
| None
) = None
@@ -765,7 +765,7 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]):
attn_metadata: FlashMLASparseMetadata,
) -> torch.Tensor:
fp8_metadata = attn_metadata.fp8_extra_metadata
assert isinstance(fp8_metadata, FlashMLASparseMetadata.FP8SeperatePrefillDecode)
assert isinstance(fp8_metadata, FlashMLASparseMetadata.FP8SeparatePrefillDecode)
num_decodes = fp8_metadata.num_decodes
prefill_request_ids = None
@@ -794,7 +794,7 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]):
)
fp8_metadata = attn_metadata.fp8_extra_metadata
assert isinstance(fp8_metadata, FlashMLASparseMetadata.FP8SeperatePrefillDecode)
assert isinstance(fp8_metadata, FlashMLASparseMetadata.FP8SeparatePrefillDecode)
def _fp8_decode(q: torch.Tensor, topk_indices: torch.Tensor) -> torch.Tensor:
# Reshape q: (num_decode_tokens, num_heads, head_dim)