diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index 500370d9f..766bc46ce 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -560,6 +560,11 @@ class RMSNormGated(CustomOp): activation=self.activation, ) + def forward_xpu( + self, x: torch.Tensor, z: torch.Tensor | None = None + ) -> torch.Tensor: + return self.forward_cuda(x, z) + class LayerNorm(nn.Module): """ diff --git a/vllm/model_executor/layers/mamba/gdn_linear_attn.py b/vllm/model_executor/layers/mamba/gdn_linear_attn.py index 9b95e00d2..bdbe8b4a3 100644 --- a/vllm/model_executor/layers/mamba/gdn_linear_attn.py +++ b/vllm/model_executor/layers/mamba/gdn_linear_attn.py @@ -262,6 +262,9 @@ class GatedDeltaNetAttention(PluggableLayer, MambaBase): else 0 ) self.gqa_interleaved_layout = gqa_interleaved_layout + self._forward_method = ( + self.forward_xpu if current_platform.is_xpu() else self.forward_cuda + ) # QKV self.conv_dim = self.key_dim * 2 + self.value_dim @@ -493,6 +496,13 @@ class GatedDeltaNetAttention(PluggableLayer, MambaBase): self, hidden_states: torch.Tensor, output: torch.Tensor, + ): + self._forward_method(hidden_states, output) + + def forward_cuda( + self, + hidden_states: torch.Tensor, + output: torch.Tensor, ): """ Forward pass with three parts: @@ -567,6 +577,90 @@ class GatedDeltaNetAttention(PluggableLayer, MambaBase): core_attn_out = rearrange(core_attn_out, "... h d -> ... (h d)") output[:num_tokens], _ = self.out_proj(core_attn_out) + def forward_xpu( + self, + hidden_states: torch.Tensor, + output: torch.Tensor, + ): + """ + Forward pass with three parts: + 1. Input projection + 2. Core attention (custom op) + 3. Output projection + """ + num_tokens = hidden_states.size(0) + + assert not hasattr(self, "in_proj_qkv"), "lora isn't supported on XPU." + + # ============================================================ + # Part 1: Input Projection + # ============================================================ + projected_states_qkvz, _ = self.in_proj_qkvz(hidden_states) + projected_states_ba, _ = self.in_proj_ba(hidden_states) + + # ============================================================ + # Part 2: Core Attention + # ============================================================ + forward_context = get_forward_context() + attn_metadata: AttentionMetadata = forward_context.attn_metadata + core_attn_out = torch.zeros( + (num_tokens, self.num_v_heads // self.tp_size, self.head_v_dim), + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + z = torch.empty_like(core_attn_out) + if attn_metadata is not None: + attn_metadata = attn_metadata[self.prefix] + + # TODO: xpu does not support this param yet + spec_sequence_masks = attn_metadata.spec_sequence_masks + assert spec_sequence_masks is None + + conv_weights = self.conv1d.weight.view( + self.conv1d.weight.size(0), self.conv1d.weight.size(2) + ) + + conv_state = self.kv_cache[0] + ssm_state = self.kv_cache[1] + + torch.ops._xpu_C.gdn_attention( + core_attn_out, + z, + projected_states_qkvz, + projected_states_ba, + self.num_k_heads, + self.num_v_heads, + self.head_k_dim, + self.head_v_dim, + conv_state=conv_state, + ssm_state=ssm_state, + conv_weights=conv_weights, + conv_bias=self.conv1d.bias, + activation=self.activation, + A_log=self.A_log, + dt_bias=self.dt_bias, + num_prefills=attn_metadata.num_prefills, + num_decodes=attn_metadata.num_decodes, + has_initial_state=attn_metadata.has_initial_state, + non_spec_query_start_loc=attn_metadata.non_spec_query_start_loc, + non_spec_state_indices_tensor=attn_metadata.non_spec_state_indices_tensor, + num_actual_tokens=attn_metadata.num_actual_tokens, + tp_size=self.tp_size, + reorder_input=not self.gqa_interleaved_layout, + ) + + # ============================================================ + # Part 3: Output Projection + # ============================================================ + z_shape_og = z.shape + # Reshape input data into 2D tensor + core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1]) + z = z.reshape(-1, z.shape[-1]) + core_attn_out = self.norm(core_attn_out, z) + core_attn_out = core_attn_out.reshape(z_shape_og) + core_attn_out = rearrange(core_attn_out, "... h d -> ... (h d)") + output[:num_tokens], _ = self.out_proj(core_attn_out) + def _warmup_prefill_kernels(self, mixed_qkv: torch.Tensor) -> None: """Warm up GDN prefill kernels during V1 profiling. diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py index 2a56ff5c6..ffc765257 100644 --- a/vllm/platforms/xpu.py +++ b/vllm/platforms/xpu.py @@ -218,6 +218,57 @@ class XPUPlatform(Platform): # ref. https://openucx.readthedocs.io/en/master/faq.html os.environ["UCX_MEMTYPE_CACHE"] = "n" + @classmethod + def update_block_size_for_backend(cls, vllm_config: "VllmConfig") -> None: + super().update_block_size_for_backend(vllm_config) + from vllm.config.vllm import get_layers_from_vllm_config + from vllm.model_executor.layers.attention_layer_base import ( + AttentionLayerBase, + ) + from vllm.utils.math_utils import cdiv + + cache_config = vllm_config.cache_config + # special fix for GDN since kernel only supports block size dividable by 64 + attn_layers = get_layers_from_vllm_config( + vllm_config, + AttentionLayerBase, # type: ignore[type-abstract] + ) + + kernel_block_size = None + for layer in attn_layers.values(): + b = layer.get_attn_backend() + if b.get_name() == "GDN_ATTN": + kernel_block_size = 64 + break + + if kernel_block_size is None: + return + new_block_size = ( + cdiv(cache_config.block_size, kernel_block_size) * kernel_block_size + ) + if new_block_size == cache_config.block_size: + return + + if cache_config.mamba_cache_mode == "align": + cache_config.mamba_block_size = new_block_size + original_mamba_page_size_padded = cache_config.mamba_page_size_padded + if cache_config.mamba_page_size_padded is not None: + attn_page_size_1_token = ( + cache_config.mamba_page_size_padded // cache_config.block_size + ) + cache_config.mamba_page_size_padded = ( + new_block_size * attn_page_size_1_token + ) + cache_config.block_size = new_block_size + logger.info( + "[XPU]Setting attention block size to %d tokens to ensure multiple of %d, " + "set mamba_page_size_padded to %d bytes accordingly, before was %d bytes.", + new_block_size, + kernel_block_size, + cache_config.mamba_page_size_padded, + original_mamba_page_size_padded, + ) + @classmethod def support_hybrid_kv_cache(cls) -> bool: return True