[TPU][V1][Bugfix] Fix w8a8 recompiilation with GSM8K (#15714)
Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
@@ -80,6 +80,7 @@ class TPUModelRunner:
|
||||
self.enforce_eager = model_config.enforce_eager
|
||||
self.pin_memory = is_pin_memory_available()
|
||||
self.dtype = self.model_config.dtype
|
||||
self._hidden_states_dtype = self.dtype
|
||||
|
||||
self.is_multimodal_model = model_config.is_multimodal_model
|
||||
self.sliding_window = model_config.get_sliding_window()
|
||||
@@ -771,10 +772,11 @@ class TPUModelRunner:
|
||||
torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0)
|
||||
|
||||
with set_forward_context(attn_metadata, self.vllm_config, 0):
|
||||
self.model(input_ids=input_ids,
|
||||
positions=position_ids,
|
||||
kv_caches=kv_caches,
|
||||
inputs_embeds=inputs_embeds)
|
||||
out = self.model(input_ids=input_ids,
|
||||
positions=position_ids,
|
||||
kv_caches=kv_caches,
|
||||
inputs_embeds=inputs_embeds)
|
||||
self._hidden_states_dtype = out.dtype
|
||||
|
||||
def capture_model(self) -> None:
|
||||
"""Compile the model."""
|
||||
@@ -800,7 +802,7 @@ class TPUModelRunner:
|
||||
num_reqs_to_sample = MIN_NUM_SEQS
|
||||
dummy_hidden = torch.randn((num_tokens, hsize),
|
||||
device=device,
|
||||
dtype=torch.bfloat16)
|
||||
dtype=self._hidden_states_dtype)
|
||||
# Compile for [8, 16, .., 128,.., `self.max_num_reqs`]
|
||||
while True:
|
||||
indices = torch.zeros(
|
||||
@@ -823,7 +825,7 @@ class TPUModelRunner:
|
||||
num_reqs_to_sample + 1, self.max_num_reqs)
|
||||
xm.wait_device_ops()
|
||||
end = time.perf_counter()
|
||||
logger.info("Compilation finished in in %.2f [secs].", end - start)
|
||||
logger.info("Compilation finished in %.2f [secs].", end - start)
|
||||
# Record the number cached XLA graph after warming up, this will be
|
||||
# used for checking there is no additional graph compilation during
|
||||
# runtime execution.
|
||||
|
||||
@@ -105,8 +105,8 @@ class TPUWorker:
|
||||
|
||||
# Increase the cache size limit, which is the maximum number of
|
||||
# dynamo graphs that can be compiled.
|
||||
# NOTE(woosuk): Usually, we compile 10-15 graphs for prefill and
|
||||
# 30-40 graphs for decode. 128 is an arbitrary safe number.
|
||||
# TODO (NickLucche) On gsm we compile 80+ graphs.
|
||||
# Re-evaluate limit, with MM we may get close to this limit.
|
||||
torch._dynamo.config.cache_size_limit = 128
|
||||
# Use persistent cache to avoid XLA recompilation.
|
||||
# NOTE(woosuk): Set per-rank cache path since different ranks
|
||||
|
||||
Reference in New Issue
Block a user