[Kernel][Model] Varlen prefill + Prefill chunking support for mamba kernels and Jamba model (#8533)

This commit is contained in:
Mor Zusman
2024-09-30 00:35:58 +03:00
committed by GitHub
parent 6c9ba48fde
commit f13a07b1f8
13 changed files with 1176 additions and 894 deletions

View File

@@ -138,42 +138,47 @@ class JambaMambaMixer(nn.Module):
self.c_layernorm = RMSNorm(self.ssm_state_size,
eps=config.rms_norm_eps)
def mamba_forward(self,
hidden_states: torch.Tensor,
cache_params: MambaCacheParams = None):
def forward(self, hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata, conv_state: torch.Tensor,
ssm_state: torch.Tensor):
# 1. Gated MLP's linear projection
projected_states = self.in_proj(hidden_states)[0].transpose(1, 2)
hidden_states, gate = projected_states.chunk(2, dim=1)
projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1)
hidden_states, gate = projected_states.chunk(2, dim=-2)
# 2. Convolution sequence transformation
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0),
self.conv1d.weight.size(2))
if cache_params is not None and not cache_params.is_prompt:
hidden_states = causal_conv1d_update(
hidden_states.squeeze(-1),
cache_params.conv_state,
conv_weights,
self.conv1d.bias,
self.activation,
)
hidden_states = hidden_states.unsqueeze(-1)
else:
if cache_params is not None:
conv_states = nn.functional.pad(
hidden_states,
(self.conv_kernel_size - hidden_states.shape[-1], 0))
cache_params.conv_state.copy_(conv_states)
hidden_states, _ = causal_conv1d_fn(
if attn_metadata.query_start_loc is not None \
and attn_metadata.context_lens_tensor is not None:
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seq_len ---------------------|
# |-- query_len ---|
hidden_states = causal_conv1d_fn(
hidden_states,
conv_weights,
self.conv1d.bias,
activation=self.activation,
conv_states=conv_state,
has_initial_state=attn_metadata.context_lens_tensor > 0,
query_start_loc=attn_metadata.query_start_loc)
else:
hidden_states = causal_conv1d_update(
hidden_states.transpose(0, 1),
conv_state,
conv_weights,
self.conv1d.bias,
self.activation,
)
hidden_states = hidden_states.transpose(0, 1)
# 3. State Space Model sequence transformation
# 3.a. input varying initialization of time_step, B and C
ssm_parameters = self.x_proj(hidden_states.transpose(1, 2))[0]
ssm_parameters = self.x_proj(hidden_states.transpose(-2, -1))[0]
time_step, B, C = torch.split(
ssm_parameters,
@@ -184,72 +189,46 @@ class JambaMambaMixer(nn.Module):
B = self.b_layernorm(B.contiguous())
C = self.c_layernorm(C.contiguous())
discrete_time_step = self.dt_proj(time_step)[0].transpose(1, 2)
discrete_time_step = self.dt_proj(time_step)[0].transpose(-2, -1)
# 3.c perform the recurrence y ← SSM(A, B, C)(x)
time_proj_bias = (self.dt_proj.bias.float() if hasattr(
self.dt_proj, "bias") else None)
if cache_params is not None and not cache_params.is_prompt:
scan_outputs = selective_state_update(
cache_params.ssm_state,
hidden_states[..., 0],
discrete_time_step[..., 0],
self.A,
B[:, 0],
C[:, 0],
self.D,
gate[..., 0],
time_proj_bias,
dt_softplus=True,
).unsqueeze(-1)
else:
scan_outputs, ssm_state = selective_scan_fn(
if attn_metadata.query_start_loc is not None \
and attn_metadata.context_lens_tensor is not None:
scan_outputs = selective_scan_fn(
hidden_states,
ssm_state,
discrete_time_step,
self.A,
B.transpose(1, 2),
C.transpose(1, 2),
B.transpose(-2, -1),
C.transpose(-2, -1),
self.D.float(),
gate,
time_proj_bias,
delta_softplus=True,
return_last_state=True,
has_initial_state=attn_metadata.context_lens_tensor > 0,
query_start_loc=attn_metadata.query_start_loc)
else:
scan_outputs = selective_state_update(
ssm_state,
hidden_states.transpose(0, 1),
discrete_time_step.transpose(0, 1),
self.A,
B,
C,
self.D,
gate.transpose(0, 1),
time_proj_bias,
dt_softplus=True,
)
if ssm_state is not None and cache_params is not None:
cache_params.ssm_state.copy_(ssm_state)
scan_outputs = scan_outputs.transpose(0, 1)
# 4. Final linear projection
contextualized_states = self.out_proj(scan_outputs.transpose(1, 2))[0]
contextualized_states = self.out_proj(scan_outputs.transpose(-2,
-1))[0]
return contextualized_states
def forward(
self,
hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata,
conv_state: torch.Tensor,
ssm_state: torch.Tensor,
):
if attn_metadata.prefill_metadata is not None:
offset = 0
for i, prompt_len in enumerate(
attn_metadata.prefill_metadata.seq_lens):
cache = MambaCacheParams(True,
conv_state=conv_state[i].unsqueeze(0),
ssm_state=ssm_state[i].unsqueeze(0))
hidden_states[offset:offset + prompt_len].copy_(
self.mamba_forward(hidden_states[offset:offset +
prompt_len].unsqueeze(0),
cache_params=cache)[0])
offset += prompt_len
else:
cache = MambaCacheParams(False,
conv_state=conv_state,
ssm_state=ssm_state)
hidden_states = self.mamba_forward(hidden_states.unsqueeze(1),
cache_params=cache)
hidden_states = hidden_states.squeeze(1)
return hidden_states
class JambaMoE(nn.Module):
@@ -571,8 +550,6 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
lora_config: Optional[LoRAConfig] = None,
scheduler_config: Optional[SchedulerConfig] = None,
) -> None:
assert not scheduler_config.chunked_prefill_enabled, \
"Jamba currently does not support chunked prefill"
assert not cache_config.enable_prefix_caching, \
"Jamba currently does not support prefix caching"
@@ -616,18 +593,10 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
if "seqlen_agnostic_capture_inputs" not in kwargs:
# We get here only on Prefill/Eager mode runs
assert all(
key in kwargs
for key in ["request_ids_to_seq_ids", "finished_requests_ids"])
request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"]
finished_requests_ids = kwargs["finished_requests_ids"]
self._release_mamba_cache(finished_requests_ids)
batch_size = input_ids.shape[0]
if attn_metadata.prefill_metadata:
batch_size = len(request_ids_to_seq_ids)
mamba_cache = self._prepare_current_run_mamba_cache(
request_ids_to_seq_ids, batch_size, finished_requests_ids)
mamba_cache = self._release_finished_and_prepare_mamba_cache(
finished_requests_ids, request_ids_to_seq_ids)
else:
# CUDA graph capturing runs
mamba_cache = kwargs["seqlen_agnostic_capture_inputs"]
@@ -699,13 +668,15 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
def _prepare_current_run_mamba_cache(
self, request_ids_to_seq_ids: Dict[str, list[int]],
batch_size: int, finished_requests_ids: List[str]):
finished_requests_ids: List[str]
) -> Tuple[torch.Tensor, torch.Tensor]:
running_indices = []
request_ids_to_seq_ids_flatten = [
(req_id, seq_id)
for req_id, seq_ids in request_ids_to_seq_ids.items()
for seq_id in seq_ids
]
batch_size = len(request_ids_to_seq_ids_flatten)
for dest_index, (request_id,
seq_id) in enumerate(request_ids_to_seq_ids_flatten):
if request_id in finished_requests_ids:
@@ -769,22 +740,21 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
seq_ids2index.update({seq_id: to_index})
return
def _release_finished_and_prepare_mamba_cache(
self, finished_requests_ids,
request_ids_to_seq_ids) -> Tuple[torch.Tensor, torch.Tensor]:
self._release_mamba_cache(finished_requests_ids)
return self._prepare_current_run_mamba_cache(request_ids_to_seq_ids,
finished_requests_ids)
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
"""
Copy the relevant Mamba cache into the CUDA graph input buffer
that was provided during the capture runs
(JambaForCausalLM.mamba_gc_cache_buffer).
"""
assert all(
key in kwargs
for key in ["request_ids_to_seq_ids", "finished_requests_ids"])
finished_requests_ids = kwargs["finished_requests_ids"]
self._release_mamba_cache(finished_requests_ids)
request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"]
cg_batch_size = input_buffers['input_ids'].shape[0]
self._prepare_current_run_mamba_cache(request_ids_to_seq_ids,
cg_batch_size,
finished_requests_ids)
self._release_finished_and_prepare_mamba_cache(
kwargs["finished_requests_ids"], kwargs["request_ids_to_seq_ids"])
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
"""
@@ -819,7 +789,7 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
hidden_size = self.config.hidden_size
conv_state_shape = (
self.config.mamba_expand * hidden_size // world_size,
self.config.mamba_d_conv,
self.config.mamba_d_conv - 1,
)
temporal_state_shape = (
self.config.mamba_expand * self.config.hidden_size // world_size,