[BugFix] fix 3 issues: (1) using metadata for causal-conv1d, (2) indexing overflow in v1 vLLM, and (3) init_states in v0 (#20838)
Signed-off-by: Tuan M. Hoang-Trong <tmhoangt@us.ibm.com> Co-authored-by: Tuan M. Hoang-Trong <tmhoangt@us.ibm.com>
This commit is contained in:
committed by
GitHub
parent
ed10f3cea1
commit
f29fd8a7f8
@@ -573,8 +573,8 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
x = hidden_states_B_C_p.transpose(
|
||||
0, 1) # this is the form that causal-conv see
|
||||
if mamba2_metadata.cu_seqlen is None:
|
||||
mamba2_metadata = update_metadata(
|
||||
x, attn_metadata.query_start_loc, mamba2_metadata)
|
||||
mamba2_metadata = update_metadata(x, query_start_loc_p,
|
||||
mamba2_metadata)
|
||||
hidden_states_B_C_p = causal_conv1d_fn(
|
||||
x,
|
||||
conv_weights,
|
||||
@@ -583,6 +583,7 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
conv_states=conv_state,
|
||||
has_initial_state=has_initial_states_p,
|
||||
cache_indices=state_indices_tensor_p,
|
||||
metadata=mamba2_metadata,
|
||||
query_start_loc=query_start_loc_p).transpose(
|
||||
0, 1)[:num_prefill_tokens]
|
||||
|
||||
@@ -593,9 +594,14 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
initial_states = None
|
||||
if (has_initial_states_p is not None and prep_initial_states):
|
||||
# making a copy of the states
|
||||
initial_states = torch.where(
|
||||
has_initial_states_p[:, None, None, None],
|
||||
ssm_state[state_indices_tensor_p], 0)
|
||||
if envs.VLLM_USE_V1:
|
||||
initial_states = torch.where(
|
||||
has_initial_states_p[:, None, None, None],
|
||||
ssm_state[state_indices_tensor_p], 0)
|
||||
else:
|
||||
initial_states = torch.where(
|
||||
has_initial_states_p[:num_prefills, None, None, None],
|
||||
ssm_state[state_indices_tensor_p], 0)
|
||||
|
||||
scan_output, varlen_state = mamba_chunk_scan_combined(
|
||||
hidden_states_p.view(1, num_prefill_tokens,
|
||||
|
||||
Reference in New Issue
Block a user