[CPU] update torch 2.8 and fix missing fields in TorchSDPAMetadata (#25652)

Signed-off-by: jiang1.li <jiang1.li@intel.com>
This commit is contained in:
Li, Jiang
2025-09-25 21:29:11 +08:00
committed by GitHub
parent 69a8c8e99a
commit eb32335e35
7 changed files with 59 additions and 53 deletions

View File

@@ -85,6 +85,19 @@ class TorchSDPABackend(AttentionBackend):
@dataclass
class TorchSDPAMetadata(AttentionMetadata):
"""Attention metadata for prefill and decode batched together."""
# Total number of prefill requests.
num_prefills: int
# Number of prefill tokens.
num_prefill_tokens: int
# Number of decode tokens. Note that it is equivalent to the number of
# decode requests.
num_decode_tokens: int
# (num_tokens,). The indices of the token slots that input tokens will be
# stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size
# is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
# in block 0, and 1st slot in block 1, respectively.
slot_mapping: torch.Tensor
"""Metadata for PagedAttention."""
# (batch_size,). The length of sequences (entire tokens seen so far) per
# sequence.
@@ -420,7 +433,6 @@ class TorchSDPAMetadataBuilderV1(AttentionMetadataBuilder[TorchSDPAMetadata]):
num_prompt_req], # prefill
query_start_loc=query_start_loc_cpu[:num_reqs +
1], # for logits index
enable_kv_scales_calculation=False,
)
return attn_metadata

View File

@@ -68,6 +68,8 @@ class TopKTopPSampler(nn.Module):
"native implementation of top-p & top-k sampling. For the "
"best performance, please install FlashInfer.")
self.forward = self.forward_native
elif current_platform.is_cpu():
self.forward = self.forward_cpu
else:
self.forward = self.forward_native
@@ -119,6 +121,45 @@ class TopKTopPSampler(nn.Module):
# because of slicing operation in logits_processor.
return flashinfer_sample(logits.contiguous(), k, p, generators), None
def forward_cpu(
self,
logits: torch.Tensor,
generators: dict[int, torch.Generator],
k: Optional[torch.Tensor],
p: Optional[torch.Tensor],
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
"""
PyTorch-native implementation of top-k and top-p sampling for CPU.
The logits tensor may be updated in-place.
"""
logits = self.apply_top_k_top_p(logits, k, p)
logits_to_return = None
if self.logprobs_mode == "processed_logits":
logits_to_return = logits
elif self.logprobs_mode == "processed_logprobs":
logits_to_return = logits.log_softmax(dim=-1, dtype=torch.float32)
# Note: this is a workaround for
# https://github.com/pytorch/pytorch/pull/151218
@torch.compile(dynamic=True)
def compiled_random_sample(logits: torch.Tensor) -> torch.Tensor:
probs = logits.softmax(dim=-1, dtype=torch.float32)
q = torch.empty_like(probs)
q.exponential_()
return probs.div(q).argmax(dim=-1).view(-1)
if len(generators) != logits.shape[0]:
return compiled_random_sample(logits), logits_to_return
else:
probs = logits.softmax(dim=-1, dtype=torch.float32)
q = torch.empty_like(probs)
q.exponential_()
for i, generator in generators.items():
q[i].exponential_(generator=generator)
return probs.div_(q).argmax(dim=-1).view(-1), logits_to_return
def apply_top_k_top_p(
logits: torch.Tensor,

View File

@@ -8,18 +8,13 @@ import torch
from vllm import envs
from vllm.config import VllmConfig
from vllm.distributed.parallel_state import get_pp_group, get_tp_group
from vllm.logger import init_logger
from vllm.model_executor.utils import set_random_seed
from vllm.platforms import CpuArchEnum, current_platform
from vllm.platforms.cpu import CpuPlatform, LogicalCPUInfo
from vllm.sequence import IntermediateTensors
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.worker.cpu_model_runner import CPUModelRunner
from vllm.v1.worker.gpu_worker import (Worker,
init_worker_distributed_environment)
from vllm.v1.worker.utils import is_residual_scattered_for_sp
logger = init_logger(__name__)
@@ -102,40 +97,6 @@ class CPUWorker(Worker):
set_random_seed(self.model_config.seed)
self.model_runner.warming_up_model()
@torch.inference_mode()
def execute_model(
self,
scheduler_output: "SchedulerOutput",
) -> Optional[ModelRunnerOutput]:
intermediate_tensors = None
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
num_input_tokens = self.model_runner._get_num_input_tokens(
num_scheduled_tokens)
all_gather_tensors = {
"residual":
not is_residual_scattered_for_sp(self.vllm_config,
num_input_tokens)
}
if not get_pp_group().is_first_rank:
intermediate_tensors = IntermediateTensors(
get_pp_group().recv_tensor_dict(
all_gather_group=get_tp_group(),
all_gather_tensors=all_gather_tensors))
output = self.model_runner.execute_model(scheduler_output,
intermediate_tensors)
if not get_pp_group().is_last_rank:
assert isinstance(output, IntermediateTensors)
get_pp_group().send_tensor_dict(
output.tensors,
all_gather_group=get_tp_group(),
all_gather_tensors=all_gather_tensors)
return None
assert isinstance(output, ModelRunnerOutput)
return output if self.is_driver_worker else None
def _get_autobind_cpu_ids(
self, cpu_selector: Callable[[list[LogicalCPUInfo]],
list[LogicalCPUInfo]]