Revert "[BugFix][AMD] Compatible patch for latest AITER(05/07/2025)" (#17910)

This commit is contained in:
Michael Goin
2025-05-09 09:58:18 -06:00
committed by GitHub
parent 6e5595ca39
commit 85b72cb7b1
4 changed files with 23 additions and 54 deletions

View File

@@ -1213,9 +1213,9 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
attn_output, attn_softmax_lse = \ attn_output, attn_softmax_lse = \
self._flash_attn_varlen_diff_headdims( self._flash_attn_varlen_diff_headdims(
q, q=q,
k, k=k,
v, v=v,
cu_seqlens_q=prefill_metadata.query_start_loc, cu_seqlens_q=prefill_metadata.query_start_loc,
cu_seqlens_k=prefill_metadata.context_chunk_cu_seq_lens[i], cu_seqlens_k=prefill_metadata.context_chunk_cu_seq_lens[i],
max_seqlen_q=prefill_metadata.max_query_len, max_seqlen_q=prefill_metadata.max_query_len,
@@ -1267,9 +1267,9 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
output = self._flash_attn_varlen_diff_headdims( output = self._flash_attn_varlen_diff_headdims(
q, q=q,
k, k=k,
v, v=v,
cu_seqlens_q=prefill_metadata.query_start_loc, cu_seqlens_q=prefill_metadata.query_start_loc,
cu_seqlens_k=prefill_metadata.query_start_loc, cu_seqlens_k=prefill_metadata.query_start_loc,
max_seqlen_q=prefill_metadata.max_prefill_seq_len, max_seqlen_q=prefill_metadata.max_prefill_seq_len,

View File

@@ -53,7 +53,7 @@ class AiterMLABackend(MLACommonBackend):
@dataclass @dataclass
class AiterMLAMetadata(MLACommonMetadata): class AiterMLAMetadata(MLACommonMetadata):
# The following 5 tensors are for current version of AITER MLA # The following 4 tensors are for current version of AITER MLA
block_table_bound: Optional[torch.Tensor] = None block_table_bound: Optional[torch.Tensor] = None
# The indptr of the paged kv cache, shape: [batch_size + 1] # The indptr of the paged kv cache, shape: [batch_size + 1]
paged_kv_indptr: Optional[torch.Tensor] = None paged_kv_indptr: Optional[torch.Tensor] = None
@@ -63,10 +63,6 @@ class AiterMLAMetadata(MLACommonMetadata):
# the paged kv cache, shape: [batch_size] # the paged kv cache, shape: [batch_size]
paged_kv_last_page_lens: Optional[torch.Tensor] = None paged_kv_last_page_lens: Optional[torch.Tensor] = None
# This is just to make new AITER MLA API work
# -- MTP support is not added yet.
qo_indptr: Optional[torch.Tensor] = None
@property @property
def prefill_metadata(self): def prefill_metadata(self):
prefill_metadata = super().prefill_metadata prefill_metadata = super().prefill_metadata
@@ -78,7 +74,6 @@ class AiterMLAMetadata(MLACommonMetadata):
prefill_metadata\ prefill_metadata\
.paged_kv_last_page_lens = self.paged_kv_last_page_lens .paged_kv_last_page_lens = self.paged_kv_last_page_lens
prefill_metadata.block_table_bound = self.block_table_bound prefill_metadata.block_table_bound = self.block_table_bound
prefill_metadata.qo_indptr = self.qo_indptr
# update the cache # update the cache
self._cached_prefill_metadata = self.__class__( self._cached_prefill_metadata = self.__class__(
@@ -98,7 +93,6 @@ class AiterMLAMetadata(MLACommonMetadata):
decode_metadata\ decode_metadata\
.paged_kv_last_page_lens = self.paged_kv_last_page_lens .paged_kv_last_page_lens = self.paged_kv_last_page_lens
decode_metadata.block_table_bound = self.block_table_bound decode_metadata.block_table_bound = self.block_table_bound
decode_metadata.qo_indptr = self.qo_indptr
# update the cache # update the cache
self._cached_decode_metadata = self.__class__( self._cached_decode_metadata = self.__class__(
@@ -142,7 +136,6 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
self.paged_kv_indptr: list[int] = [0] self.paged_kv_indptr: list[int] = [0]
self.paged_kv_last_page_lens: list[int] = [] self.paged_kv_last_page_lens: list[int] = []
self.total_blocks = 0 self.total_blocks = 0
self.qo_indptr: list[int] = [0]
def _add_seq_group(self, inter_data, chunked_prefill_enabled: bool, def _add_seq_group(self, inter_data, chunked_prefill_enabled: bool,
prefix_cache_hit: bool): prefix_cache_hit: bool):
@@ -215,7 +208,6 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
self.paged_kv_indices.extend(block_table[:block_table_bound]) self.paged_kv_indices.extend(block_table[:block_table_bound])
self.paged_kv_indptr.append(self.paged_kv_indptr[-1] + self.paged_kv_indptr.append(self.paged_kv_indptr[-1] +
block_table_bound) block_table_bound)
self.qo_indptr.append(self.qo_indptr[-1] + 1)
last_page_len = seq_len % self.block_size last_page_len = seq_len % self.block_size
if last_page_len == 0: if last_page_len == 0:
@@ -234,8 +226,6 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
self.paged_kv_indptr.extend([last_paged_kv_indptr] * self.paged_kv_indptr.extend([last_paged_kv_indptr] *
cuda_graph_pad_size) cuda_graph_pad_size)
self.paged_kv_last_page_lens.extend([0] * cuda_graph_pad_size) self.paged_kv_last_page_lens.extend([0] * cuda_graph_pad_size)
last_qo_indptr = self.qo_indptr[-1]
self.qo_indptr.extend([last_qo_indptr] * cuda_graph_pad_size)
# For current version of AITER MLA # For current version of AITER MLA
if len(self.paged_kv_indptr) > 0: if len(self.paged_kv_indptr) > 0:
@@ -255,22 +245,16 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
1, 1,
device=device, device=device,
dtype=torch.int) dtype=torch.int)
qo_indptr = torch.tensor(self.qo_indptr,
device=device,
dtype=torch.int)
else: else:
paged_kv_indices_tensor = None paged_kv_indices_tensor = None
paged_kv_indptr_tensor = None paged_kv_indptr_tensor = None
paged_kv_last_page_lens_tensor = None paged_kv_last_page_lens_tensor = None
block_table_bound_tensor = None block_table_bound_tensor = None
qo_indptr = None
metadata.paged_kv_indptr = paged_kv_indptr_tensor metadata.paged_kv_indptr = paged_kv_indptr_tensor
metadata.paged_kv_indices = paged_kv_indices_tensor metadata.paged_kv_indices = paged_kv_indices_tensor
metadata.paged_kv_last_page_lens = paged_kv_last_page_lens_tensor metadata.paged_kv_last_page_lens = paged_kv_last_page_lens_tensor
metadata.block_table_bound = block_table_bound_tensor metadata.block_table_bound = block_table_bound_tensor
metadata.qo_indptr = qo_indptr
return metadata return metadata
@@ -279,17 +263,14 @@ class AiterMLAState(MLACommonState[AiterMLAMetadata]):
@contextmanager @contextmanager
def graph_capture(self, max_batch_size: int): def graph_capture(self, max_batch_size: int):
kv_indices, kv_indptr, last_page_lens, qo_indptr = \ kv_indices, kv_indptr, last_page_lens = get_aiter_mla_metadata(
get_aiter_mla_metadata( max_batch_size=max_batch_size,
max_batch_size=max_batch_size, block_size=self.runner.block_size,
block_size=self.runner.block_size, max_block_per_batch=self.runner.get_max_block_per_batch(),
max_block_per_batch=\ device=self.runner.device)
self.runner.get_max_block_per_batch(),
device=self.runner.device)
self._paged_kv_indices_tensor = kv_indices self._paged_kv_indices_tensor = kv_indices
self._paged_kv_indptr_tensor = kv_indptr self._paged_kv_indptr_tensor = kv_indptr
self._paged_kv_last_page_lens_tensor = last_page_lens self._paged_kv_last_page_lens_tensor = last_page_lens
self._qo_indptr_tensor = qo_indptr
with super().graph_capture(max_batch_size): with super().graph_capture(max_batch_size):
yield yield
@@ -297,7 +278,6 @@ class AiterMLAState(MLACommonState[AiterMLAMetadata]):
del self._paged_kv_indices_tensor del self._paged_kv_indices_tensor
del self._paged_kv_indptr_tensor del self._paged_kv_indptr_tensor
del self._paged_kv_last_page_lens_tensor del self._paged_kv_last_page_lens_tensor
del self._qo_indptr_tensor
def graph_capture_get_metadata_for_batch( def graph_capture_get_metadata_for_batch(
self, self,
@@ -311,12 +291,10 @@ class AiterMLAState(MLACommonState[AiterMLAMetadata]):
paged_kv_indices = self._paged_kv_indices_tensor paged_kv_indices = self._paged_kv_indices_tensor
paged_kv_last_page_lens = self._paged_kv_last_page_lens_tensor[: paged_kv_last_page_lens = self._paged_kv_last_page_lens_tensor[:
batch_size] batch_size]
qo_indptr = self._qo_indptr_tensor[:batch_size + 1]
metadata.paged_kv_indptr = paged_kv_indptr metadata.paged_kv_indptr = paged_kv_indptr
metadata.paged_kv_indices = paged_kv_indices metadata.paged_kv_indices = paged_kv_indices
metadata.paged_kv_last_page_lens = paged_kv_last_page_lens metadata.paged_kv_last_page_lens = paged_kv_last_page_lens
metadata.qo_indptr = qo_indptr
return metadata return metadata
@@ -333,7 +311,6 @@ class AiterMLAState(MLACommonState[AiterMLAMetadata]):
input_buffers[ input_buffers[
"paged_kv_last_page_lens"] = attn_metadata.\ "paged_kv_last_page_lens"] = attn_metadata.\
decode_metadata.paged_kv_last_page_lens decode_metadata.paged_kv_last_page_lens
input_buffers['qo_indptr'] = attn_metadata.qo_indptr
return input_buffers return input_buffers
@@ -353,8 +330,6 @@ class AiterMLAState(MLACommonState[AiterMLAMetadata]):
input_buffers["paged_kv_last_page_lens"].copy_( input_buffers["paged_kv_last_page_lens"].copy_(
attn_metadata.decode_metadata.paged_kv_last_page_lens, attn_metadata.decode_metadata.paged_kv_last_page_lens,
non_blocking=True) non_blocking=True)
input_buffers["qo_indptr"].copy_(
attn_metadata.decode_metadata.qo_indptr, non_blocking=True)
class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]): class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
@@ -395,9 +370,11 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
softmax_scale: float, return_softmax_lse: bool, softmax_scale: float, return_softmax_lse: bool,
**kwargs) -> Union[tuple[torch.Tensor, ...], torch.Tensor]: **kwargs) -> Union[tuple[torch.Tensor, ...], torch.Tensor]:
output = self.flash_attn_varlen_func( output = self.flash_attn_varlen_func(
q, q=q,
k, k=k,
v, v=v,
softmax_scale=softmax_scale,
return_lse=return_softmax_lse,
**kwargs, **kwargs,
) )
@@ -417,7 +394,7 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
B = q_nope.shape[0] B = q_nope.shape[0]
q = torch.cat([q_nope, q_pe], dim=-1) q = torch.cat([q_nope, q_pe], dim=-1)
o = torch.empty(B, o = torch.zeros(B,
self.num_heads, self.num_heads,
self.kv_lora_rank, self.kv_lora_rank,
dtype=q.dtype, dtype=q.dtype,
@@ -426,8 +403,6 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
kv_buffer = kv_c_and_k_pe_cache.unsqueeze(2) kv_buffer = kv_c_and_k_pe_cache.unsqueeze(2)
aiter_mla_decode_fwd(q, kv_buffer, o, self.scale, aiter_mla_decode_fwd(q, kv_buffer, o, self.scale,
attn_metadata.qo_indptr,
attn_metadata.max_query_len,
attn_metadata.paged_kv_indptr, attn_metadata.paged_kv_indptr,
attn_metadata.paged_kv_indices, attn_metadata.paged_kv_indices,
attn_metadata.paged_kv_last_page_lens) attn_metadata.paged_kv_last_page_lens)

View File

@@ -20,8 +20,7 @@ def get_aiter_mla_metadata(max_batch_size: int, block_size: int,
paged_kv_last_page_lens = torch.full((max_batch_size, ), paged_kv_last_page_lens = torch.full((max_batch_size, ),
block_size, block_size,
dtype=torch.int32) dtype=torch.int32)
qo_indptr = torch.zeros(max_batch_size + 1, dtype=torch.int, device=device) return paged_kv_indices, paged_kv_indptr, paged_kv_last_page_lens
return paged_kv_indices, paged_kv_indptr, paged_kv_last_page_lens, qo_indptr
def aiter_mla_decode_fwd( def aiter_mla_decode_fwd(
@@ -29,8 +28,6 @@ def aiter_mla_decode_fwd(
kv_buffer: torch.Tensor, kv_buffer: torch.Tensor,
o: torch.Tensor, o: torch.Tensor,
sm_scale: float, sm_scale: float,
qo_indptr: torch.Tensor,
max_seqlen_qo: int,
kv_indptr: Optional[torch.Tensor] = None, kv_indptr: Optional[torch.Tensor] = None,
kv_indices: Optional[torch.Tensor] = None, kv_indices: Optional[torch.Tensor] = None,
kv_last_page_lens: Optional[torch.Tensor] = None, kv_last_page_lens: Optional[torch.Tensor] = None,
@@ -63,11 +60,9 @@ def mla_decode_fwd_impl(
mla_decode_fwd(q, mla_decode_fwd(q,
kv_buffer.view(-1, 1, 1, q.shape[-1]), kv_buffer.view(-1, 1, 1, q.shape[-1]),
o, o,
qo_indptr,
kv_indptr, kv_indptr,
kv_indices, kv_indices,
kv_last_page_lens, kv_last_page_lens,
max_seqlen_qo,
sm_scale=sm_scale, sm_scale=sm_scale,
logit_cap=logit_cap) logit_cap=logit_cap)

View File

@@ -123,11 +123,10 @@ def rocm_aiter_fmoe_fp8_blockscale_g1u1_impl(
fmoe_fp8_blockscale_g1u1(out_asm, a1, w1, w2, sorted_token_ids, fmoe_fp8_blockscale_g1u1(out_asm, a1, w1, w2, sorted_token_ids,
sorted_weight_buf, sorted_expert_ids, sorted_weight_buf, sorted_expert_ids,
num_valid_ids, topk, num_valid_ids, topk, w1_scale.view(local_E, -1),
a1_scale.t().contiguous(), w2_scale.view(local_E, -1),
w1_scale.view(local_E, -1), a1_scale.t().contiguous(), *block_shape,
w2_scale.view(local_E, smooth_scale)
-1), *block_shape, smooth_scale)
return out_asm return out_asm