Fix imports in vLLM codepaths test
This commit is contained in:
@@ -98,11 +98,10 @@ def causal_prefill_attention(q, kv, scale):
|
||||
def main():
|
||||
"""Test the exact csa_attention.py code paths used in the container."""
|
||||
from cutedsl.blackwell_attention import (
|
||||
blackwell_attention_kv_write,
|
||||
blackwell_attention_decode,
|
||||
blackwell_attention_forward,
|
||||
apply_gptj_rope,
|
||||
apply_inv_gptj_rope,
|
||||
)
|
||||
# Also import the vLLM patch version
|
||||
# Import the vLLM patch version (the actual code used in the container)
|
||||
sys.path.insert(0, os.path.join(REPO, "vllm", "patches", "layers"))
|
||||
from csa_attention import (
|
||||
fused_qnorm_rope_kv_insert_py,
|
||||
@@ -110,6 +109,7 @@ def main():
|
||||
blackwell_attention_decode as vllm_decode,
|
||||
kv_quantize_fp8 as vllm_kv_quantize,
|
||||
kv_dequantize_fp8 as vllm_kv_dequantize,
|
||||
causal_prefill_attention,
|
||||
)
|
||||
|
||||
torch.cuda.set_device(0)
|
||||
|
||||
Reference in New Issue
Block a user