Use module-level Blackwell flag in compressor (works during torch.compile)
This commit is contained in:
@@ -15,6 +15,15 @@ from vllm.model_executor.layers.linear import (
|
||||
MergedColumnParallelLinear,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
# Check at module load time if we're on Blackwell
|
||||
_IS_BLACKWELL = False
|
||||
try:
|
||||
_cap = current_platform.get_device_capability()
|
||||
if _cap is not None and _cap.major >= 10:
|
||||
_IS_BLACKWELL = True
|
||||
except Exception:
|
||||
pass
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionBackend,
|
||||
@@ -343,8 +352,7 @@ class DeepseekCompressor(nn.Module):
|
||||
# 2. Our Blackwell attention path handles everything separately
|
||||
# Instead, we just save the state (done above) and let the attention
|
||||
# path handle compression + RoPE + cache write + attention.
|
||||
cap = current_platform.get_device_capability()
|
||||
if cap is not None and cap.major >= 10:
|
||||
if _IS_BLACKWELL:
|
||||
# Blackwell: state is already saved, skip fused kernel
|
||||
return
|
||||
|
||||
|
||||
Reference in New Issue
Block a user