[TPU] make ptxla not imported when using tpu_commons (#23081)

Signed-off-by: Chengji Yao <chengjiyao@gmail.com>
Signed-off-by: Chengji Yao <chengjiyao@google.com>
Co-authored-by: Chengji Yao <chengjiyao@gmail.com>
This commit is contained in:
Chengji Yao
2025-08-18 20:46:42 -07:00
committed by GitHub
parent a4454e9401
commit e9d6a3db69
6 changed files with 94 additions and 78 deletions

View File

@@ -5,12 +5,6 @@ from dataclasses import dataclass
from typing import Optional
import torch
import torch_xla.core.xla_builder as xb
import torch_xla.experimental.custom_kernel # noqa: F401
# Required to register custom ops.
from torch.library import impl
from torch_xla._internal.jax_workarounds import requires_jax
from torch_xla.experimental.custom_kernel import XLA_LIB
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer, AttentionType)
@@ -37,6 +31,57 @@ TPU_STR_DTYPE_TO_TORCH_DTYPE = {
"uint8": torch.uint8,
}
try:
import tpu_commons # noqa: F401
except ImportError:
# Lazy import torch_xla
import torch_xla.core.xla_builder as xb
import torch_xla.experimental.custom_kernel # noqa: F401
from torch.library import impl
from torch_xla._internal.jax_workarounds import requires_jax
from torch_xla.experimental.custom_kernel import XLA_LIB
@requires_jax
def kv_cache_update_op_impl(kv: torch.Tensor, slot_mapping: torch.Tensor,
kv_cache: torch.Tensor,
num_kv_update_slices: torch.Tensor,
page_size: int, num_slices_per_block: int):
from vllm.attention.ops.pallas_kv_cache_update import kv_cache_update
new_kv_cache = xb.call_jax(
kv_cache_update,
(kv, slot_mapping, kv_cache, num_kv_update_slices), {
"page_size": page_size,
"num_slices_per_block": num_slices_per_block
})
return new_kv_cache
XLA_LIB.define(
"kv_cache_update_op(Tensor kv, Tensor slot_mapping," \
"Tensor kv_cache, Tensor num_kv_update_slices, int page_size," \
"int num_slices_per_block)" \
"-> Tensor", )
@impl(XLA_LIB, "kv_cache_update_op", "XLA")
def kv_cache_update_op_xla(kv: torch.Tensor, slot_mapping: torch.Tensor,
kv_cache: torch.Tensor,
num_kv_update_slices: torch.Tensor,
page_size: int,
num_slices_per_block: int) -> torch.Tensor:
new_kv_cache = kv_cache_update_op_impl(kv, slot_mapping, kv_cache,
num_kv_update_slices, page_size,
num_slices_per_block)
return new_kv_cache
@impl(XLA_LIB, "kv_cache_update_op", "CompositeExplicitAutograd")
def kv_cache_update_op_non_xla(kv: torch.Tensor,
slot_mapping: torch.Tensor,
kv_cache: torch.Tensor,
num_kv_update_slices: torch.Tensor,
page_size: int,
num_slices_per_block: int) -> torch.Tensor:
return kv_cache
class PallasAttentionBackend(AttentionBackend):
@@ -313,46 +358,6 @@ def write_to_kv_cache(
kv_cache.copy_(new_kv_cache)
@requires_jax
def kv_cache_update_op_impl(kv: torch.Tensor, slot_mapping: torch.Tensor,
kv_cache: torch.Tensor,
num_kv_update_slices: torch.Tensor, page_size: int,
num_slices_per_block: int):
from vllm.attention.ops.pallas_kv_cache_update import kv_cache_update
new_kv_cache = xb.call_jax(
kv_cache_update, (kv, slot_mapping, kv_cache, num_kv_update_slices), {
"page_size": page_size,
"num_slices_per_block": num_slices_per_block
})
return new_kv_cache
XLA_LIB.define(
"kv_cache_update_op(Tensor kv, Tensor slot_mapping, Tensor kv_cache," \
"Tensor num_kv_update_slices, int page_size, int num_slices_per_block)" \
"-> Tensor", )
@impl(XLA_LIB, "kv_cache_update_op", "XLA")
def kv_cache_update_op_xla(kv: torch.Tensor, slot_mapping: torch.Tensor,
kv_cache: torch.Tensor,
num_kv_update_slices: torch.Tensor, page_size: int,
num_slices_per_block: int) -> torch.Tensor:
new_kv_cache = kv_cache_update_op_impl(kv, slot_mapping, kv_cache,
num_kv_update_slices, page_size,
num_slices_per_block)
return new_kv_cache
@impl(XLA_LIB, "kv_cache_update_op", "CompositeExplicitAutograd")
def kv_cache_update_op_non_xla(kv: torch.Tensor, slot_mapping: torch.Tensor,
kv_cache: torch.Tensor,
num_kv_update_slices: torch.Tensor,
page_size: int,
num_slices_per_block: int) -> torch.Tensor:
return kv_cache
# We can move this function to a common utils file if it's also useful for other
# hardware.
def dtype_bits(dtype: torch.dtype):