[V1] TPU - Add tensor parallel support via Ray (#13618)
Signed-off-by: Alexander Matveev <amatveev@redhat.com>
This commit is contained in:
committed by
GitHub
parent
33f227e16b
commit
cb8bdfade2
@@ -21,6 +21,7 @@ from vllm.model_executor.model_loader import get_model
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
|
||||
from vllm.multimodal.utils import group_mm_inputs_by_modality
|
||||
from vllm.sampling_params import SamplingType
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import LayerBlockType, cdiv, is_pin_memory_available
|
||||
from vllm.v1.attention.backends.pallas import (NUM_KV_PAGES_PER_BLOCK,
|
||||
NUM_QUERIES_PER_BLOCK,
|
||||
@@ -545,6 +546,7 @@ class TPUModelRunner:
|
||||
def execute_model(
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
) -> ModelRunnerOutput:
|
||||
# Update cached state
|
||||
self._update_states(scheduler_output)
|
||||
|
||||
@@ -96,7 +96,8 @@ class TPUWorker:
|
||||
|
||||
# Set random seed.
|
||||
set_random_seed(self.model_config.seed)
|
||||
xm.set_rng_state(self.model_config.seed, self.device)
|
||||
if self.model_config.seed is not None:
|
||||
xm.set_rng_state(self.model_config.seed, self.device)
|
||||
|
||||
# Increase the cache size limit, which is the maximum number of
|
||||
# dynamo graphs that can be compiled.
|
||||
|
||||
Reference in New Issue
Block a user