Force TRTLLM attention for gpt-oss on SM100 (#22678)
Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
@@ -523,14 +523,17 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
num_kv_heads = self.kv_cache_spec.num_kv_heads
|
||||
head_dim = self.kv_cache_spec.head_size
|
||||
|
||||
# Check if any layer uses sinks (requires TRTLLM attention)
|
||||
has_sinks = self.global_hyperparameters.has_sinks
|
||||
|
||||
# currently prefill trtllm attention does not support fp8 kv cache
|
||||
prefill_use_trtllm = not cache_dtype.startswith("fp8") \
|
||||
and use_trtllm_attention(
|
||||
num_prefill_tokens, max_seq_len, cache_dtype,
|
||||
num_qo_heads, num_kv_heads, head_dim)
|
||||
num_qo_heads, num_kv_heads, head_dim, has_sinks)
|
||||
decode_use_trtllm = use_trtllm_attention(
|
||||
num_decode_tokens, max_seq_len, cache_dtype,
|
||||
num_qo_heads, num_kv_heads, head_dim)
|
||||
num_qo_heads, num_kv_heads, head_dim, has_sinks)
|
||||
|
||||
attn_metadata = FlashInferMetadata(
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
@@ -642,9 +645,9 @@ class FlashInferImpl(AttentionImpl):
|
||||
f"heads in the layer. Expected {num_heads}, but got "
|
||||
f"{sinks.shape[0]}."
|
||||
)
|
||||
# Cast sinks to float32 if needed (FlashInfer requirement)
|
||||
if sinks.dtype != torch.float32:
|
||||
raise ValueError("Sinks must be of type float32, but got "
|
||||
f"{sinks.dtype}.")
|
||||
sinks = sinks.to(torch.float32)
|
||||
self.sinks = sinks
|
||||
|
||||
def forward(
|
||||
|
||||
@@ -285,6 +285,7 @@ class PerLayerParameters:
|
||||
window_left: int
|
||||
logits_soft_cap: Optional[float]
|
||||
sm_scale: float
|
||||
has_sinks: bool = False
|
||||
|
||||
|
||||
def get_per_layer_parameters(
|
||||
@@ -307,9 +308,11 @@ def get_per_layer_parameters(
|
||||
window_left = window_size[0] if window_size is not None else -1
|
||||
logits_soft_cap = getattr(impl, "logits_soft_cap", None)
|
||||
sm_scale = impl.scale
|
||||
has_sinks = getattr(impl, "sinks", None) is not None
|
||||
|
||||
per_layer_params[key] = PerLayerParameters(window_left,
|
||||
logits_soft_cap, sm_scale)
|
||||
logits_soft_cap, sm_scale,
|
||||
has_sinks)
|
||||
|
||||
return per_layer_params
|
||||
|
||||
|
||||
Reference in New Issue
Block a user