[Bugfix][TPU] Fix tpu model runner testcase failure (#18810)
Signed-off-by: Carol Zheng <cazheng@google.com>
This commit is contained in:
@@ -175,11 +175,21 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.kv_caches: list[torch.Tensor] = []
|
||||
# req_id -> (input_id -> encoder_output)
|
||||
self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {}
|
||||
# self.input_batch: InputBatch # Persistent batch.
|
||||
|
||||
# Request states.
|
||||
self.requests: dict[str, CachedRequestState] = {}
|
||||
|
||||
# Initialize input batch early to avoid AttributeError in _update_states
|
||||
self.input_batch = InputBatch(
|
||||
max_num_reqs=self.max_num_reqs,
|
||||
max_model_len=self.max_model_len,
|
||||
max_num_batched_tokens=self.max_num_tokens,
|
||||
device=self.device,
|
||||
pin_memory=self.pin_memory,
|
||||
vocab_size=self.model_config.get_vocab_size(),
|
||||
block_size=self.block_size,
|
||||
)
|
||||
|
||||
# Cached torch/numpy tensor
|
||||
# The pytorch tensor and numpy array share the same buffer.
|
||||
# Sometimes the numpy op is faster so we create both.
|
||||
@@ -1286,16 +1296,19 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
||||
"Hybrid models with more than one KV cache type are not "
|
||||
"supported yet.")
|
||||
|
||||
self.input_batch = InputBatch(
|
||||
max_num_reqs=self.max_num_reqs,
|
||||
max_model_len=self.max_model_len,
|
||||
max_num_batched_tokens=self.max_num_tokens,
|
||||
device=self.device,
|
||||
pin_memory=self.pin_memory,
|
||||
vocab_size=self.model_config.get_vocab_size(),
|
||||
block_size=kv_cache_config.kv_cache_groups[0].kv_cache_spec.
|
||||
block_size,
|
||||
)
|
||||
if kv_cache_config.kv_cache_groups[
|
||||
0].kv_cache_spec.block_size != self.block_size:
|
||||
self.input_batch = InputBatch(
|
||||
max_num_reqs=self.max_num_reqs,
|
||||
max_model_len=self.max_model_len,
|
||||
max_num_batched_tokens=self.max_num_tokens,
|
||||
device=self.device,
|
||||
pin_memory=self.pin_memory,
|
||||
vocab_size=self.model_config.get_vocab_size(),
|
||||
block_size=kv_cache_config.kv_cache_groups[0].kv_cache_spec.
|
||||
block_size,
|
||||
)
|
||||
# Verify dtype compatibility between block_table_cpu and input_batch
|
||||
assert self.block_table_cpu.dtype == self.input_batch.block_table[
|
||||
0].get_cpu_tensor().dtype
|
||||
|
||||
|
||||
Reference in New Issue
Block a user