[Misc] Fix typo: seperator -> separator in flashmla_sparse.py (#32411)
Signed-off-by: Guofang Tang <tinggofun@gmail.com> Co-authored-by: Guofang Tang <tinggofun@gmail.com>
This commit is contained in:
@@ -149,7 +149,7 @@ class FlashMLASparseMetadata(AttentionMetadata):
|
||||
cache_lens: torch.Tensor
|
||||
|
||||
@dataclass
|
||||
class FP8SeperatePrefillDecode:
|
||||
class FP8SeparatePrefillDecode:
|
||||
@dataclass
|
||||
class Decode:
|
||||
kernel_metadata: "FlashMLASparseMetadata.FP8KernelMetadata"
|
||||
@@ -196,7 +196,7 @@ class FlashMLASparseMetadata(AttentionMetadata):
|
||||
decode: Decode | None = None
|
||||
prefill: Prefill | None = None
|
||||
|
||||
fp8_extra_metadata: FP8SeperatePrefillDecode | FP8KernelMetadata | None = None
|
||||
fp8_extra_metadata: FP8SeparatePrefillDecode | FP8KernelMetadata | None = None
|
||||
fp8_use_mixed_batch: bool = False
|
||||
|
||||
|
||||
@@ -485,7 +485,7 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad
|
||||
def _build_fp8_separate_prefill_decode(
|
||||
self,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
) -> "FlashMLASparseMetadata.FP8SeperatePrefillDecode":
|
||||
) -> "FlashMLASparseMetadata.FP8SeparatePrefillDecode":
|
||||
num_tokens = common_attn_metadata.num_actual_tokens
|
||||
|
||||
(num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens) = (
|
||||
@@ -496,7 +496,7 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad
|
||||
)
|
||||
)
|
||||
|
||||
FP8Meta = FlashMLASparseMetadata.FP8SeperatePrefillDecode
|
||||
FP8Meta = FlashMLASparseMetadata.FP8SeparatePrefillDecode
|
||||
fp8_metadata = FP8Meta(
|
||||
num_decodes=num_decodes,
|
||||
num_prefills=num_prefills,
|
||||
@@ -659,7 +659,7 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad
|
||||
req_id_per_token = self.req_id_per_token_buffer[:num_tokens]
|
||||
|
||||
fp8_extra_metadata: (
|
||||
FlashMLASparseMetadata.FP8SeperatePrefillDecode
|
||||
FlashMLASparseMetadata.FP8SeparatePrefillDecode
|
||||
| FlashMLASparseMetadata.FP8KernelMetadata
|
||||
| None
|
||||
) = None
|
||||
@@ -765,7 +765,7 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]):
|
||||
attn_metadata: FlashMLASparseMetadata,
|
||||
) -> torch.Tensor:
|
||||
fp8_metadata = attn_metadata.fp8_extra_metadata
|
||||
assert isinstance(fp8_metadata, FlashMLASparseMetadata.FP8SeperatePrefillDecode)
|
||||
assert isinstance(fp8_metadata, FlashMLASparseMetadata.FP8SeparatePrefillDecode)
|
||||
num_decodes = fp8_metadata.num_decodes
|
||||
|
||||
prefill_request_ids = None
|
||||
@@ -794,7 +794,7 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]):
|
||||
)
|
||||
|
||||
fp8_metadata = attn_metadata.fp8_extra_metadata
|
||||
assert isinstance(fp8_metadata, FlashMLASparseMetadata.FP8SeperatePrefillDecode)
|
||||
assert isinstance(fp8_metadata, FlashMLASparseMetadata.FP8SeparatePrefillDecode)
|
||||
|
||||
def _fp8_decode(q: torch.Tensor, topk_indices: torch.Tensor) -> torch.Tensor:
|
||||
# Reshape q: (num_decode_tokens, num_heads, head_dim)
|
||||
|
||||
Reference in New Issue
Block a user