[CPU] update torch 2.8 and fix missing fields in TorchSDPAMetadata (#25652)
Signed-off-by: jiang1.li <jiang1.li@intel.com>
This commit is contained in:
@@ -85,6 +85,19 @@ class TorchSDPABackend(AttentionBackend):
|
||||
|
||||
@dataclass
|
||||
class TorchSDPAMetadata(AttentionMetadata):
|
||||
"""Attention metadata for prefill and decode batched together."""
|
||||
# Total number of prefill requests.
|
||||
num_prefills: int
|
||||
# Number of prefill tokens.
|
||||
num_prefill_tokens: int
|
||||
# Number of decode tokens. Note that it is equivalent to the number of
|
||||
# decode requests.
|
||||
num_decode_tokens: int
|
||||
# (num_tokens,). The indices of the token slots that input tokens will be
|
||||
# stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size
|
||||
# is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
|
||||
# in block 0, and 1st slot in block 1, respectively.
|
||||
slot_mapping: torch.Tensor
|
||||
"""Metadata for PagedAttention."""
|
||||
# (batch_size,). The length of sequences (entire tokens seen so far) per
|
||||
# sequence.
|
||||
@@ -420,7 +433,6 @@ class TorchSDPAMetadataBuilderV1(AttentionMetadataBuilder[TorchSDPAMetadata]):
|
||||
num_prompt_req], # prefill
|
||||
query_start_loc=query_start_loc_cpu[:num_reqs +
|
||||
1], # for logits index
|
||||
enable_kv_scales_calculation=False,
|
||||
)
|
||||
|
||||
return attn_metadata
|
||||
|
||||
@@ -68,6 +68,8 @@ class TopKTopPSampler(nn.Module):
|
||||
"native implementation of top-p & top-k sampling. For the "
|
||||
"best performance, please install FlashInfer.")
|
||||
self.forward = self.forward_native
|
||||
elif current_platform.is_cpu():
|
||||
self.forward = self.forward_cpu
|
||||
else:
|
||||
self.forward = self.forward_native
|
||||
|
||||
@@ -119,6 +121,45 @@ class TopKTopPSampler(nn.Module):
|
||||
# because of slicing operation in logits_processor.
|
||||
return flashinfer_sample(logits.contiguous(), k, p, generators), None
|
||||
|
||||
def forward_cpu(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
generators: dict[int, torch.Generator],
|
||||
k: Optional[torch.Tensor],
|
||||
p: Optional[torch.Tensor],
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""
|
||||
PyTorch-native implementation of top-k and top-p sampling for CPU.
|
||||
|
||||
The logits tensor may be updated in-place.
|
||||
"""
|
||||
logits = self.apply_top_k_top_p(logits, k, p)
|
||||
logits_to_return = None
|
||||
if self.logprobs_mode == "processed_logits":
|
||||
logits_to_return = logits
|
||||
elif self.logprobs_mode == "processed_logprobs":
|
||||
logits_to_return = logits.log_softmax(dim=-1, dtype=torch.float32)
|
||||
|
||||
# Note: this is a workaround for
|
||||
# https://github.com/pytorch/pytorch/pull/151218
|
||||
@torch.compile(dynamic=True)
|
||||
def compiled_random_sample(logits: torch.Tensor) -> torch.Tensor:
|
||||
probs = logits.softmax(dim=-1, dtype=torch.float32)
|
||||
q = torch.empty_like(probs)
|
||||
q.exponential_()
|
||||
return probs.div(q).argmax(dim=-1).view(-1)
|
||||
|
||||
if len(generators) != logits.shape[0]:
|
||||
return compiled_random_sample(logits), logits_to_return
|
||||
else:
|
||||
probs = logits.softmax(dim=-1, dtype=torch.float32)
|
||||
q = torch.empty_like(probs)
|
||||
q.exponential_()
|
||||
for i, generator in generators.items():
|
||||
q[i].exponential_(generator=generator)
|
||||
|
||||
return probs.div_(q).argmax(dim=-1).view(-1), logits_to_return
|
||||
|
||||
|
||||
def apply_top_k_top_p(
|
||||
logits: torch.Tensor,
|
||||
|
||||
@@ -8,18 +8,13 @@ import torch
|
||||
|
||||
from vllm import envs
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.parallel_state import get_pp_group, get_tp_group
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.utils import set_random_seed
|
||||
from vllm.platforms import CpuArchEnum, current_platform
|
||||
from vllm.platforms.cpu import CpuPlatform, LogicalCPUInfo
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.v1.worker.cpu_model_runner import CPUModelRunner
|
||||
from vllm.v1.worker.gpu_worker import (Worker,
|
||||
init_worker_distributed_environment)
|
||||
from vllm.v1.worker.utils import is_residual_scattered_for_sp
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -102,40 +97,6 @@ class CPUWorker(Worker):
|
||||
set_random_seed(self.model_config.seed)
|
||||
self.model_runner.warming_up_model()
|
||||
|
||||
@torch.inference_mode()
|
||||
def execute_model(
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
) -> Optional[ModelRunnerOutput]:
|
||||
intermediate_tensors = None
|
||||
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||
num_input_tokens = self.model_runner._get_num_input_tokens(
|
||||
num_scheduled_tokens)
|
||||
all_gather_tensors = {
|
||||
"residual":
|
||||
not is_residual_scattered_for_sp(self.vllm_config,
|
||||
num_input_tokens)
|
||||
}
|
||||
if not get_pp_group().is_first_rank:
|
||||
intermediate_tensors = IntermediateTensors(
|
||||
get_pp_group().recv_tensor_dict(
|
||||
all_gather_group=get_tp_group(),
|
||||
all_gather_tensors=all_gather_tensors))
|
||||
|
||||
output = self.model_runner.execute_model(scheduler_output,
|
||||
intermediate_tensors)
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
assert isinstance(output, IntermediateTensors)
|
||||
get_pp_group().send_tensor_dict(
|
||||
output.tensors,
|
||||
all_gather_group=get_tp_group(),
|
||||
all_gather_tensors=all_gather_tensors)
|
||||
return None
|
||||
|
||||
assert isinstance(output, ModelRunnerOutput)
|
||||
return output if self.is_driver_worker else None
|
||||
|
||||
def _get_autobind_cpu_ids(
|
||||
self, cpu_selector: Callable[[list[LogicalCPUInfo]],
|
||||
list[LogicalCPUInfo]]
|
||||
|
||||
Reference in New Issue
Block a user