[Kernel][Model] Varlen prefill + Prefill chunking support for mamba kernels and Jamba model (#8533)
This commit is contained in:
@@ -138,42 +138,47 @@ class JambaMambaMixer(nn.Module):
|
||||
self.c_layernorm = RMSNorm(self.ssm_state_size,
|
||||
eps=config.rms_norm_eps)
|
||||
|
||||
def mamba_forward(self,
|
||||
hidden_states: torch.Tensor,
|
||||
cache_params: MambaCacheParams = None):
|
||||
def forward(self, hidden_states: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata, conv_state: torch.Tensor,
|
||||
ssm_state: torch.Tensor):
|
||||
|
||||
# 1. Gated MLP's linear projection
|
||||
projected_states = self.in_proj(hidden_states)[0].transpose(1, 2)
|
||||
hidden_states, gate = projected_states.chunk(2, dim=1)
|
||||
projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1)
|
||||
hidden_states, gate = projected_states.chunk(2, dim=-2)
|
||||
|
||||
# 2. Convolution sequence transformation
|
||||
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0),
|
||||
self.conv1d.weight.size(2))
|
||||
if cache_params is not None and not cache_params.is_prompt:
|
||||
hidden_states = causal_conv1d_update(
|
||||
hidden_states.squeeze(-1),
|
||||
cache_params.conv_state,
|
||||
conv_weights,
|
||||
self.conv1d.bias,
|
||||
self.activation,
|
||||
)
|
||||
hidden_states = hidden_states.unsqueeze(-1)
|
||||
else:
|
||||
if cache_params is not None:
|
||||
conv_states = nn.functional.pad(
|
||||
hidden_states,
|
||||
(self.conv_kernel_size - hidden_states.shape[-1], 0))
|
||||
cache_params.conv_state.copy_(conv_states)
|
||||
|
||||
hidden_states, _ = causal_conv1d_fn(
|
||||
if attn_metadata.query_start_loc is not None \
|
||||
and attn_metadata.context_lens_tensor is not None:
|
||||
# |---------- N-1 iteration --------|
|
||||
# |---------------- N iteration ---------------------|
|
||||
# |- tokenA -|......................|-- newTokens ---|
|
||||
# |---------- context_len ----------|
|
||||
# |-------------------- seq_len ---------------------|
|
||||
# |-- query_len ---|
|
||||
hidden_states = causal_conv1d_fn(
|
||||
hidden_states,
|
||||
conv_weights,
|
||||
self.conv1d.bias,
|
||||
activation=self.activation,
|
||||
conv_states=conv_state,
|
||||
has_initial_state=attn_metadata.context_lens_tensor > 0,
|
||||
query_start_loc=attn_metadata.query_start_loc)
|
||||
else:
|
||||
hidden_states = causal_conv1d_update(
|
||||
hidden_states.transpose(0, 1),
|
||||
conv_state,
|
||||
conv_weights,
|
||||
self.conv1d.bias,
|
||||
self.activation,
|
||||
)
|
||||
hidden_states = hidden_states.transpose(0, 1)
|
||||
|
||||
# 3. State Space Model sequence transformation
|
||||
# 3.a. input varying initialization of time_step, B and C
|
||||
ssm_parameters = self.x_proj(hidden_states.transpose(1, 2))[0]
|
||||
ssm_parameters = self.x_proj(hidden_states.transpose(-2, -1))[0]
|
||||
|
||||
time_step, B, C = torch.split(
|
||||
ssm_parameters,
|
||||
@@ -184,72 +189,46 @@ class JambaMambaMixer(nn.Module):
|
||||
B = self.b_layernorm(B.contiguous())
|
||||
C = self.c_layernorm(C.contiguous())
|
||||
|
||||
discrete_time_step = self.dt_proj(time_step)[0].transpose(1, 2)
|
||||
discrete_time_step = self.dt_proj(time_step)[0].transpose(-2, -1)
|
||||
# 3.c perform the recurrence y ← SSM(A, B, C)(x)
|
||||
time_proj_bias = (self.dt_proj.bias.float() if hasattr(
|
||||
self.dt_proj, "bias") else None)
|
||||
if cache_params is not None and not cache_params.is_prompt:
|
||||
scan_outputs = selective_state_update(
|
||||
cache_params.ssm_state,
|
||||
hidden_states[..., 0],
|
||||
discrete_time_step[..., 0],
|
||||
self.A,
|
||||
B[:, 0],
|
||||
C[:, 0],
|
||||
self.D,
|
||||
gate[..., 0],
|
||||
time_proj_bias,
|
||||
dt_softplus=True,
|
||||
).unsqueeze(-1)
|
||||
else:
|
||||
scan_outputs, ssm_state = selective_scan_fn(
|
||||
|
||||
if attn_metadata.query_start_loc is not None \
|
||||
and attn_metadata.context_lens_tensor is not None:
|
||||
scan_outputs = selective_scan_fn(
|
||||
hidden_states,
|
||||
ssm_state,
|
||||
discrete_time_step,
|
||||
self.A,
|
||||
B.transpose(1, 2),
|
||||
C.transpose(1, 2),
|
||||
B.transpose(-2, -1),
|
||||
C.transpose(-2, -1),
|
||||
self.D.float(),
|
||||
gate,
|
||||
time_proj_bias,
|
||||
delta_softplus=True,
|
||||
return_last_state=True,
|
||||
has_initial_state=attn_metadata.context_lens_tensor > 0,
|
||||
query_start_loc=attn_metadata.query_start_loc)
|
||||
else:
|
||||
scan_outputs = selective_state_update(
|
||||
ssm_state,
|
||||
hidden_states.transpose(0, 1),
|
||||
discrete_time_step.transpose(0, 1),
|
||||
self.A,
|
||||
B,
|
||||
C,
|
||||
self.D,
|
||||
gate.transpose(0, 1),
|
||||
time_proj_bias,
|
||||
dt_softplus=True,
|
||||
)
|
||||
if ssm_state is not None and cache_params is not None:
|
||||
cache_params.ssm_state.copy_(ssm_state)
|
||||
scan_outputs = scan_outputs.transpose(0, 1)
|
||||
|
||||
# 4. Final linear projection
|
||||
contextualized_states = self.out_proj(scan_outputs.transpose(1, 2))[0]
|
||||
contextualized_states = self.out_proj(scan_outputs.transpose(-2,
|
||||
-1))[0]
|
||||
return contextualized_states
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
conv_state: torch.Tensor,
|
||||
ssm_state: torch.Tensor,
|
||||
):
|
||||
if attn_metadata.prefill_metadata is not None:
|
||||
offset = 0
|
||||
for i, prompt_len in enumerate(
|
||||
attn_metadata.prefill_metadata.seq_lens):
|
||||
cache = MambaCacheParams(True,
|
||||
conv_state=conv_state[i].unsqueeze(0),
|
||||
ssm_state=ssm_state[i].unsqueeze(0))
|
||||
hidden_states[offset:offset + prompt_len].copy_(
|
||||
self.mamba_forward(hidden_states[offset:offset +
|
||||
prompt_len].unsqueeze(0),
|
||||
cache_params=cache)[0])
|
||||
offset += prompt_len
|
||||
else:
|
||||
cache = MambaCacheParams(False,
|
||||
conv_state=conv_state,
|
||||
ssm_state=ssm_state)
|
||||
hidden_states = self.mamba_forward(hidden_states.unsqueeze(1),
|
||||
cache_params=cache)
|
||||
hidden_states = hidden_states.squeeze(1)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class JambaMoE(nn.Module):
|
||||
|
||||
@@ -571,8 +550,6 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
scheduler_config: Optional[SchedulerConfig] = None,
|
||||
) -> None:
|
||||
assert not scheduler_config.chunked_prefill_enabled, \
|
||||
"Jamba currently does not support chunked prefill"
|
||||
assert not cache_config.enable_prefix_caching, \
|
||||
"Jamba currently does not support prefix caching"
|
||||
|
||||
@@ -616,18 +593,10 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
|
||||
|
||||
if "seqlen_agnostic_capture_inputs" not in kwargs:
|
||||
# We get here only on Prefill/Eager mode runs
|
||||
assert all(
|
||||
key in kwargs
|
||||
for key in ["request_ids_to_seq_ids", "finished_requests_ids"])
|
||||
|
||||
request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"]
|
||||
finished_requests_ids = kwargs["finished_requests_ids"]
|
||||
self._release_mamba_cache(finished_requests_ids)
|
||||
batch_size = input_ids.shape[0]
|
||||
if attn_metadata.prefill_metadata:
|
||||
batch_size = len(request_ids_to_seq_ids)
|
||||
mamba_cache = self._prepare_current_run_mamba_cache(
|
||||
request_ids_to_seq_ids, batch_size, finished_requests_ids)
|
||||
mamba_cache = self._release_finished_and_prepare_mamba_cache(
|
||||
finished_requests_ids, request_ids_to_seq_ids)
|
||||
else:
|
||||
# CUDA graph capturing runs
|
||||
mamba_cache = kwargs["seqlen_agnostic_capture_inputs"]
|
||||
@@ -699,13 +668,15 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
|
||||
|
||||
def _prepare_current_run_mamba_cache(
|
||||
self, request_ids_to_seq_ids: Dict[str, list[int]],
|
||||
batch_size: int, finished_requests_ids: List[str]):
|
||||
finished_requests_ids: List[str]
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
running_indices = []
|
||||
request_ids_to_seq_ids_flatten = [
|
||||
(req_id, seq_id)
|
||||
for req_id, seq_ids in request_ids_to_seq_ids.items()
|
||||
for seq_id in seq_ids
|
||||
]
|
||||
batch_size = len(request_ids_to_seq_ids_flatten)
|
||||
for dest_index, (request_id,
|
||||
seq_id) in enumerate(request_ids_to_seq_ids_flatten):
|
||||
if request_id in finished_requests_ids:
|
||||
@@ -769,22 +740,21 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
|
||||
seq_ids2index.update({seq_id: to_index})
|
||||
return
|
||||
|
||||
def _release_finished_and_prepare_mamba_cache(
|
||||
self, finished_requests_ids,
|
||||
request_ids_to_seq_ids) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
self._release_mamba_cache(finished_requests_ids)
|
||||
return self._prepare_current_run_mamba_cache(request_ids_to_seq_ids,
|
||||
finished_requests_ids)
|
||||
|
||||
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
|
||||
"""
|
||||
Copy the relevant Mamba cache into the CUDA graph input buffer
|
||||
that was provided during the capture runs
|
||||
(JambaForCausalLM.mamba_gc_cache_buffer).
|
||||
"""
|
||||
assert all(
|
||||
key in kwargs
|
||||
for key in ["request_ids_to_seq_ids", "finished_requests_ids"])
|
||||
finished_requests_ids = kwargs["finished_requests_ids"]
|
||||
self._release_mamba_cache(finished_requests_ids)
|
||||
request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"]
|
||||
cg_batch_size = input_buffers['input_ids'].shape[0]
|
||||
self._prepare_current_run_mamba_cache(request_ids_to_seq_ids,
|
||||
cg_batch_size,
|
||||
finished_requests_ids)
|
||||
self._release_finished_and_prepare_mamba_cache(
|
||||
kwargs["finished_requests_ids"], kwargs["request_ids_to_seq_ids"])
|
||||
|
||||
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
|
||||
"""
|
||||
@@ -819,7 +789,7 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
|
||||
hidden_size = self.config.hidden_size
|
||||
conv_state_shape = (
|
||||
self.config.mamba_expand * hidden_size // world_size,
|
||||
self.config.mamba_d_conv,
|
||||
self.config.mamba_d_conv - 1,
|
||||
)
|
||||
temporal_state_shape = (
|
||||
self.config.mamba_expand * self.config.hidden_size // world_size,
|
||||
|
||||
Reference in New Issue
Block a user