Compare commits

..

8 Commits

Author SHA1 Message Date
Robert Shaw
2339d59f92 [BugFix] Fix quantization for all other methods (#11547)
Some checks failed
Create Release / Create Release (push) Has been cancelled
2024-12-26 22:23:29 -08:00
Robert Shaw
1b875a0ef3 [V1][3/N] API Server: Reduce Task Switching + Handle Abort Properly (#11534) 2024-12-26 21:19:21 -08:00
youkaichao
eb881ed006 [misc] fix typing (#11540)
Signed-off-by: youkaichao <youkaichao@gmail.com>
2024-12-27 11:05:08 +08:00
Robert Shaw
46d4359450 [CI] Fix broken CI (#11543) 2024-12-26 18:49:16 -08:00
Woosuk Kwon
81b979f2a8 [V1] Fix yapf (#11538)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2024-12-27 09:47:10 +09:00
Woosuk Kwon
371d04d39b [V1] Use FlashInfer Sampling Kernel for Top-P & Top-K Sampling (#11394)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2024-12-27 09:32:38 +09:00
Robert Shaw
0c0c2015c5 Update openai_compatible_server.md (#11536)
Co-authored-by: Simon Mo <simon.mo@hey.com>
2024-12-26 16:26:18 -08:00
Simon Mo
82d24f7aac [Docs] Document Deepseek V3 support (#11535)
Signed-off-by: simon-mo <simon.mo@hey.com>
2024-12-26 16:21:56 -08:00
20 changed files with 491 additions and 370 deletions

View File

@@ -60,7 +60,7 @@ vLLM is flexible and easy to use with:
vLLM seamlessly supports most popular open-source models on HuggingFace, including: vLLM seamlessly supports most popular open-source models on HuggingFace, including:
- Transformer-like LLMs (e.g., Llama) - Transformer-like LLMs (e.g., Llama)
- Mixture-of-Expert LLMs (e.g., Mixtral) - Mixture-of-Expert LLMs (e.g., Mixtral, Deepseek-V2 and V3)
- Embedding Models (e.g. E5-Mistral) - Embedding Models (e.g. E5-Mistral)
- Multi-modal LLMs (e.g., LLaVA) - Multi-modal LLMs (e.g., LLaVA)

View File

@@ -137,6 +137,11 @@ See [this page](#generative-models) for more information on how to use generativ
- :code:`deepseek-ai/DeepSeek-V2`, :code:`deepseek-ai/DeepSeek-V2-Chat` etc. - :code:`deepseek-ai/DeepSeek-V2`, :code:`deepseek-ai/DeepSeek-V2-Chat` etc.
- -
- ✅︎ - ✅︎
* - :code:`DeepseekV3ForCausalLM`
- DeepSeek-V3
- :code:`deepseek-ai/DeepSeek-V3-Base`, :code:`deepseek-ai/DeepSeek-V3` etc.
-
- ✅︎
* - :code:`ExaoneForCausalLM` * - :code:`ExaoneForCausalLM`
- EXAONE-3 - EXAONE-3
- :code:`LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct`, etc. - :code:`LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct`, etc.
@@ -676,7 +681,7 @@ See [this page](#generative-models) for more information on how to use generativ
- PaliGemma, PaliGemma 2 - PaliGemma, PaliGemma 2
- T + I\ :sup:`E` - T + I\ :sup:`E`
- :code:`google/paligemma-3b-pt-224`, :code:`google/paligemma-3b-mix-224`, :code:`google/paligemma2-3b-ft-docci-448`, etc. - :code:`google/paligemma-3b-pt-224`, :code:`google/paligemma-3b-mix-224`, :code:`google/paligemma2-3b-ft-docci-448`, etc.
- -
- ✅︎ - ✅︎
- -
* - :code:`Phi3VForCausalLM` * - :code:`Phi3VForCausalLM`

View File

@@ -112,7 +112,13 @@ completion = client.chat.completions.create(
## Extra HTTP Headers ## Extra HTTP Headers
Only `X-Request-Id` HTTP request header is supported for now. Only `X-Request-Id` HTTP request header is supported for now. It can be enabled
with `--enable-request-id-headers`.
> Note that enablement of the headers can impact performance significantly at high QPS
> rates. We recommend implementing HTTP headers at the router level (e.g. via Istio),
> rather than within the vLLM layer for this reason.
> See https://github.com/vllm-project/vllm/pull/11529 for more details.
```python ```python
completion = client.chat.completions.create( completion = client.chat.completions.create(

View File

@@ -61,6 +61,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"DeepseekForCausalLM": _HfExamplesInfo("deepseek-ai/deepseek-llm-7b-chat"), "DeepseekForCausalLM": _HfExamplesInfo("deepseek-ai/deepseek-llm-7b-chat"),
"DeepseekV2ForCausalLM": _HfExamplesInfo("deepseek-ai/DeepSeek-V2-Lite-Chat", # noqa: E501 "DeepseekV2ForCausalLM": _HfExamplesInfo("deepseek-ai/DeepSeek-V2-Lite-Chat", # noqa: E501
trust_remote_code=True), trust_remote_code=True),
"DeepseekV3ForCausalLM": _HfExamplesInfo("deepseek-ai/DeepSeek-V3", # noqa: E501
trust_remote_code=True),
"ExaoneForCausalLM": _HfExamplesInfo("LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct"), # noqa: E501 "ExaoneForCausalLM": _HfExamplesInfo("LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct"), # noqa: E501
"FalconForCausalLM": _HfExamplesInfo("tiiuae/falcon-7b"), "FalconForCausalLM": _HfExamplesInfo("tiiuae/falcon-7b"),
"GemmaForCausalLM": _HfExamplesInfo("google/gemma-2b"), "GemmaForCausalLM": _HfExamplesInfo("google/gemma-2b"),

View File

@@ -68,7 +68,7 @@ def _create_default_sampling_metadata(
no_top_p=True, no_top_p=True,
no_top_k=True, no_top_k=True,
generators={}, generators={},
max_num_logprobs=VOCAB_SIZE, max_num_logprobs=0,
prompt_token_ids=_create_prompt_tokens_tensor(prompt_token_ids, prompt_token_ids=_create_prompt_tokens_tensor(prompt_token_ids,
vocab_size, device), vocab_size, device),
output_token_ids=output_token_ids, output_token_ids=output_token_ids,
@@ -169,20 +169,14 @@ def test_sampler_min_tokens_penalty(device: str, batch_size: int):
sampling_metadata.min_tokens = min_tokens sampling_metadata.min_tokens = min_tokens
sampling_metadata.stop_token_ids = stop_token_ids sampling_metadata.stop_token_ids = stop_token_ids
sampler = Sampler() sampler = Sampler()
sampler_output = sampler(fake_logits, sampling_metadata) logits = sampler.apply_penalties(fake_logits, sampling_metadata)
logits = logits.cpu()
for batch_idx in range(batch_size): for batch_idx in range(batch_size):
for vocab in range(VOCAB_SIZE): for token_id in range(VOCAB_SIZE):
# Verify that the logprobs for stop token ids is set if token_id in stop_token_ids[batch_idx]:
# to -inf. assert logits[batch_idx][token_id] == -float("inf")
logprob_index = torch.where(
sampler_output.logprob_token_ids[batch_idx] ==
vocab)[0].item()
if vocab in stop_token_ids[batch_idx]:
assert sampler_output.logprobs[batch_idx][
logprob_index] == -float("inf")
else: else:
assert sampler_output.logprobs[batch_idx][ assert logits[batch_idx][token_id] != -float("inf")
logprob_index] != -float("inf")
@pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("device", CUDA_DEVICES)
@@ -205,18 +199,14 @@ def test_sampler_presence_penalty(device: str, batch_size: int,
batch_size, presence_penalty, torch.device(device)) batch_size, presence_penalty, torch.device(device))
sampling_metadata.no_penalties = False sampling_metadata.no_penalties = False
sampler = Sampler() sampler = Sampler()
sampler_output = sampler(fake_logits, sampling_metadata) logits = sampler.apply_penalties(fake_logits, sampling_metadata)
logits = logits.cpu()
for batch_idx in range(batch_size): for batch_idx in range(batch_size):
# The logprobs in the SamplerOutput are arranged in descending order. # Since all tokens initially have the same logits, the non-penalized
# Since all tokens initially have the same logprobs, the non-penalized # token ID will be the one with the highest logit value, while the
# tokens will appear at the beginning, while the penalized tokens # penalized token ID will be the one with the lowest logit value.
# will appear at the end of the list. non_penalized_token_id = logits[batch_idx].argmax().item()
penalized_token_id = sampler_output.logprob_token_ids[batch_idx][ penalized_token_id = logits[batch_idx].argmin().item()
VOCAB_SIZE - 1]
penalized_log_prod = sampler_output.logprobs[batch_idx][VOCAB_SIZE - 1]
non_penalized_token_id = sampler_output.logprob_token_ids[batch_idx][0]
non_penalized_log_prod = sampler_output.logprobs[batch_idx][0]
assert non_penalized_log_prod > penalized_log_prod
if presence_penalty > 0: if presence_penalty > 0:
# If `presence_penalty` is set to a value greater than 0, it # If `presence_penalty` is set to a value greater than 0, it
# indicates a preference for new tokens over those already # indicates a preference for new tokens over those already
@@ -256,11 +246,11 @@ def test_sampler_frequency_penalty(device: str, batch_size: int,
sampling_metadata.output_token_ids = output_token_ids sampling_metadata.output_token_ids = output_token_ids
sampling_metadata.no_penalties = False sampling_metadata.no_penalties = False
sampler = Sampler() sampler = Sampler()
sampler_output = sampler(fake_logits, sampling_metadata) logits = sampler.apply_penalties(fake_logits, sampling_metadata)
logits = logits.cpu()
for batch_idx in range(batch_size): for batch_idx in range(batch_size):
logprobs_token_ids = sampler_output.logprob_token_ids[batch_idx] non_penalized_token_id = logits[batch_idx].argmax().item()
non_penalized_token_id = logprobs_token_ids[0] penalized_token_id = logits[batch_idx].argmin().item()
penalized_token_id = logprobs_token_ids[VOCAB_SIZE - 1]
distinct_sorted_token_ids_in_output = \ distinct_sorted_token_ids_in_output = \
sorted_token_ids_in_output[batch_idx] sorted_token_ids_in_output[batch_idx]
most_frequent_token_id = distinct_sorted_token_ids_in_output[ most_frequent_token_id = distinct_sorted_token_ids_in_output[
@@ -305,11 +295,11 @@ def test_sampler_repetition_penalty(device: str, batch_size: int,
batch_size, repetition_penalty, torch.device(device)) batch_size, repetition_penalty, torch.device(device))
sampling_metadata.no_penalties = False sampling_metadata.no_penalties = False
sampler = Sampler() sampler = Sampler()
sampler_output = sampler(fake_logits, sampling_metadata) logits = sampler.apply_penalties(fake_logits, sampling_metadata)
logits = logits.cpu()
for batch_idx in range(batch_size): for batch_idx in range(batch_size):
logprobs_token_ids = sampler_output.logprob_token_ids[batch_idx] non_penalized_token_id = logits[batch_idx].argmax().item()
non_penalized_token_id = logprobs_token_ids[0] penalized_token_id = logits[batch_idx].argmin().item()
penalized_token_id = logprobs_token_ids[VOCAB_SIZE - 1]
prompt_tokens = sampling_metadata.prompt_token_ids[ prompt_tokens = sampling_metadata.prompt_token_ids[
batch_idx][:].tolist() batch_idx][:].tolist()
output_tokens = sampling_metadata.output_token_ids[batch_idx] output_tokens = sampling_metadata.output_token_ids[batch_idx]

View File

@@ -208,8 +208,8 @@ def wrap_inductor(graph: fx.GraphModule,
from torch._inductor.compile_fx import graph_returns_tuple from torch._inductor.compile_fx import graph_returns_tuple
returns_tuple = graph_returns_tuple(graph) returns_tuple = graph_returns_tuple(graph)
# this is the graph we return to Dynamo to run # this is the callable we return to Dynamo to run
def compiled_graph(*args) -> Optional[fx.CompiledFxGraph]: def compiled_graph(*args):
# convert args to list # convert args to list
list_args = list(args) list_args = list(args)
graph_output = inductor_compiled_graph(list_args) graph_output = inductor_compiled_graph(list_args)
@@ -537,7 +537,8 @@ class VllmBackend:
example_inputs[x].clone() for x in self.sym_tensor_indices example_inputs[x].clone() for x in self.sym_tensor_indices
] ]
def copy_and_call(*args) -> fx.GraphModule: # this is the callable we return to Dynamo to run
def copy_and_call(*args):
list_args = list(args) list_args = list(args)
for i, index in enumerate(self.sym_tensor_indices): for i, index in enumerate(self.sym_tensor_indices):
runtime_tensor = list_args[index] runtime_tensor = list_args[index]

View File

@@ -30,7 +30,7 @@ if TYPE_CHECKING:
VLLM_LOGGING_CONFIG_PATH: Optional[str] = None VLLM_LOGGING_CONFIG_PATH: Optional[str] = None
VLLM_TRACE_FUNCTION: int = 0 VLLM_TRACE_FUNCTION: int = 0
VLLM_ATTENTION_BACKEND: Optional[str] = None VLLM_ATTENTION_BACKEND: Optional[str] = None
VLLM_USE_FLASHINFER_SAMPLER: bool = False VLLM_USE_FLASHINFER_SAMPLER: Optional[bool] = None
VLLM_USE_FLASHINFER_REJECTION_SAMPLER: bool = False VLLM_USE_FLASHINFER_REJECTION_SAMPLER: bool = False
VLLM_FLASHINFER_FORCE_TENSOR_CORES: bool = False VLLM_FLASHINFER_FORCE_TENSOR_CORES: bool = False
VLLM_PP_LAYER_PARTITION: Optional[str] = None VLLM_PP_LAYER_PARTITION: Optional[str] = None
@@ -277,7 +277,8 @@ environment_variables: Dict[str, Callable[[], Any]] = {
# If set, vllm will use flashinfer sampler # If set, vllm will use flashinfer sampler
"VLLM_USE_FLASHINFER_SAMPLER": "VLLM_USE_FLASHINFER_SAMPLER":
lambda: bool(int(os.getenv("VLLM_USE_FLASHINFER_SAMPLER", "0"))), lambda: bool(int(os.environ["VLLM_USE_FLASHINFER_SAMPLER"]))
if "VLLM_USE_FLASHINFER_SAMPLER" in os.environ else None,
# If set, vllm will force flashinfer to use tensor cores; # If set, vllm will force flashinfer to use tensor cores;
# otherwise will use heuristic based on model architecture. # otherwise will use heuristic based on model architecture.

View File

@@ -41,9 +41,20 @@ class FusedMoEMethodBase(QuantizeMethodBase):
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def apply(self, layer: torch.nn.Module, x: torch.Tensor, def apply(
router_logits: torch.Tensor, top_k: int, renormalize: bool, self,
use_grouped_topk: bool) -> torch.Tensor: layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None
) -> torch.Tensor:
raise NotImplementedError raise NotImplementedError
@@ -79,7 +90,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
router_logits: torch.Tensor, router_logits: torch.Tensor,
top_k: int, top_k: int,
renormalize: bool, renormalize: bool,
use_grouped_topk: bool, use_grouped_topk: bool = False,
topk_group: Optional[int] = None, topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None, num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,

View File

@@ -440,11 +440,13 @@ class AWQMoEMethod(FusedMoEMethodBase):
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
top_k: int, top_k: int,
renormalize: bool = True, renormalize: bool,
use_grouped_topk: bool = False, use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None, topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
topk_weights, topk_ids = FusedMoE.select_experts( topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x, hidden_states=x,
@@ -454,7 +456,9 @@ class AWQMoEMethod(FusedMoEMethodBase):
renormalize=renormalize, renormalize=renormalize,
topk_group=topk_group, topk_group=topk_group,
num_expert_group=num_expert_group, num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function) custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias)
return torch.ops.vllm.fused_marlin_moe( return torch.ops.vllm.fused_marlin_moe(
x, x,

View File

@@ -203,13 +203,14 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
top_k: int, top_k: int,
renormalize: bool = True, renormalize: bool,
use_grouped_topk: bool = False, use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None, topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe import fused_experts
topk_weights, topk_ids = FusedMoE.select_experts( topk_weights, topk_ids = FusedMoE.select_experts(
@@ -220,7 +221,9 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
renormalize=renormalize, renormalize=renormalize,
topk_group=topk_group, topk_group=topk_group,
num_expert_group=num_expert_group, num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function) custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias)
return fused_experts(x, return fused_experts(x,
layer.w13_weight, layer.w13_weight,
@@ -476,12 +479,15 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
top_k: int, top_k: int,
renormalize: bool = True, renormalize: bool,
use_grouped_topk: bool = False, use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None, topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
topk_weights, topk_ids = FusedMoE.select_experts( topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x, hidden_states=x,
router_logits=router_logits, router_logits=router_logits,
@@ -490,7 +496,9 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
renormalize=renormalize, renormalize=renormalize,
topk_group=topk_group, topk_group=topk_group,
num_expert_group=num_expert_group, num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function) custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias)
return torch.ops.vllm.fused_marlin_moe( return torch.ops.vllm.fused_marlin_moe(
x, x,

View File

@@ -99,11 +99,13 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
top_k: int, top_k: int,
renormalize: bool = True, renormalize: bool,
use_grouped_topk: bool = False, use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None, topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe import fused_experts
@@ -115,7 +117,9 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
renormalize=renormalize, renormalize=renormalize,
topk_group=topk_group, topk_group=topk_group,
num_expert_group=num_expert_group, num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function) custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias)
return fused_experts(x, return fused_experts(x,
layer.w13_weight, layer.w13_weight,

View File

@@ -601,14 +601,13 @@ class Fp8MoEMethod(FusedMoEMethodBase):
router_logits: torch.Tensor, router_logits: torch.Tensor,
top_k: int, top_k: int,
renormalize: bool, renormalize: bool,
use_grouped_topk: bool, use_grouped_topk: bool = False,
topk_group: Optional[int] = None, topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None, num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax", scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None, e_score_correction_bias: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe import fused_experts
topk_weights, topk_ids = FusedMoE.select_experts( topk_weights, topk_ids = FusedMoE.select_experts(

View File

@@ -532,11 +532,13 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
top_k: int, top_k: int,
renormalize: bool = True, renormalize: bool,
use_grouped_topk: bool = False, use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None, topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
# The input must currently be float16 # The input must currently be float16
orig_dtype = x.dtype orig_dtype = x.dtype
@@ -550,7 +552,9 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
renormalize=renormalize, renormalize=renormalize,
topk_group=topk_group, topk_group=topk_group,
num_expert_group=num_expert_group, num_expert_group=num_expert_group,
custom_routing_function=None) custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias)
return torch.ops.vllm.fused_marlin_moe( return torch.ops.vllm.fused_marlin_moe(
x, x,

View File

@@ -9,14 +9,13 @@ from vllm.inputs import INPUT_REGISTRY, InputRegistry, PromptType
from vllm.inputs.preprocess import InputPreprocessor from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.outputs import PoolingRequestOutput, RequestOutput from vllm.outputs import RequestOutput
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
from vllm.v1.engine.async_stream import AsyncStream
from vllm.v1.engine.core_client import EngineCoreClient from vllm.v1.engine.core_client import EngineCoreClient
from vllm.v1.engine.detokenizer import Detokenizer from vllm.v1.engine.detokenizer import Detokenizer
from vllm.v1.engine.processor import Processor from vllm.v1.engine.processor import Processor
@@ -54,10 +53,8 @@ class AsyncLLM(EngineClient):
lora_config=vllm_config.lora_config) lora_config=vllm_config.lora_config)
self.tokenizer.ping() self.tokenizer.ping()
# Request streams (map of request_id -> AsyncStream). # Request streams (map of request_id -> queue).
self.request_streams: Dict[str, AsyncStream] = {} self.rid_to_queue: Dict[str, asyncio.Queue] = {}
# List of cancelled request ids to be aborted.
self.client_aborted_requests: List[str] = []
# Processor (converts Inputs --> EngineCoreRequests). # Processor (converts Inputs --> EngineCoreRequests).
self.processor = Processor( self.processor = Processor(
@@ -153,14 +150,13 @@ class AsyncLLM(EngineClient):
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0, priority: int = 0,
) -> AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None]: ) -> asyncio.Queue[RequestOutput]:
"""Add new request to the AsyncLLM.""" """Add new request to the AsyncLLM."""
if self.detokenizer.is_request_active(request_id): # 1) Create a new output queue for the request.
raise ValueError(f"Request {request_id} already exists.") if request_id in self.rid_to_queue:
raise ValueError(f"Request id {request_id} already running.")
# 1) Create a new AsyncStream for the request. self.rid_to_queue[request_id] = asyncio.Queue()
stream = self._add_request_to_streams(request_id)
# 2) Convert input --> DetokenizerRequest / EngineCoreRequest. # 2) Convert input --> DetokenizerRequest / EngineCoreRequest.
detokenizer_req, engine_core_req = self.processor.process_inputs( detokenizer_req, engine_core_req = self.processor.process_inputs(
@@ -173,8 +169,10 @@ class AsyncLLM(EngineClient):
# 4) Add the EngineCoreRequest to EngineCore (separate process). # 4) Add the EngineCoreRequest to EngineCore (separate process).
await self.engine_core.add_request_async(engine_core_req) await self.engine_core.add_request_async(engine_core_req)
# 5) Return the generator. if self.log_requests:
return stream.generator() logger.info("Added request %s.", request_id)
return self.rid_to_queue[request_id]
# TODO: we should support multiple prompts in one call, as you # TODO: we should support multiple prompts in one call, as you
# can do with LLM.generate. So that for multi-prompt completion # can do with LLM.generate. So that for multi-prompt completion
@@ -194,7 +192,7 @@ class AsyncLLM(EngineClient):
""" """
Main function called by the API server to kick off a request Main function called by the API server to kick off a request
* 1) Making an AsyncStream corresponding to the Request. * 1) Making an AsyncStream corresponding to the Request.
# 2) Processing the Input. * 2) Processing the Input.
* 3) Adding the Request to the Detokenizer. * 3) Adding the Request to the Detokenizer.
* 4) Adding the Request to the EngineCore (separate process). * 4) Adding the Request to the EngineCore (separate process).
@@ -206,14 +204,15 @@ class AsyncLLM(EngineClient):
returning the RequestOutput back to the caller. returning the RequestOutput back to the caller.
""" """
# We start the output_handler on the first call to generate() so that try:
# we can call __init__ before the event loop starts, which enables us # We start the output_handler on the first call to generate() so
# to handle startup failure gracefully in the OpenAI server. # we can call __init__ before the event loop, which enables us
if self.output_handler is None: # to handle startup failure gracefully in the OpenAI server.
self.output_handler = asyncio.create_task( if self.output_handler is None:
self._run_output_handler()) self.output_handler = asyncio.create_task(
self._run_output_handler())
async for output in await self.add_request( q = await self.add_request(
request_id, request_id,
prompt, prompt,
sampling_params, sampling_params,
@@ -221,79 +220,42 @@ class AsyncLLM(EngineClient):
trace_headers=trace_headers, trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request, prompt_adapter_request=prompt_adapter_request,
priority=priority, priority=priority,
): )
yield output
def _finish_stream(self, request_id: str): # The output_handler task pushes items into the queue.
stream = self.request_streams.pop(request_id, None) # This task pulls from the queue and yields to caller.
if stream is not None: while True:
stream.finish() # Note: drain queue without await if possible (avoids
# task switching under load which helps performance).
out = q.get_nowait() if q.qsize() > 0 else await q.get()
def _add_request_to_streams( # Note: both Detokenizer and EngineCore handle their
self, # own request cleanup based on finished.
request_id: str, if out.finished:
) -> AsyncStream: del self.rid_to_queue[request_id]
yield out
break
if request_id in self.request_streams: yield out
raise ValueError(f"Request id {request_id} already running.")
# Avoid streams having circular ref to parent AsyncLLM object. # If the request is disconnected by the client, the
aborted_reqs = self.client_aborted_requests # generate() task will be canceled. So, we abort the
stream = AsyncStream(request_id, aborted_reqs.append) # request if we end up here.
self.request_streams[request_id] = stream except asyncio.CancelledError:
await self.abort(request_id)
if self.log_requests: raise
logger.info("Added request %s.", request_id)
return stream
async def _process_cancellations(self) -> None:
"""
Process requests cancelled from user disconnecting.
When a client disconnects, AsyncStream._cancel() is called.
We passed a callback to AsyncStream(), which appends to
self.client_aborted_requests.
As a result, if any requests are canceled from the user side
the request_id will show up in self.client_aborted_requests.
"""
# Avoid streams having circular ref to parent AsyncLLM object.
if not self.client_aborted_requests:
return
reqs_to_abort = self.client_aborted_requests.copy()
self.client_aborted_requests.clear()
# Remove from Detokenizer.
self.detokenizer.abort_requests(reqs_to_abort)
# Remove from RequestStreams.
for request_id in reqs_to_abort:
if self.log_requests:
logger.info("User-cancelled request %s.", request_id)
self._finish_stream(request_id)
# Remove from EngineCore.
await self.engine_core.abort_requests_async(reqs_to_abort)
def _process_request_outputs(self, request_outputs: List[RequestOutput]): def _process_request_outputs(self, request_outputs: List[RequestOutput]):
"""Process outputs by putting them into per-request AsyncStreams.""" """Process outputs by putting them into per-request queues."""
for request_output in request_outputs: for request_output in request_outputs:
request_id = request_output.request_id request_id = request_output.request_id
assert request_id in self.request_streams
# Each request in the API server pulls from the per-request stream. # Note: it is possible a request was aborted and removed from
stream = self.request_streams.get(request_id) # the state due to client cancellations, so if we encounter a
if stream is not None: # request id not in the state, we skip.
stream.put(request_output) if request_id in self.rid_to_queue:
self.rid_to_queue[request_id].put_nowait(request_output)
# If finished, remove from the tracker.
if request_output.finished:
if self.log_requests:
logger.info("Finished request %s.", request_id)
self._finish_stream(request_id)
async def _run_output_handler(self): async def _run_output_handler(self):
"""Background loop: pulls from EngineCore and pushes to AsyncStreams.""" """Background loop: pulls from EngineCore and pushes to AsyncStreams."""
@@ -306,24 +268,27 @@ class AsyncLLM(EngineClient):
# 2) Detokenize based on the output. # 2) Detokenize based on the output.
request_outputs, reqs_to_abort = self.detokenizer.step(outputs) request_outputs, reqs_to_abort = self.detokenizer.step(outputs)
# 3) Put the RequestOutputs into the per-request AsyncStreams. # 3) Put the RequestOutputs into the per-request queues.
self._process_request_outputs(request_outputs) self._process_request_outputs(request_outputs)
# 4) Abort any requests that finished due to stop strings. # 4) Abort any requests that finished due to stop strings.
await self.engine_core.abort_requests_async(reqs_to_abort) await self.engine_core.abort_requests_async(reqs_to_abort)
# 5) Abort any requests due to client cancellations.
await self._process_cancellations()
except BaseException as e: except BaseException as e:
logger.error(e) logger.error(e)
raise e raise e
# TODO: can we eliminate these?
async def abort(self, request_id: str) -> None: async def abort(self, request_id: str) -> None:
# Note: Who Calls this? I dont think this is actually used. """Abort RequestId in self, detokenizer, and engine core."""
raise ValueError("Not Supported on V1 yet.")
request_ids = [request_id]
await self.engine_core.abort_requests_async(request_ids)
self.detokenizer.abort_requests(request_ids)
# If a request finishes while we await then the request_id
# will be removed from the tracked queues before we get here.
if request_id in self.rid_to_queue:
del self.rid_to_queue[request_id]
def encode( def encode(
self, self,

View File

@@ -1,55 +0,0 @@
import asyncio
from typing import Any, AsyncGenerator, Callable, Optional, Type, Union
from vllm.outputs import PoolingRequestOutput, RequestOutput
class AsyncStream:
"""A stream of RequestOutputs or PoolingRequestOutputs for a request
that can be iterated over asynchronously via an async generator."""
STOP_ITERATION = Exception() # Sentinel
def __init__(self, request_id: str, cancel: Callable[[str], None]) -> None:
self.request_id = request_id
self._cancel = cancel
self._queue: asyncio.Queue = asyncio.Queue()
self._finished = False
def put(self, item: Union[RequestOutput, PoolingRequestOutput,
Exception]) -> None:
if not self._finished:
self._queue.put_nowait(item)
def finish(
self,
exception: Optional[Union[BaseException, Type[BaseException]]] = None,
) -> None:
if not self._finished:
self._finished = True
self._queue.put_nowait(exception if self._is_raisable(exception)
else AsyncStream.STOP_ITERATION)
async def generator(
self
) -> AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None]:
finished = False
try:
while True:
result = await self._queue.get()
if self._is_raisable(result):
finished = True
if result == AsyncStream.STOP_ITERATION:
return
raise result
yield result
finally:
self._finished = True
if not finished:
self._cancel(self.request_id)
@staticmethod
def _is_raisable(value: Any):
return isinstance(value, BaseException) or \
(isinstance(value, type) and \
issubclass(value, BaseException))

View File

@@ -32,7 +32,7 @@ logger = init_logger(__name__)
POLLING_TIMEOUT_MS = 5000 POLLING_TIMEOUT_MS = 5000
POLLING_TIMEOUT_S = POLLING_TIMEOUT_MS // 1000 POLLING_TIMEOUT_S = POLLING_TIMEOUT_MS // 1000
LOGGING_TIME_S = 5000 LOGGING_TIME_S = POLLING_TIMEOUT_S
class EngineCore: class EngineCore:

View File

View File

@@ -0,0 +1,59 @@
from typing import List, Set, Tuple
import torch
from vllm.model_executor.layers.utils import apply_penalties
from vllm.utils import is_pin_memory_available, make_tensor_with_pad
def apply_min_token_penalties(logits: torch.Tensor,
output_token_ids: List[List[int]],
stop_token_ids: List[Set[int]],
min_tokens: List[int]) -> None:
"""
Applies minimum token penalty by setting the logits of the stop tokens
to -inf.
"""
min_tokens_logits_to_penalize: List[Tuple[int, int]] = []
for index, min_token in enumerate(min_tokens):
if len(output_token_ids[index]) < min_token:
for stop_token_id in stop_token_ids[index]:
min_tokens_logits_to_penalize.append((index, stop_token_id))
if min_tokens_logits_to_penalize:
logits[tuple(zip(*min_tokens_logits_to_penalize))] = -float("inf")
def apply_all_penalties(
logits: torch.Tensor,
prompt_token_ids: torch.Tensor,
presence_penalties: torch.Tensor,
frequency_penalties: torch.Tensor,
repetition_penalties: torch.Tensor,
output_token_ids: List[List[int]],
) -> torch.Tensor:
"""
Applies presence, frequency and repetition penalties to the logits.
"""
_, vocab_size = logits.shape
output_tokens_t = _convert_to_tensors(output_token_ids, vocab_size,
logits.device)
return apply_penalties(logits, prompt_token_ids, output_tokens_t,
presence_penalties, frequency_penalties,
repetition_penalties)
def _convert_to_tensors(output_token_ids: List[List[int]], vocab_size: int,
device: torch.device) -> torch.Tensor:
"""
Convert the different list data structures to tensors.
"""
output_tokens_tensor = make_tensor_with_pad(
output_token_ids,
# Use the value of vocab_size as a pad since we don't have a
# token_id of this value.
pad=vocab_size,
device="cpu",
dtype=torch.int64,
pin_memory=is_pin_memory_available(),
)
return output_tokens_tensor.to(device, non_blocking=True)

View File

@@ -0,0 +1,201 @@
from typing import Dict
import torch
import torch.nn as nn
from vllm import envs
from vllm.logger import init_logger
from vllm.platforms import current_platform
logger = init_logger(__name__)
try:
import flashinfer.sampling
is_flashinfer_available = True
except ImportError:
is_flashinfer_available = False
class TopKTopPSampler(nn.Module):
def __init__(self):
super().__init__()
if current_platform.is_cuda:
if is_flashinfer_available:
if envs.VLLM_USE_FLASHINFER_SAMPLER is not False:
# NOTE(woosuk): The V0 sampler doesn't use FlashInfer for
# sampling unless VLLM_USE_FLASHINFER_SAMPLER=1 (i.e., by
# default it is unused). For backward compatibility, we set
# `VLLM_USE_FLASHINFER_SAMPLER` as None by default and
# interpret it differently in V0 and V1 samplers: In V0,
# None means False, while in V1, None means True. This is
# why we use the condition
# `envs.VLLM_USE_FLASHINFER_SAMPLER is not False` here.
logger.info("Using FlashInfer for top-p & top-k sampling.")
self.forward = self.forward_cuda
else:
logger.warning(
"FlashInfer is available, but it is not enabled. "
"Falling back to the PyTorch-native implementation of "
"top-p & top-k sampling. For the best performance, "
"please set VLLM_USE_FLASHINFER_SAMPLER=1.")
self.forward = self.forward_native
else:
logger.warning(
"FlashInfer is not available. Falling back to the PyTorch-"
"native implementation of top-p & top-k sampling. For the "
"best performance, please install FalshInfer.")
self.forward = self.forward_native
else:
self.forward = self.forward_native
def forward_native(
self,
logits: torch.Tensor,
generators: Dict[int, torch.Generator],
no_top_k: bool,
k: torch.Tensor,
no_top_p: bool,
p: torch.Tensor,
) -> torch.Tensor:
"""PyTorch-native implementation of top-k and top-p sampling."""
logits = apply_top_k_top_p(logits, no_top_k, k, no_top_p, p)
probs = logits.softmax(dim=-1, dtype=torch.float32)
return random_sample(probs, generators)
def forward_cuda(
self,
logits: torch.Tensor,
generators: Dict[int, torch.Generator],
no_top_k: bool,
k: torch.Tensor,
no_top_p: bool,
p: torch.Tensor,
) -> torch.Tensor:
"""More optimized implementation for top-k and top-p sampling."""
probs = logits.softmax(dim=-1, dtype=torch.float32)
if no_top_k and no_top_p:
# We prefer `random_sample` over `flashinfer_sample` when sorting is
# not needed. This is because `random_sample` does not require
# CPU-GPU synchronization while `flashinfer_sample` does.
return random_sample(probs, generators)
return flashinfer_sample(probs, no_top_k, k, no_top_p, p, generators)
def apply_top_k_top_p(
logits: torch.Tensor,
no_top_k: bool,
k: torch.Tensor,
no_top_p: bool,
p: torch.Tensor,
) -> torch.Tensor:
"""Apply top-k and top-p masks to the logits.
This function sorts the logits tensor, which can be slow for large batches.
"""
if no_top_k and no_top_p:
return logits
logits_sort, logits_idx = logits.sort(dim=-1, descending=False)
if not no_top_k:
# Apply top-k.
top_k_mask = logits_sort.size(1) - k.to(torch.long)
# Get all the top_k values.
top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1))
top_k_mask = logits_sort < top_k_mask
logits_sort.masked_fill_(top_k_mask, -float("inf"))
if not no_top_p:
# Apply top-p.
probs_sort = logits_sort.softmax(dim=-1)
probs_sum = probs_sort.cumsum(dim=-1)
top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1)
# at least one
top_p_mask[:, -1] = False
logits_sort.masked_fill_(top_p_mask, -float("inf"))
# Re-sort the probabilities.
logits = logits_sort.scatter(dim=-1, index=logits_idx, src=logits_sort)
return logits
def random_sample(
probs: torch.Tensor,
generators: Dict[int, torch.Generator],
) -> torch.Tensor:
"""Randomly sample from the probabilities.
We use this function instead of torch.multinomial because torch.multinomial
causes CPU-GPU synchronization.
"""
q = torch.empty_like(probs)
# NOTE(woosuk): To batch-process the requests without their own seeds,
# which is the common case, we first assume that every request does
# not have its own seed. Then, we overwrite the values for the requests
# that have their own seeds.
if len(generators) != probs.shape[0]:
q.exponential_()
if generators:
# TODO(woosuk): This can be slow because we handle each request
# one by one. Optimize this.
for i, generator in generators.items():
q[i].exponential_(generator=generator)
return probs.div_(q).argmax(dim=-1).view(-1)
def flashinfer_sample(
probs: torch.Tensor,
no_top_k: bool,
k: torch.Tensor,
no_top_p: bool,
p: torch.Tensor,
generators: Dict[int, torch.Generator],
) -> torch.Tensor:
"""Sample from the probabilities using FlashInfer.
Statistically, this function is equivalent to the `random_sample` function.
However, this function is faster because it avoids sorting the logits tensor
via rejection sampling.
NOTE: The outputs of this function do not necessarily match the outputs of
the `random_sample` function. It only guarantees that the outputs are
statistically equivalent.
NOTE: This function includes CPU-GPU synchronization, while `random_sample`
does not. Call this function at the end of the forward pass to minimize
the synchronization overhead.
"""
assert not (no_top_k and no_top_p)
max_top_k_round = 32
batch_size = probs.shape[0]
uniform_samples = torch.empty((max_top_k_round, batch_size),
device=probs.device)
if len(generators) != batch_size:
uniform_samples.uniform_()
if generators:
for i, generator in generators.items():
uniform_samples[:, i].uniform_(generator=generator)
if no_top_k:
# Top-p only.
next_token_ids, success = flashinfer.sampling.top_p_sampling_from_probs(
probs, uniform_samples, p, deterministic=True)
elif no_top_p:
# Top-k only.
next_token_ids, success = flashinfer.sampling.top_k_sampling_from_probs(
probs, uniform_samples, k, deterministic=True)
else:
# Both top-k and top-p.
next_token_ids, success = (
flashinfer.sampling.top_k_top_p_sampling_from_probs(
probs, uniform_samples, k, p, deterministic=True))
# NOTE: CPU-GPU synchronization happens here.
if not success.all():
if not no_top_k:
probs = flashinfer.sampling.top_k_renorm_prob(probs, k)
if not no_top_p:
probs = flashinfer.sampling.top_p_renorm_prob(probs, p)
next_token_ids = flashinfer.sampling.sampling_from_probs(
probs, uniform_samples[0], deterministic=True)
return next_token_ids.view(-1)

View File

@@ -1,53 +1,55 @@
"""A layer that samples the next tokens from the model's outputs.""" """A layer that samples the next tokens from the model's outputs."""
from typing import Dict, List, Set, Tuple from typing import Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
from vllm.model_executor.layers.utils import apply_penalties
from vllm.utils import is_pin_memory_available, make_tensor_with_pad
from vllm.v1.outputs import SamplerOutput from vllm.v1.outputs import SamplerOutput
from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.ops.penalties import (apply_all_penalties,
apply_min_token_penalties)
from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler
_SAMPLING_EPS = 1e-5 _SAMPLING_EPS = 1e-5
class Sampler(nn.Module): class Sampler(nn.Module):
def __init__(self):
super().__init__()
self.topk_topp_sampler = TopKTopPSampler()
def forward( def forward(
self, self,
logits: torch.Tensor, logits: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> SamplerOutput: ) -> SamplerOutput:
_apply_min_token_penalties(logits, sampling_metadata.output_token_ids, needs_logprobs = sampling_metadata.max_num_logprobs > 0
sampling_metadata.stop_token_ids, if needs_logprobs:
sampling_metadata.min_tokens) # NOTE(woosuk): Use the original logits (before any penalties or
if not sampling_metadata.no_penalties: # temperature scaling) for the top-k logprobs.
assert sampling_metadata.prompt_token_ids is not None # This is different from the V0 sampler, which uses the logits that
_apply_penalties(logits, sampling_metadata.prompt_token_ids, # is used for sampling (after penalties and temperature scaling).
sampling_metadata.presence_penalties, # NOTE: We compute logprobs first because the below ops may
sampling_metadata.frequency_penalties, # modify the logits tensor in-place (and we don't want to clone
sampling_metadata.repetition_penalties, # the logits tensor for memory efficiency).
sampling_metadata.output_token_ids) topk_logprobs, topk_indices = self.get_topk_logprobs(
logits = self.apply_temperature(logits, sampling_metadata.temperature) logits, sampling_metadata)
logits = self.apply_top_k_top_p(logits, sampling_metadata)
probs = self.get_probs(logits)
sampled = self.sample(probs, sampling_metadata)
# Use int32 to reduce the tensor size.
sampled = sampled.to(torch.int32)
if sampling_metadata.max_num_logprobs > 0:
logprobs = self.get_logprobs(logits)
# FIXME: Mask the sampled token_id, get topk logprobs,
# and concatenate the topk with the sampled token_id.
topk_logprobs, topk_indices = torch.topk(
logprobs, sampling_metadata.max_num_logprobs, dim=-1)
# Use int32 to reduce the tensor size.
topk_indices = topk_indices.to(torch.int32)
else: else:
topk_logprobs = None topk_logprobs = None
topk_indices = None topk_indices = None
# Use float32 for the logits.
logits = logits.to(torch.float32)
# Apply penalties (e.g., min_tokens, freq_penalties).
logits = self.apply_penalties(logits, sampling_metadata)
# Apply temperature.
logits = self.apply_temperature(logits, sampling_metadata.temperature)
# Sample the next token.
sampled = self.sample(logits, sampling_metadata)
# Use int32 to reduce the tensor size.
sampled = sampled.to(torch.int32)
# NOTE: CPU-GPU synchronization happens here. # NOTE: CPU-GPU synchronization happens here.
sampler_output = SamplerOutput( sampler_output = SamplerOutput(
sampled_token_ids=sampled.tolist(), sampled_token_ids=sampled.tolist(),
@@ -63,71 +65,37 @@ class Sampler(nn.Module):
logits: torch.Tensor, logits: torch.Tensor,
temp: torch.Tensor, temp: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
# Use float32 to apply temperature scaling.
logits = logits.to(torch.float32)
# Avoid division by zero. # Avoid division by zero.
temp = torch.where(temp < _SAMPLING_EPS, 1.0, temp) temp = torch.where(temp < _SAMPLING_EPS, 1.0, temp)
# Use in-place division to avoid creating a new tensor. # Use in-place division to avoid creating a new tensor.
logits.div_(temp.unsqueeze(dim=1)) logits.div_(temp.unsqueeze(dim=1))
return logits return logits
def apply_top_k_top_p( def greedy_sample(self, logits: torch.Tensor) -> torch.Tensor:
self, return logits.argmax(dim=-1).view(-1)
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
return _apply_top_k_top_p(
logits,
sampling_metadata.no_top_k,
sampling_metadata.top_k,
sampling_metadata.no_top_p,
sampling_metadata.top_p,
)
def get_probs(self, logits: torch.Tensor) -> torch.Tensor:
return torch.softmax(logits, dim=-1, dtype=torch.float32)
def get_logprobs(self, logits: torch.Tensor) -> torch.Tensor:
return torch.log_softmax(logits, dim=-1, dtype=torch.float32)
def greedy_sample(self, probs: torch.Tensor) -> torch.Tensor:
return probs.argmax(dim=-1).view(-1)
def random_sample(
self,
probs: torch.Tensor,
generators: Dict[int, torch.Generator],
) -> torch.Tensor:
q = torch.empty_like(probs)
# NOTE(woosuk): To batch-process the requests without their own seeds,
# which is the common case, we first assume that every request does
# not have its own seed. Then, we overwrite the values for the requests
# that have their own seeds.
if len(generators) != probs.shape[0]:
# This might still be done here unnecessarily if there are greedies
q.exponential_()
if generators:
# TODO(woosuk): This can be slow because we handle each request
# one by one. Optimize this.
for i, generator in generators.items():
q[i].exponential_(generator=generator)
return probs.div_(q).argmax(dim=-1).view(-1)
def sample( def sample(
self, self,
probs: torch.Tensor, logits: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
assert not (sampling_metadata.all_greedy assert not (sampling_metadata.all_greedy
and sampling_metadata.all_random) and sampling_metadata.all_random)
if sampling_metadata.all_greedy: if sampling_metadata.all_greedy:
return self.greedy_sample(probs) return self.greedy_sample(logits)
if sampling_metadata.all_random:
return self.random_sample(probs, sampling_metadata.generators)
greedy_sampled = self.greedy_sample(probs) random_sampled = self.topk_topp_sampler(
random_sampled = self.random_sample(probs, logits,
sampling_metadata.generators) sampling_metadata.generators,
sampling_metadata.no_top_k,
sampling_metadata.top_k,
sampling_metadata.no_top_p,
sampling_metadata.top_p,
)
if sampling_metadata.all_random:
return random_sampled
greedy_sampled = self.greedy_sample(logits)
sampled = torch.where( sampled = torch.where(
sampling_metadata.temperature < _SAMPLING_EPS, sampling_metadata.temperature < _SAMPLING_EPS,
greedy_sampled, greedy_sampled,
@@ -135,86 +103,34 @@ class Sampler(nn.Module):
) )
return sampled return sampled
def get_topk_logprobs(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Tuple[torch.Tensor, torch.Tensor]:
logprobs = logits.log_softmax(dim=-1, dtype=torch.float32)
# FIXME: Mask the sampled token_id, get topk logprobs,
# and concatenate the topk with the sampled token_id.
topk_logprobs, topk_indices = torch.topk(
logprobs, sampling_metadata.max_num_logprobs, dim=-1)
# Use int32 to reduce the tensor size.
topk_indices = topk_indices.to(torch.int32)
return topk_logprobs, topk_indices
# TODO(woosuk): Optimize this with a custom kernel. def apply_penalties(
def _apply_top_k_top_p( self,
logits: torch.Tensor, logits: torch.Tensor,
no_top_k: bool, sampling_metadata: SamplingMetadata,
k: torch.Tensor, ) -> torch.Tensor:
no_top_p: bool, apply_min_token_penalties(logits, sampling_metadata.output_token_ids,
p: torch.Tensor, sampling_metadata.stop_token_ids,
) -> torch.Tensor: sampling_metadata.min_tokens)
if no_top_k and no_top_p: if not sampling_metadata.no_penalties:
assert sampling_metadata.prompt_token_ids is not None
logits = apply_all_penalties(
logits, sampling_metadata.prompt_token_ids,
sampling_metadata.presence_penalties,
sampling_metadata.frequency_penalties,
sampling_metadata.repetition_penalties,
sampling_metadata.output_token_ids)
return logits return logits
logits_sort, logits_idx = logits.sort(dim=-1, descending=False)
if not no_top_k:
# Apply top-k.
top_k_mask = logits_sort.size(1) - k.to(torch.long)
# Get all the top_k values.
top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1))
top_k_mask = logits_sort < top_k_mask
logits_sort.masked_fill_(top_k_mask, -float("inf"))
if not no_top_p:
# Apply top-p.
probs_sort = logits_sort.softmax(dim=-1)
probs_sum = probs_sort.cumsum(dim=-1)
top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1)
# at least one
top_p_mask[:, -1] = False
logits_sort.masked_fill_(top_p_mask, -float("inf"))
# Re-sort the probabilities.
logits = logits_sort.scatter(dim=-1, index=logits_idx, src=logits_sort)
return logits
def _apply_min_token_penalties(logits: torch.Tensor,
output_token_ids: List[List[int]],
stop_token_ids: List[Set[int]],
min_tokens: List[int]):
"""
Applies minimum token penalty by setting the logits of the stop tokens
to -inf.
"""
min_tokens_logits_to_penalize: List[Tuple[int, int]] = []
for index, min_token in enumerate(min_tokens):
if (len(output_token_ids[index]) < min_token):
for stop_token_id in stop_token_ids[index]:
min_tokens_logits_to_penalize.append((index, stop_token_id))
if min_tokens_logits_to_penalize:
logits[tuple(zip(*min_tokens_logits_to_penalize))] = -float("inf")
def _apply_penalties(logits: torch.Tensor, prompt_token_ids: torch.Tensor,
presence_penalties: torch.Tensor,
frequency_penalties: torch.Tensor,
repetition_penalties: torch.Tensor,
output_token_ids: List[List[int]]):
"""
Applies presence, frequency and repetition penalties to the logits.
"""
_, vocab_size = logits.shape
output_tokens_t = _convert_to_tensors(output_token_ids, vocab_size,
logits.device)
return apply_penalties(logits, prompt_token_ids, output_tokens_t,
presence_penalties, frequency_penalties,
repetition_penalties)
def _convert_to_tensors(output_token_ids: List[List[int]], vocab_size: int,
device: torch.device) -> torch.Tensor:
"""
Convert the different list data structures to tensors.
"""
output_tokens_tensor = make_tensor_with_pad(
output_token_ids,
# Use the value of vocab_size as a pad since we don't have a
# token_id of this value.
pad=vocab_size,
device="cpu",
dtype=torch.int64,
pin_memory=is_pin_memory_available(),
)
return output_tokens_tensor.to(device, non_blocking=True)