Use module-level Blackwell flag in compressor (works during torch.compile)

This commit is contained in:
2026-05-19 17:37:26 +00:00
parent 8cf6ac3e8c
commit 4f02113aa0

View File

@@ -15,6 +15,15 @@ from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear,
)
from vllm.platforms import current_platform
# Check at module load time if we're on Blackwell
_IS_BLACKWELL = False
try:
_cap = current_platform.get_device_capability()
if _cap is not None and _cap.major >= 10:
_IS_BLACKWELL = True
except Exception:
pass
from vllm.triton_utils import tl, triton
from vllm.v1.attention.backend import (
AttentionBackend,
@@ -343,8 +352,7 @@ class DeepseekCompressor(nn.Module):
# 2. Our Blackwell attention path handles everything separately
# Instead, we just save the state (done above) and let the attention
# path handle compression + RoPE + cache write + attention.
cap = current_platform.get_device_capability()
if cap is not None and cap.major >= 10:
if _IS_BLACKWELL:
# Blackwell: state is already saved, skip fused kernel
return