[V1][core] Implement pipeline parallel on Ray (#12996)
This commit is contained in:
@@ -12,7 +12,7 @@ import torch.nn as nn
|
||||
from vllm.attention.backends.abstract import AttentionType
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.config import CompilationLevel, VllmConfig
|
||||
from vllm.distributed.parallel_state import graph_capture
|
||||
from vllm.distributed.parallel_state import get_pp_group, graph_capture
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.inputs import INPUT_REGISTRY
|
||||
from vllm.logger import init_logger
|
||||
@@ -21,6 +21,7 @@ from vllm.model_executor.model_loader import get_model
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
|
||||
from vllm.multimodal.utils import group_mm_inputs_by_modality
|
||||
from vllm.sampling_params import SamplingType
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
|
||||
LayerBlockType, cdiv, is_pin_memory_available)
|
||||
from vllm.v1.attention.backends.flash_attn import (FlashAttentionBackend,
|
||||
@@ -773,6 +774,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
def execute_model(
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
) -> ModelRunnerOutput:
|
||||
batch_changed = self._update_states(scheduler_output)
|
||||
|
||||
@@ -831,8 +833,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
positions=positions,
|
||||
kv_caches=self.kv_caches,
|
||||
attn_metadata=None,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
if not get_pp_group().is_last_rank:
|
||||
return hidden_states
|
||||
hidden_states = hidden_states[:num_scheduled_tokens]
|
||||
sample_hidden_states = hidden_states[logits_indices]
|
||||
logits = self.model.compute_logits(sample_hidden_states, None)
|
||||
@@ -1007,12 +1012,19 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
positions = self.mrope_positions[:, :num_tokens]
|
||||
else:
|
||||
positions = self.positions[:num_tokens]
|
||||
intermediate_tensors = None
|
||||
if not get_pp_group().is_first_rank:
|
||||
intermediate_tensors = self.model.make_empty_intermediate_tensors(
|
||||
batch_size=num_tokens,
|
||||
dtype=self.model_config.dtype,
|
||||
device=self.device)
|
||||
with set_forward_context(None, self.vllm_config):
|
||||
hidden_states = model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
kv_caches=kv_caches,
|
||||
attn_metadata=None,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
return hidden_states
|
||||
@@ -1142,6 +1154,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
# Trigger compilation for general shape.
|
||||
hidden_states = self._dummy_run(self.max_num_tokens,
|
||||
dummy_kv_caches)
|
||||
if not get_pp_group().is_last_rank:
|
||||
return hidden_states
|
||||
hidden_states = hidden_states[logit_indices]
|
||||
logits = self.model.compute_logits(hidden_states, None)
|
||||
# TODO(woosuk): Consider the memory usage of the sampler.
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
"""A GPU worker class."""
|
||||
import gc
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
@@ -194,8 +194,9 @@ class Worker:
|
||||
def get_kv_cache_spec(self) -> KVCacheSpec:
|
||||
return self.model_runner.get_kv_cache_spec()
|
||||
|
||||
def initialize_cache(self, kv_cache_config: KVCacheConfig) -> None:
|
||||
def initialize_cache(self, kv_cache_configs: List[KVCacheConfig]) -> None:
|
||||
"""Allocate GPU KV cache with the specified kv_cache_config."""
|
||||
kv_cache_config = kv_cache_configs[self.rank]
|
||||
if self.vllm_config.model_config.enable_sleep_mode:
|
||||
allocator = CuMemAllocator.get_instance()
|
||||
context = allocator.use_memory_pool(tag="kv_cache")
|
||||
|
||||
Reference in New Issue
Block a user