[Kernel][TPU][ragged-paged-attn] vLLM code change for PR#8896 (#15659)
Signed-off-by: Yarong Mu <ymu@google.com>
This commit is contained in:
@@ -861,12 +861,11 @@ class TPUModelRunner:
|
||||
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
|
||||
dtype = kv_cache_spec.dtype
|
||||
|
||||
tpu_k_cache = torch.zeros(kv_cache_shape,
|
||||
dtype=dtype,
|
||||
device=self.device)
|
||||
tpu_v_cache = torch.zeros_like(tpu_k_cache)
|
||||
tpu_kv_cache = torch.zeros(kv_cache_shape,
|
||||
dtype=dtype,
|
||||
device=self.device)
|
||||
|
||||
kv_caches[layer_name] = (tpu_k_cache, tpu_v_cache)
|
||||
kv_caches[layer_name] = tpu_kv_cache
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -893,7 +892,7 @@ class ModelWrapperV1(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: list[tuple[torch.Tensor, torch.Tensor]],
|
||||
kv_caches: list[torch.Tensor],
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Executes the forward pass of the model.
|
||||
|
||||
@@ -136,10 +136,10 @@ class TPUWorker:
|
||||
|
||||
# Use an empty tensor instead of `None`` to force Dynamo to pass
|
||||
# it by reference, rather by specializing on the value ``None``.
|
||||
tpu_k_cache = torch.tensor([], dtype=dtype, device=self.device)
|
||||
tpu_v_cache = torch.tensor([], dtype=dtype, device=self.device)
|
||||
|
||||
kv_caches[layer_name] = (tpu_k_cache, tpu_v_cache)
|
||||
tpu_kv_cache = torch.tensor([],
|
||||
dtype=dtype,
|
||||
device=self.device)
|
||||
kv_caches[layer_name] = tpu_kv_cache
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
Reference in New Issue
Block a user