Fix imports in vLLM codepaths test

This commit is contained in:
2026-05-19 17:26:50 +00:00
parent 835e1a0590
commit facc6509e7

View File

@@ -98,11 +98,10 @@ def causal_prefill_attention(q, kv, scale):
def main():
"""Test the exact csa_attention.py code paths used in the container."""
from cutedsl.blackwell_attention import (
blackwell_attention_kv_write,
blackwell_attention_decode,
blackwell_attention_forward,
apply_gptj_rope,
apply_inv_gptj_rope,
)
# Also import the vLLM patch version
# Import the vLLM patch version (the actual code used in the container)
sys.path.insert(0, os.path.join(REPO, "vllm", "patches", "layers"))
from csa_attention import (
fused_qnorm_rope_kv_insert_py,
@@ -110,6 +109,7 @@ def main():
blackwell_attention_decode as vllm_decode,
kv_quantize_fp8 as vllm_kv_quantize,
kv_dequantize_fp8 as vllm_kv_dequantize,
causal_prefill_attention,
)
torch.cuda.set_device(0)