[Hardware][intel GPU] add async output process for xpu (#8897)
This commit is contained in:
@@ -361,9 +361,9 @@ class ModelConfig:
|
|||||||
|
|
||||||
# Reminder: Please update docs/source/serving/compatibility_matrix.rst
|
# Reminder: Please update docs/source/serving/compatibility_matrix.rst
|
||||||
# If the feature combo become valid
|
# If the feature combo become valid
|
||||||
if device_config.device_type not in ("cuda", "tpu"):
|
if device_config.device_type not in ("cuda", "tpu", "xpu"):
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Async output processing is only supported for CUDA or TPU. "
|
"Async output processing is only supported for CUDA, TPU, XPU. "
|
||||||
"Disabling it for other platforms.")
|
"Disabling it for other platforms.")
|
||||||
self.use_async_output_proc = False
|
self.use_async_output_proc = False
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -2,8 +2,8 @@ import dataclasses
|
|||||||
import time
|
import time
|
||||||
import weakref
|
import weakref
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type,
|
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple,
|
||||||
TypeVar)
|
Type, TypeVar)
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@@ -57,6 +57,7 @@ class ModelInputForXPU(ModelRunnerInputBase):
|
|||||||
virtual_engine: Optional[int] = None
|
virtual_engine: Optional[int] = None
|
||||||
seq_lens: Optional[List[int]] = None
|
seq_lens: Optional[List[int]] = None
|
||||||
query_lens: Optional[List[int]] = None
|
query_lens: Optional[List[int]] = None
|
||||||
|
async_callback: Optional[Callable] = None
|
||||||
|
|
||||||
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
|
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
|
||||||
tensor_dict = {
|
tensor_dict = {
|
||||||
@@ -582,6 +583,9 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]):
|
|||||||
if not self.is_driver_worker:
|
if not self.is_driver_worker:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
if model_input.async_callback is not None:
|
||||||
|
model_input.async_callback()
|
||||||
|
|
||||||
# Sample the next token.
|
# Sample the next token.
|
||||||
output: SamplerOutput = self.model.sample(
|
output: SamplerOutput = self.model.sample(
|
||||||
logits=logits,
|
logits=logits,
|
||||||
|
|||||||
Reference in New Issue
Block a user