[Hardware][intel GPU] add async output process for xpu (#8897)

This commit is contained in:
Kunshang Ji
2024-10-15 02:23:33 +08:00
committed by GitHub
parent dfe43a2071
commit 4141608c6a
2 changed files with 8 additions and 4 deletions

View File

@@ -361,9 +361,9 @@ class ModelConfig:
# Reminder: Please update docs/source/serving/compatibility_matrix.rst # Reminder: Please update docs/source/serving/compatibility_matrix.rst
# If the feature combo become valid # If the feature combo become valid
if device_config.device_type not in ("cuda", "tpu"): if device_config.device_type not in ("cuda", "tpu", "xpu"):
logger.warning( logger.warning(
"Async output processing is only supported for CUDA or TPU. " "Async output processing is only supported for CUDA, TPU, XPU. "
"Disabling it for other platforms.") "Disabling it for other platforms.")
self.use_async_output_proc = False self.use_async_output_proc = False
return return

View File

@@ -2,8 +2,8 @@ import dataclasses
import time import time
import weakref import weakref
from dataclasses import dataclass from dataclasses import dataclass
from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple,
TypeVar) Type, TypeVar)
import torch import torch
import torch.nn as nn import torch.nn as nn
@@ -57,6 +57,7 @@ class ModelInputForXPU(ModelRunnerInputBase):
virtual_engine: Optional[int] = None virtual_engine: Optional[int] = None
seq_lens: Optional[List[int]] = None seq_lens: Optional[List[int]] = None
query_lens: Optional[List[int]] = None query_lens: Optional[List[int]] = None
async_callback: Optional[Callable] = None
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
tensor_dict = { tensor_dict = {
@@ -582,6 +583,9 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]):
if not self.is_driver_worker: if not self.is_driver_worker:
return [] return []
if model_input.async_callback is not None:
model_input.async_callback()
# Sample the next token. # Sample the next token.
output: SamplerOutput = self.model.sample( output: SamplerOutput = self.model.sample(
logits=logits, logits=logits,