Compare commits
8 Commits
v0.6.6
...
v0.6.6.pos
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2339d59f92 | ||
|
|
1b875a0ef3 | ||
|
|
eb881ed006 | ||
|
|
46d4359450 | ||
|
|
81b979f2a8 | ||
|
|
371d04d39b | ||
|
|
0c0c2015c5 | ||
|
|
82d24f7aac |
@@ -60,7 +60,7 @@ vLLM is flexible and easy to use with:
|
|||||||
|
|
||||||
vLLM seamlessly supports most popular open-source models on HuggingFace, including:
|
vLLM seamlessly supports most popular open-source models on HuggingFace, including:
|
||||||
- Transformer-like LLMs (e.g., Llama)
|
- Transformer-like LLMs (e.g., Llama)
|
||||||
- Mixture-of-Expert LLMs (e.g., Mixtral)
|
- Mixture-of-Expert LLMs (e.g., Mixtral, Deepseek-V2 and V3)
|
||||||
- Embedding Models (e.g. E5-Mistral)
|
- Embedding Models (e.g. E5-Mistral)
|
||||||
- Multi-modal LLMs (e.g., LLaVA)
|
- Multi-modal LLMs (e.g., LLaVA)
|
||||||
|
|
||||||
|
|||||||
@@ -137,6 +137,11 @@ See [this page](#generative-models) for more information on how to use generativ
|
|||||||
- :code:`deepseek-ai/DeepSeek-V2`, :code:`deepseek-ai/DeepSeek-V2-Chat` etc.
|
- :code:`deepseek-ai/DeepSeek-V2`, :code:`deepseek-ai/DeepSeek-V2-Chat` etc.
|
||||||
-
|
-
|
||||||
- ✅︎
|
- ✅︎
|
||||||
|
* - :code:`DeepseekV3ForCausalLM`
|
||||||
|
- DeepSeek-V3
|
||||||
|
- :code:`deepseek-ai/DeepSeek-V3-Base`, :code:`deepseek-ai/DeepSeek-V3` etc.
|
||||||
|
-
|
||||||
|
- ✅︎
|
||||||
* - :code:`ExaoneForCausalLM`
|
* - :code:`ExaoneForCausalLM`
|
||||||
- EXAONE-3
|
- EXAONE-3
|
||||||
- :code:`LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct`, etc.
|
- :code:`LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct`, etc.
|
||||||
@@ -676,7 +681,7 @@ See [this page](#generative-models) for more information on how to use generativ
|
|||||||
- PaliGemma, PaliGemma 2
|
- PaliGemma, PaliGemma 2
|
||||||
- T + I\ :sup:`E`
|
- T + I\ :sup:`E`
|
||||||
- :code:`google/paligemma-3b-pt-224`, :code:`google/paligemma-3b-mix-224`, :code:`google/paligemma2-3b-ft-docci-448`, etc.
|
- :code:`google/paligemma-3b-pt-224`, :code:`google/paligemma-3b-mix-224`, :code:`google/paligemma2-3b-ft-docci-448`, etc.
|
||||||
-
|
-
|
||||||
- ✅︎
|
- ✅︎
|
||||||
-
|
-
|
||||||
* - :code:`Phi3VForCausalLM`
|
* - :code:`Phi3VForCausalLM`
|
||||||
|
|||||||
@@ -112,7 +112,13 @@ completion = client.chat.completions.create(
|
|||||||
|
|
||||||
## Extra HTTP Headers
|
## Extra HTTP Headers
|
||||||
|
|
||||||
Only `X-Request-Id` HTTP request header is supported for now.
|
Only `X-Request-Id` HTTP request header is supported for now. It can be enabled
|
||||||
|
with `--enable-request-id-headers`.
|
||||||
|
|
||||||
|
> Note that enablement of the headers can impact performance significantly at high QPS
|
||||||
|
> rates. We recommend implementing HTTP headers at the router level (e.g. via Istio),
|
||||||
|
> rather than within the vLLM layer for this reason.
|
||||||
|
> See https://github.com/vllm-project/vllm/pull/11529 for more details.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
completion = client.chat.completions.create(
|
completion = client.chat.completions.create(
|
||||||
|
|||||||
@@ -61,6 +61,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
|
|||||||
"DeepseekForCausalLM": _HfExamplesInfo("deepseek-ai/deepseek-llm-7b-chat"),
|
"DeepseekForCausalLM": _HfExamplesInfo("deepseek-ai/deepseek-llm-7b-chat"),
|
||||||
"DeepseekV2ForCausalLM": _HfExamplesInfo("deepseek-ai/DeepSeek-V2-Lite-Chat", # noqa: E501
|
"DeepseekV2ForCausalLM": _HfExamplesInfo("deepseek-ai/DeepSeek-V2-Lite-Chat", # noqa: E501
|
||||||
trust_remote_code=True),
|
trust_remote_code=True),
|
||||||
|
"DeepseekV3ForCausalLM": _HfExamplesInfo("deepseek-ai/DeepSeek-V3", # noqa: E501
|
||||||
|
trust_remote_code=True),
|
||||||
"ExaoneForCausalLM": _HfExamplesInfo("LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct"), # noqa: E501
|
"ExaoneForCausalLM": _HfExamplesInfo("LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct"), # noqa: E501
|
||||||
"FalconForCausalLM": _HfExamplesInfo("tiiuae/falcon-7b"),
|
"FalconForCausalLM": _HfExamplesInfo("tiiuae/falcon-7b"),
|
||||||
"GemmaForCausalLM": _HfExamplesInfo("google/gemma-2b"),
|
"GemmaForCausalLM": _HfExamplesInfo("google/gemma-2b"),
|
||||||
|
|||||||
@@ -68,7 +68,7 @@ def _create_default_sampling_metadata(
|
|||||||
no_top_p=True,
|
no_top_p=True,
|
||||||
no_top_k=True,
|
no_top_k=True,
|
||||||
generators={},
|
generators={},
|
||||||
max_num_logprobs=VOCAB_SIZE,
|
max_num_logprobs=0,
|
||||||
prompt_token_ids=_create_prompt_tokens_tensor(prompt_token_ids,
|
prompt_token_ids=_create_prompt_tokens_tensor(prompt_token_ids,
|
||||||
vocab_size, device),
|
vocab_size, device),
|
||||||
output_token_ids=output_token_ids,
|
output_token_ids=output_token_ids,
|
||||||
@@ -169,20 +169,14 @@ def test_sampler_min_tokens_penalty(device: str, batch_size: int):
|
|||||||
sampling_metadata.min_tokens = min_tokens
|
sampling_metadata.min_tokens = min_tokens
|
||||||
sampling_metadata.stop_token_ids = stop_token_ids
|
sampling_metadata.stop_token_ids = stop_token_ids
|
||||||
sampler = Sampler()
|
sampler = Sampler()
|
||||||
sampler_output = sampler(fake_logits, sampling_metadata)
|
logits = sampler.apply_penalties(fake_logits, sampling_metadata)
|
||||||
|
logits = logits.cpu()
|
||||||
for batch_idx in range(batch_size):
|
for batch_idx in range(batch_size):
|
||||||
for vocab in range(VOCAB_SIZE):
|
for token_id in range(VOCAB_SIZE):
|
||||||
# Verify that the logprobs for stop token ids is set
|
if token_id in stop_token_ids[batch_idx]:
|
||||||
# to -inf.
|
assert logits[batch_idx][token_id] == -float("inf")
|
||||||
logprob_index = torch.where(
|
|
||||||
sampler_output.logprob_token_ids[batch_idx] ==
|
|
||||||
vocab)[0].item()
|
|
||||||
if vocab in stop_token_ids[batch_idx]:
|
|
||||||
assert sampler_output.logprobs[batch_idx][
|
|
||||||
logprob_index] == -float("inf")
|
|
||||||
else:
|
else:
|
||||||
assert sampler_output.logprobs[batch_idx][
|
assert logits[batch_idx][token_id] != -float("inf")
|
||||||
logprob_index] != -float("inf")
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||||
@@ -205,18 +199,14 @@ def test_sampler_presence_penalty(device: str, batch_size: int,
|
|||||||
batch_size, presence_penalty, torch.device(device))
|
batch_size, presence_penalty, torch.device(device))
|
||||||
sampling_metadata.no_penalties = False
|
sampling_metadata.no_penalties = False
|
||||||
sampler = Sampler()
|
sampler = Sampler()
|
||||||
sampler_output = sampler(fake_logits, sampling_metadata)
|
logits = sampler.apply_penalties(fake_logits, sampling_metadata)
|
||||||
|
logits = logits.cpu()
|
||||||
for batch_idx in range(batch_size):
|
for batch_idx in range(batch_size):
|
||||||
# The logprobs in the SamplerOutput are arranged in descending order.
|
# Since all tokens initially have the same logits, the non-penalized
|
||||||
# Since all tokens initially have the same logprobs, the non-penalized
|
# token ID will be the one with the highest logit value, while the
|
||||||
# tokens will appear at the beginning, while the penalized tokens
|
# penalized token ID will be the one with the lowest logit value.
|
||||||
# will appear at the end of the list.
|
non_penalized_token_id = logits[batch_idx].argmax().item()
|
||||||
penalized_token_id = sampler_output.logprob_token_ids[batch_idx][
|
penalized_token_id = logits[batch_idx].argmin().item()
|
||||||
VOCAB_SIZE - 1]
|
|
||||||
penalized_log_prod = sampler_output.logprobs[batch_idx][VOCAB_SIZE - 1]
|
|
||||||
non_penalized_token_id = sampler_output.logprob_token_ids[batch_idx][0]
|
|
||||||
non_penalized_log_prod = sampler_output.logprobs[batch_idx][0]
|
|
||||||
assert non_penalized_log_prod > penalized_log_prod
|
|
||||||
if presence_penalty > 0:
|
if presence_penalty > 0:
|
||||||
# If `presence_penalty` is set to a value greater than 0, it
|
# If `presence_penalty` is set to a value greater than 0, it
|
||||||
# indicates a preference for new tokens over those already
|
# indicates a preference for new tokens over those already
|
||||||
@@ -256,11 +246,11 @@ def test_sampler_frequency_penalty(device: str, batch_size: int,
|
|||||||
sampling_metadata.output_token_ids = output_token_ids
|
sampling_metadata.output_token_ids = output_token_ids
|
||||||
sampling_metadata.no_penalties = False
|
sampling_metadata.no_penalties = False
|
||||||
sampler = Sampler()
|
sampler = Sampler()
|
||||||
sampler_output = sampler(fake_logits, sampling_metadata)
|
logits = sampler.apply_penalties(fake_logits, sampling_metadata)
|
||||||
|
logits = logits.cpu()
|
||||||
for batch_idx in range(batch_size):
|
for batch_idx in range(batch_size):
|
||||||
logprobs_token_ids = sampler_output.logprob_token_ids[batch_idx]
|
non_penalized_token_id = logits[batch_idx].argmax().item()
|
||||||
non_penalized_token_id = logprobs_token_ids[0]
|
penalized_token_id = logits[batch_idx].argmin().item()
|
||||||
penalized_token_id = logprobs_token_ids[VOCAB_SIZE - 1]
|
|
||||||
distinct_sorted_token_ids_in_output = \
|
distinct_sorted_token_ids_in_output = \
|
||||||
sorted_token_ids_in_output[batch_idx]
|
sorted_token_ids_in_output[batch_idx]
|
||||||
most_frequent_token_id = distinct_sorted_token_ids_in_output[
|
most_frequent_token_id = distinct_sorted_token_ids_in_output[
|
||||||
@@ -305,11 +295,11 @@ def test_sampler_repetition_penalty(device: str, batch_size: int,
|
|||||||
batch_size, repetition_penalty, torch.device(device))
|
batch_size, repetition_penalty, torch.device(device))
|
||||||
sampling_metadata.no_penalties = False
|
sampling_metadata.no_penalties = False
|
||||||
sampler = Sampler()
|
sampler = Sampler()
|
||||||
sampler_output = sampler(fake_logits, sampling_metadata)
|
logits = sampler.apply_penalties(fake_logits, sampling_metadata)
|
||||||
|
logits = logits.cpu()
|
||||||
for batch_idx in range(batch_size):
|
for batch_idx in range(batch_size):
|
||||||
logprobs_token_ids = sampler_output.logprob_token_ids[batch_idx]
|
non_penalized_token_id = logits[batch_idx].argmax().item()
|
||||||
non_penalized_token_id = logprobs_token_ids[0]
|
penalized_token_id = logits[batch_idx].argmin().item()
|
||||||
penalized_token_id = logprobs_token_ids[VOCAB_SIZE - 1]
|
|
||||||
prompt_tokens = sampling_metadata.prompt_token_ids[
|
prompt_tokens = sampling_metadata.prompt_token_ids[
|
||||||
batch_idx][:].tolist()
|
batch_idx][:].tolist()
|
||||||
output_tokens = sampling_metadata.output_token_ids[batch_idx]
|
output_tokens = sampling_metadata.output_token_ids[batch_idx]
|
||||||
|
|||||||
@@ -208,8 +208,8 @@ def wrap_inductor(graph: fx.GraphModule,
|
|||||||
from torch._inductor.compile_fx import graph_returns_tuple
|
from torch._inductor.compile_fx import graph_returns_tuple
|
||||||
returns_tuple = graph_returns_tuple(graph)
|
returns_tuple = graph_returns_tuple(graph)
|
||||||
|
|
||||||
# this is the graph we return to Dynamo to run
|
# this is the callable we return to Dynamo to run
|
||||||
def compiled_graph(*args) -> Optional[fx.CompiledFxGraph]:
|
def compiled_graph(*args):
|
||||||
# convert args to list
|
# convert args to list
|
||||||
list_args = list(args)
|
list_args = list(args)
|
||||||
graph_output = inductor_compiled_graph(list_args)
|
graph_output = inductor_compiled_graph(list_args)
|
||||||
@@ -537,7 +537,8 @@ class VllmBackend:
|
|||||||
example_inputs[x].clone() for x in self.sym_tensor_indices
|
example_inputs[x].clone() for x in self.sym_tensor_indices
|
||||||
]
|
]
|
||||||
|
|
||||||
def copy_and_call(*args) -> fx.GraphModule:
|
# this is the callable we return to Dynamo to run
|
||||||
|
def copy_and_call(*args):
|
||||||
list_args = list(args)
|
list_args = list(args)
|
||||||
for i, index in enumerate(self.sym_tensor_indices):
|
for i, index in enumerate(self.sym_tensor_indices):
|
||||||
runtime_tensor = list_args[index]
|
runtime_tensor = list_args[index]
|
||||||
|
|||||||
@@ -30,7 +30,7 @@ if TYPE_CHECKING:
|
|||||||
VLLM_LOGGING_CONFIG_PATH: Optional[str] = None
|
VLLM_LOGGING_CONFIG_PATH: Optional[str] = None
|
||||||
VLLM_TRACE_FUNCTION: int = 0
|
VLLM_TRACE_FUNCTION: int = 0
|
||||||
VLLM_ATTENTION_BACKEND: Optional[str] = None
|
VLLM_ATTENTION_BACKEND: Optional[str] = None
|
||||||
VLLM_USE_FLASHINFER_SAMPLER: bool = False
|
VLLM_USE_FLASHINFER_SAMPLER: Optional[bool] = None
|
||||||
VLLM_USE_FLASHINFER_REJECTION_SAMPLER: bool = False
|
VLLM_USE_FLASHINFER_REJECTION_SAMPLER: bool = False
|
||||||
VLLM_FLASHINFER_FORCE_TENSOR_CORES: bool = False
|
VLLM_FLASHINFER_FORCE_TENSOR_CORES: bool = False
|
||||||
VLLM_PP_LAYER_PARTITION: Optional[str] = None
|
VLLM_PP_LAYER_PARTITION: Optional[str] = None
|
||||||
@@ -277,7 +277,8 @@ environment_variables: Dict[str, Callable[[], Any]] = {
|
|||||||
|
|
||||||
# If set, vllm will use flashinfer sampler
|
# If set, vllm will use flashinfer sampler
|
||||||
"VLLM_USE_FLASHINFER_SAMPLER":
|
"VLLM_USE_FLASHINFER_SAMPLER":
|
||||||
lambda: bool(int(os.getenv("VLLM_USE_FLASHINFER_SAMPLER", "0"))),
|
lambda: bool(int(os.environ["VLLM_USE_FLASHINFER_SAMPLER"]))
|
||||||
|
if "VLLM_USE_FLASHINFER_SAMPLER" in os.environ else None,
|
||||||
|
|
||||||
# If set, vllm will force flashinfer to use tensor cores;
|
# If set, vllm will force flashinfer to use tensor cores;
|
||||||
# otherwise will use heuristic based on model architecture.
|
# otherwise will use heuristic based on model architecture.
|
||||||
|
|||||||
@@ -41,9 +41,20 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def apply(self, layer: torch.nn.Module, x: torch.Tensor,
|
def apply(
|
||||||
router_logits: torch.Tensor, top_k: int, renormalize: bool,
|
self,
|
||||||
use_grouped_topk: bool) -> torch.Tensor:
|
layer: torch.nn.Module,
|
||||||
|
x: torch.Tensor,
|
||||||
|
router_logits: torch.Tensor,
|
||||||
|
top_k: int,
|
||||||
|
renormalize: bool,
|
||||||
|
use_grouped_topk: bool = False,
|
||||||
|
topk_group: Optional[int] = None,
|
||||||
|
num_expert_group: Optional[int] = None,
|
||||||
|
custom_routing_function: Optional[Callable] = None,
|
||||||
|
scoring_func: str = "softmax",
|
||||||
|
e_score_correction_bias: Optional[torch.Tensor] = None
|
||||||
|
) -> torch.Tensor:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
@@ -79,7 +90,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|||||||
router_logits: torch.Tensor,
|
router_logits: torch.Tensor,
|
||||||
top_k: int,
|
top_k: int,
|
||||||
renormalize: bool,
|
renormalize: bool,
|
||||||
use_grouped_topk: bool,
|
use_grouped_topk: bool = False,
|
||||||
topk_group: Optional[int] = None,
|
topk_group: Optional[int] = None,
|
||||||
num_expert_group: Optional[int] = None,
|
num_expert_group: Optional[int] = None,
|
||||||
custom_routing_function: Optional[Callable] = None,
|
custom_routing_function: Optional[Callable] = None,
|
||||||
|
|||||||
@@ -440,11 +440,13 @@ class AWQMoEMethod(FusedMoEMethodBase):
|
|||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
router_logits: torch.Tensor,
|
router_logits: torch.Tensor,
|
||||||
top_k: int,
|
top_k: int,
|
||||||
renormalize: bool = True,
|
renormalize: bool,
|
||||||
use_grouped_topk: bool = False,
|
use_grouped_topk: bool = False,
|
||||||
num_expert_group: Optional[int] = None,
|
|
||||||
topk_group: Optional[int] = None,
|
topk_group: Optional[int] = None,
|
||||||
|
num_expert_group: Optional[int] = None,
|
||||||
custom_routing_function: Optional[Callable] = None,
|
custom_routing_function: Optional[Callable] = None,
|
||||||
|
scoring_func: str = "softmax",
|
||||||
|
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||||
hidden_states=x,
|
hidden_states=x,
|
||||||
@@ -454,7 +456,9 @@ class AWQMoEMethod(FusedMoEMethodBase):
|
|||||||
renormalize=renormalize,
|
renormalize=renormalize,
|
||||||
topk_group=topk_group,
|
topk_group=topk_group,
|
||||||
num_expert_group=num_expert_group,
|
num_expert_group=num_expert_group,
|
||||||
custom_routing_function=custom_routing_function)
|
custom_routing_function=custom_routing_function,
|
||||||
|
scoring_func=scoring_func,
|
||||||
|
e_score_correction_bias=e_score_correction_bias)
|
||||||
|
|
||||||
return torch.ops.vllm.fused_marlin_moe(
|
return torch.ops.vllm.fused_marlin_moe(
|
||||||
x,
|
x,
|
||||||
|
|||||||
@@ -203,13 +203,14 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
|||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
router_logits: torch.Tensor,
|
router_logits: torch.Tensor,
|
||||||
top_k: int,
|
top_k: int,
|
||||||
renormalize: bool = True,
|
renormalize: bool,
|
||||||
use_grouped_topk: bool = False,
|
use_grouped_topk: bool = False,
|
||||||
num_expert_group: Optional[int] = None,
|
|
||||||
topk_group: Optional[int] = None,
|
topk_group: Optional[int] = None,
|
||||||
|
num_expert_group: Optional[int] = None,
|
||||||
custom_routing_function: Optional[Callable] = None,
|
custom_routing_function: Optional[Callable] = None,
|
||||||
|
scoring_func: str = "softmax",
|
||||||
|
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
|
||||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||||
|
|
||||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||||
@@ -220,7 +221,9 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
|||||||
renormalize=renormalize,
|
renormalize=renormalize,
|
||||||
topk_group=topk_group,
|
topk_group=topk_group,
|
||||||
num_expert_group=num_expert_group,
|
num_expert_group=num_expert_group,
|
||||||
custom_routing_function=custom_routing_function)
|
custom_routing_function=custom_routing_function,
|
||||||
|
scoring_func=scoring_func,
|
||||||
|
e_score_correction_bias=e_score_correction_bias)
|
||||||
|
|
||||||
return fused_experts(x,
|
return fused_experts(x,
|
||||||
layer.w13_weight,
|
layer.w13_weight,
|
||||||
@@ -476,12 +479,15 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
|||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
router_logits: torch.Tensor,
|
router_logits: torch.Tensor,
|
||||||
top_k: int,
|
top_k: int,
|
||||||
renormalize: bool = True,
|
renormalize: bool,
|
||||||
use_grouped_topk: bool = False,
|
use_grouped_topk: bool = False,
|
||||||
num_expert_group: Optional[int] = None,
|
|
||||||
topk_group: Optional[int] = None,
|
topk_group: Optional[int] = None,
|
||||||
|
num_expert_group: Optional[int] = None,
|
||||||
custom_routing_function: Optional[Callable] = None,
|
custom_routing_function: Optional[Callable] = None,
|
||||||
|
scoring_func: str = "softmax",
|
||||||
|
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
|
||||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||||
hidden_states=x,
|
hidden_states=x,
|
||||||
router_logits=router_logits,
|
router_logits=router_logits,
|
||||||
@@ -490,7 +496,9 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
|||||||
renormalize=renormalize,
|
renormalize=renormalize,
|
||||||
topk_group=topk_group,
|
topk_group=topk_group,
|
||||||
num_expert_group=num_expert_group,
|
num_expert_group=num_expert_group,
|
||||||
custom_routing_function=custom_routing_function)
|
custom_routing_function=custom_routing_function,
|
||||||
|
scoring_func=scoring_func,
|
||||||
|
e_score_correction_bias=e_score_correction_bias)
|
||||||
|
|
||||||
return torch.ops.vllm.fused_marlin_moe(
|
return torch.ops.vllm.fused_marlin_moe(
|
||||||
x,
|
x,
|
||||||
|
|||||||
@@ -99,11 +99,13 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
|
|||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
router_logits: torch.Tensor,
|
router_logits: torch.Tensor,
|
||||||
top_k: int,
|
top_k: int,
|
||||||
renormalize: bool = True,
|
renormalize: bool,
|
||||||
use_grouped_topk: bool = False,
|
use_grouped_topk: bool = False,
|
||||||
num_expert_group: Optional[int] = None,
|
|
||||||
topk_group: Optional[int] = None,
|
topk_group: Optional[int] = None,
|
||||||
|
num_expert_group: Optional[int] = None,
|
||||||
custom_routing_function: Optional[Callable] = None,
|
custom_routing_function: Optional[Callable] = None,
|
||||||
|
scoring_func: str = "softmax",
|
||||||
|
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||||
|
|
||||||
@@ -115,7 +117,9 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
|
|||||||
renormalize=renormalize,
|
renormalize=renormalize,
|
||||||
topk_group=topk_group,
|
topk_group=topk_group,
|
||||||
num_expert_group=num_expert_group,
|
num_expert_group=num_expert_group,
|
||||||
custom_routing_function=custom_routing_function)
|
custom_routing_function=custom_routing_function,
|
||||||
|
scoring_func=scoring_func,
|
||||||
|
e_score_correction_bias=e_score_correction_bias)
|
||||||
|
|
||||||
return fused_experts(x,
|
return fused_experts(x,
|
||||||
layer.w13_weight,
|
layer.w13_weight,
|
||||||
|
|||||||
@@ -601,14 +601,13 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
router_logits: torch.Tensor,
|
router_logits: torch.Tensor,
|
||||||
top_k: int,
|
top_k: int,
|
||||||
renormalize: bool,
|
renormalize: bool,
|
||||||
use_grouped_topk: bool,
|
use_grouped_topk: bool = False,
|
||||||
topk_group: Optional[int] = None,
|
topk_group: Optional[int] = None,
|
||||||
num_expert_group: Optional[int] = None,
|
num_expert_group: Optional[int] = None,
|
||||||
custom_routing_function: Optional[Callable] = None,
|
custom_routing_function: Optional[Callable] = None,
|
||||||
scoring_func: str = "softmax",
|
scoring_func: str = "softmax",
|
||||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
|
||||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||||
|
|
||||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||||
|
|||||||
@@ -532,11 +532,13 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
|||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
router_logits: torch.Tensor,
|
router_logits: torch.Tensor,
|
||||||
top_k: int,
|
top_k: int,
|
||||||
renormalize: bool = True,
|
renormalize: bool,
|
||||||
use_grouped_topk: bool = False,
|
use_grouped_topk: bool = False,
|
||||||
num_expert_group: Optional[int] = None,
|
|
||||||
topk_group: Optional[int] = None,
|
topk_group: Optional[int] = None,
|
||||||
|
num_expert_group: Optional[int] = None,
|
||||||
custom_routing_function: Optional[Callable] = None,
|
custom_routing_function: Optional[Callable] = None,
|
||||||
|
scoring_func: str = "softmax",
|
||||||
|
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
# The input must currently be float16
|
# The input must currently be float16
|
||||||
orig_dtype = x.dtype
|
orig_dtype = x.dtype
|
||||||
@@ -550,7 +552,9 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
|||||||
renormalize=renormalize,
|
renormalize=renormalize,
|
||||||
topk_group=topk_group,
|
topk_group=topk_group,
|
||||||
num_expert_group=num_expert_group,
|
num_expert_group=num_expert_group,
|
||||||
custom_routing_function=None)
|
custom_routing_function=custom_routing_function,
|
||||||
|
scoring_func=scoring_func,
|
||||||
|
e_score_correction_bias=e_score_correction_bias)
|
||||||
|
|
||||||
return torch.ops.vllm.fused_marlin_moe(
|
return torch.ops.vllm.fused_marlin_moe(
|
||||||
x,
|
x,
|
||||||
|
|||||||
@@ -9,14 +9,13 @@ from vllm.inputs import INPUT_REGISTRY, InputRegistry, PromptType
|
|||||||
from vllm.inputs.preprocess import InputPreprocessor
|
from vllm.inputs.preprocess import InputPreprocessor
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.outputs import PoolingRequestOutput, RequestOutput
|
from vllm.outputs import RequestOutput
|
||||||
from vllm.pooling_params import PoolingParams
|
from vllm.pooling_params import PoolingParams
|
||||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||||
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
|
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
|
||||||
from vllm.usage.usage_lib import UsageContext
|
from vllm.usage.usage_lib import UsageContext
|
||||||
from vllm.v1.engine.async_stream import AsyncStream
|
|
||||||
from vllm.v1.engine.core_client import EngineCoreClient
|
from vllm.v1.engine.core_client import EngineCoreClient
|
||||||
from vllm.v1.engine.detokenizer import Detokenizer
|
from vllm.v1.engine.detokenizer import Detokenizer
|
||||||
from vllm.v1.engine.processor import Processor
|
from vllm.v1.engine.processor import Processor
|
||||||
@@ -54,10 +53,8 @@ class AsyncLLM(EngineClient):
|
|||||||
lora_config=vllm_config.lora_config)
|
lora_config=vllm_config.lora_config)
|
||||||
self.tokenizer.ping()
|
self.tokenizer.ping()
|
||||||
|
|
||||||
# Request streams (map of request_id -> AsyncStream).
|
# Request streams (map of request_id -> queue).
|
||||||
self.request_streams: Dict[str, AsyncStream] = {}
|
self.rid_to_queue: Dict[str, asyncio.Queue] = {}
|
||||||
# List of cancelled request ids to be aborted.
|
|
||||||
self.client_aborted_requests: List[str] = []
|
|
||||||
|
|
||||||
# Processor (converts Inputs --> EngineCoreRequests).
|
# Processor (converts Inputs --> EngineCoreRequests).
|
||||||
self.processor = Processor(
|
self.processor = Processor(
|
||||||
@@ -153,14 +150,13 @@ class AsyncLLM(EngineClient):
|
|||||||
trace_headers: Optional[Mapping[str, str]] = None,
|
trace_headers: Optional[Mapping[str, str]] = None,
|
||||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||||
priority: int = 0,
|
priority: int = 0,
|
||||||
) -> AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None]:
|
) -> asyncio.Queue[RequestOutput]:
|
||||||
"""Add new request to the AsyncLLM."""
|
"""Add new request to the AsyncLLM."""
|
||||||
|
|
||||||
if self.detokenizer.is_request_active(request_id):
|
# 1) Create a new output queue for the request.
|
||||||
raise ValueError(f"Request {request_id} already exists.")
|
if request_id in self.rid_to_queue:
|
||||||
|
raise ValueError(f"Request id {request_id} already running.")
|
||||||
# 1) Create a new AsyncStream for the request.
|
self.rid_to_queue[request_id] = asyncio.Queue()
|
||||||
stream = self._add_request_to_streams(request_id)
|
|
||||||
|
|
||||||
# 2) Convert input --> DetokenizerRequest / EngineCoreRequest.
|
# 2) Convert input --> DetokenizerRequest / EngineCoreRequest.
|
||||||
detokenizer_req, engine_core_req = self.processor.process_inputs(
|
detokenizer_req, engine_core_req = self.processor.process_inputs(
|
||||||
@@ -173,8 +169,10 @@ class AsyncLLM(EngineClient):
|
|||||||
# 4) Add the EngineCoreRequest to EngineCore (separate process).
|
# 4) Add the EngineCoreRequest to EngineCore (separate process).
|
||||||
await self.engine_core.add_request_async(engine_core_req)
|
await self.engine_core.add_request_async(engine_core_req)
|
||||||
|
|
||||||
# 5) Return the generator.
|
if self.log_requests:
|
||||||
return stream.generator()
|
logger.info("Added request %s.", request_id)
|
||||||
|
|
||||||
|
return self.rid_to_queue[request_id]
|
||||||
|
|
||||||
# TODO: we should support multiple prompts in one call, as you
|
# TODO: we should support multiple prompts in one call, as you
|
||||||
# can do with LLM.generate. So that for multi-prompt completion
|
# can do with LLM.generate. So that for multi-prompt completion
|
||||||
@@ -194,7 +192,7 @@ class AsyncLLM(EngineClient):
|
|||||||
"""
|
"""
|
||||||
Main function called by the API server to kick off a request
|
Main function called by the API server to kick off a request
|
||||||
* 1) Making an AsyncStream corresponding to the Request.
|
* 1) Making an AsyncStream corresponding to the Request.
|
||||||
# 2) Processing the Input.
|
* 2) Processing the Input.
|
||||||
* 3) Adding the Request to the Detokenizer.
|
* 3) Adding the Request to the Detokenizer.
|
||||||
* 4) Adding the Request to the EngineCore (separate process).
|
* 4) Adding the Request to the EngineCore (separate process).
|
||||||
|
|
||||||
@@ -206,14 +204,15 @@ class AsyncLLM(EngineClient):
|
|||||||
returning the RequestOutput back to the caller.
|
returning the RequestOutput back to the caller.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# We start the output_handler on the first call to generate() so that
|
try:
|
||||||
# we can call __init__ before the event loop starts, which enables us
|
# We start the output_handler on the first call to generate() so
|
||||||
# to handle startup failure gracefully in the OpenAI server.
|
# we can call __init__ before the event loop, which enables us
|
||||||
if self.output_handler is None:
|
# to handle startup failure gracefully in the OpenAI server.
|
||||||
self.output_handler = asyncio.create_task(
|
if self.output_handler is None:
|
||||||
self._run_output_handler())
|
self.output_handler = asyncio.create_task(
|
||||||
|
self._run_output_handler())
|
||||||
|
|
||||||
async for output in await self.add_request(
|
q = await self.add_request(
|
||||||
request_id,
|
request_id,
|
||||||
prompt,
|
prompt,
|
||||||
sampling_params,
|
sampling_params,
|
||||||
@@ -221,79 +220,42 @@ class AsyncLLM(EngineClient):
|
|||||||
trace_headers=trace_headers,
|
trace_headers=trace_headers,
|
||||||
prompt_adapter_request=prompt_adapter_request,
|
prompt_adapter_request=prompt_adapter_request,
|
||||||
priority=priority,
|
priority=priority,
|
||||||
):
|
)
|
||||||
yield output
|
|
||||||
|
|
||||||
def _finish_stream(self, request_id: str):
|
# The output_handler task pushes items into the queue.
|
||||||
stream = self.request_streams.pop(request_id, None)
|
# This task pulls from the queue and yields to caller.
|
||||||
if stream is not None:
|
while True:
|
||||||
stream.finish()
|
# Note: drain queue without await if possible (avoids
|
||||||
|
# task switching under load which helps performance).
|
||||||
|
out = q.get_nowait() if q.qsize() > 0 else await q.get()
|
||||||
|
|
||||||
def _add_request_to_streams(
|
# Note: both Detokenizer and EngineCore handle their
|
||||||
self,
|
# own request cleanup based on finished.
|
||||||
request_id: str,
|
if out.finished:
|
||||||
) -> AsyncStream:
|
del self.rid_to_queue[request_id]
|
||||||
|
yield out
|
||||||
|
break
|
||||||
|
|
||||||
if request_id in self.request_streams:
|
yield out
|
||||||
raise ValueError(f"Request id {request_id} already running.")
|
|
||||||
|
|
||||||
# Avoid streams having circular ref to parent AsyncLLM object.
|
# If the request is disconnected by the client, the
|
||||||
aborted_reqs = self.client_aborted_requests
|
# generate() task will be canceled. So, we abort the
|
||||||
stream = AsyncStream(request_id, aborted_reqs.append)
|
# request if we end up here.
|
||||||
self.request_streams[request_id] = stream
|
except asyncio.CancelledError:
|
||||||
|
await self.abort(request_id)
|
||||||
if self.log_requests:
|
raise
|
||||||
logger.info("Added request %s.", request_id)
|
|
||||||
|
|
||||||
return stream
|
|
||||||
|
|
||||||
async def _process_cancellations(self) -> None:
|
|
||||||
"""
|
|
||||||
Process requests cancelled from user disconnecting.
|
|
||||||
|
|
||||||
When a client disconnects, AsyncStream._cancel() is called.
|
|
||||||
We passed a callback to AsyncStream(), which appends to
|
|
||||||
self.client_aborted_requests.
|
|
||||||
|
|
||||||
As a result, if any requests are canceled from the user side
|
|
||||||
the request_id will show up in self.client_aborted_requests.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Avoid streams having circular ref to parent AsyncLLM object.
|
|
||||||
if not self.client_aborted_requests:
|
|
||||||
return
|
|
||||||
reqs_to_abort = self.client_aborted_requests.copy()
|
|
||||||
self.client_aborted_requests.clear()
|
|
||||||
|
|
||||||
# Remove from Detokenizer.
|
|
||||||
self.detokenizer.abort_requests(reqs_to_abort)
|
|
||||||
|
|
||||||
# Remove from RequestStreams.
|
|
||||||
for request_id in reqs_to_abort:
|
|
||||||
if self.log_requests:
|
|
||||||
logger.info("User-cancelled request %s.", request_id)
|
|
||||||
self._finish_stream(request_id)
|
|
||||||
|
|
||||||
# Remove from EngineCore.
|
|
||||||
await self.engine_core.abort_requests_async(reqs_to_abort)
|
|
||||||
|
|
||||||
def _process_request_outputs(self, request_outputs: List[RequestOutput]):
|
def _process_request_outputs(self, request_outputs: List[RequestOutput]):
|
||||||
"""Process outputs by putting them into per-request AsyncStreams."""
|
"""Process outputs by putting them into per-request queues."""
|
||||||
|
|
||||||
for request_output in request_outputs:
|
for request_output in request_outputs:
|
||||||
request_id = request_output.request_id
|
request_id = request_output.request_id
|
||||||
assert request_id in self.request_streams
|
|
||||||
|
|
||||||
# Each request in the API server pulls from the per-request stream.
|
# Note: it is possible a request was aborted and removed from
|
||||||
stream = self.request_streams.get(request_id)
|
# the state due to client cancellations, so if we encounter a
|
||||||
if stream is not None:
|
# request id not in the state, we skip.
|
||||||
stream.put(request_output)
|
if request_id in self.rid_to_queue:
|
||||||
|
self.rid_to_queue[request_id].put_nowait(request_output)
|
||||||
# If finished, remove from the tracker.
|
|
||||||
if request_output.finished:
|
|
||||||
if self.log_requests:
|
|
||||||
logger.info("Finished request %s.", request_id)
|
|
||||||
self._finish_stream(request_id)
|
|
||||||
|
|
||||||
async def _run_output_handler(self):
|
async def _run_output_handler(self):
|
||||||
"""Background loop: pulls from EngineCore and pushes to AsyncStreams."""
|
"""Background loop: pulls from EngineCore and pushes to AsyncStreams."""
|
||||||
@@ -306,24 +268,27 @@ class AsyncLLM(EngineClient):
|
|||||||
# 2) Detokenize based on the output.
|
# 2) Detokenize based on the output.
|
||||||
request_outputs, reqs_to_abort = self.detokenizer.step(outputs)
|
request_outputs, reqs_to_abort = self.detokenizer.step(outputs)
|
||||||
|
|
||||||
# 3) Put the RequestOutputs into the per-request AsyncStreams.
|
# 3) Put the RequestOutputs into the per-request queues.
|
||||||
self._process_request_outputs(request_outputs)
|
self._process_request_outputs(request_outputs)
|
||||||
|
|
||||||
# 4) Abort any requests that finished due to stop strings.
|
# 4) Abort any requests that finished due to stop strings.
|
||||||
await self.engine_core.abort_requests_async(reqs_to_abort)
|
await self.engine_core.abort_requests_async(reqs_to_abort)
|
||||||
|
|
||||||
# 5) Abort any requests due to client cancellations.
|
|
||||||
await self._process_cancellations()
|
|
||||||
|
|
||||||
except BaseException as e:
|
except BaseException as e:
|
||||||
logger.error(e)
|
logger.error(e)
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
# TODO: can we eliminate these?
|
|
||||||
|
|
||||||
async def abort(self, request_id: str) -> None:
|
async def abort(self, request_id: str) -> None:
|
||||||
# Note: Who Calls this? I dont think this is actually used.
|
"""Abort RequestId in self, detokenizer, and engine core."""
|
||||||
raise ValueError("Not Supported on V1 yet.")
|
|
||||||
|
request_ids = [request_id]
|
||||||
|
await self.engine_core.abort_requests_async(request_ids)
|
||||||
|
self.detokenizer.abort_requests(request_ids)
|
||||||
|
|
||||||
|
# If a request finishes while we await then the request_id
|
||||||
|
# will be removed from the tracked queues before we get here.
|
||||||
|
if request_id in self.rid_to_queue:
|
||||||
|
del self.rid_to_queue[request_id]
|
||||||
|
|
||||||
def encode(
|
def encode(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -1,55 +0,0 @@
|
|||||||
import asyncio
|
|
||||||
from typing import Any, AsyncGenerator, Callable, Optional, Type, Union
|
|
||||||
|
|
||||||
from vllm.outputs import PoolingRequestOutput, RequestOutput
|
|
||||||
|
|
||||||
|
|
||||||
class AsyncStream:
|
|
||||||
"""A stream of RequestOutputs or PoolingRequestOutputs for a request
|
|
||||||
that can be iterated over asynchronously via an async generator."""
|
|
||||||
|
|
||||||
STOP_ITERATION = Exception() # Sentinel
|
|
||||||
|
|
||||||
def __init__(self, request_id: str, cancel: Callable[[str], None]) -> None:
|
|
||||||
self.request_id = request_id
|
|
||||||
self._cancel = cancel
|
|
||||||
self._queue: asyncio.Queue = asyncio.Queue()
|
|
||||||
self._finished = False
|
|
||||||
|
|
||||||
def put(self, item: Union[RequestOutput, PoolingRequestOutput,
|
|
||||||
Exception]) -> None:
|
|
||||||
if not self._finished:
|
|
||||||
self._queue.put_nowait(item)
|
|
||||||
|
|
||||||
def finish(
|
|
||||||
self,
|
|
||||||
exception: Optional[Union[BaseException, Type[BaseException]]] = None,
|
|
||||||
) -> None:
|
|
||||||
if not self._finished:
|
|
||||||
self._finished = True
|
|
||||||
self._queue.put_nowait(exception if self._is_raisable(exception)
|
|
||||||
else AsyncStream.STOP_ITERATION)
|
|
||||||
|
|
||||||
async def generator(
|
|
||||||
self
|
|
||||||
) -> AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None]:
|
|
||||||
finished = False
|
|
||||||
try:
|
|
||||||
while True:
|
|
||||||
result = await self._queue.get()
|
|
||||||
if self._is_raisable(result):
|
|
||||||
finished = True
|
|
||||||
if result == AsyncStream.STOP_ITERATION:
|
|
||||||
return
|
|
||||||
raise result
|
|
||||||
yield result
|
|
||||||
finally:
|
|
||||||
self._finished = True
|
|
||||||
if not finished:
|
|
||||||
self._cancel(self.request_id)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _is_raisable(value: Any):
|
|
||||||
return isinstance(value, BaseException) or \
|
|
||||||
(isinstance(value, type) and \
|
|
||||||
issubclass(value, BaseException))
|
|
||||||
@@ -32,7 +32,7 @@ logger = init_logger(__name__)
|
|||||||
|
|
||||||
POLLING_TIMEOUT_MS = 5000
|
POLLING_TIMEOUT_MS = 5000
|
||||||
POLLING_TIMEOUT_S = POLLING_TIMEOUT_MS // 1000
|
POLLING_TIMEOUT_S = POLLING_TIMEOUT_MS // 1000
|
||||||
LOGGING_TIME_S = 5000
|
LOGGING_TIME_S = POLLING_TIMEOUT_S
|
||||||
|
|
||||||
|
|
||||||
class EngineCore:
|
class EngineCore:
|
||||||
|
|||||||
0
vllm/v1/sample/ops/__init__.py
Normal file
0
vllm/v1/sample/ops/__init__.py
Normal file
59
vllm/v1/sample/ops/penalties.py
Normal file
59
vllm/v1/sample/ops/penalties.py
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
from typing import List, Set, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.utils import apply_penalties
|
||||||
|
from vllm.utils import is_pin_memory_available, make_tensor_with_pad
|
||||||
|
|
||||||
|
|
||||||
|
def apply_min_token_penalties(logits: torch.Tensor,
|
||||||
|
output_token_ids: List[List[int]],
|
||||||
|
stop_token_ids: List[Set[int]],
|
||||||
|
min_tokens: List[int]) -> None:
|
||||||
|
"""
|
||||||
|
Applies minimum token penalty by setting the logits of the stop tokens
|
||||||
|
to -inf.
|
||||||
|
"""
|
||||||
|
min_tokens_logits_to_penalize: List[Tuple[int, int]] = []
|
||||||
|
for index, min_token in enumerate(min_tokens):
|
||||||
|
if len(output_token_ids[index]) < min_token:
|
||||||
|
for stop_token_id in stop_token_ids[index]:
|
||||||
|
min_tokens_logits_to_penalize.append((index, stop_token_id))
|
||||||
|
if min_tokens_logits_to_penalize:
|
||||||
|
logits[tuple(zip(*min_tokens_logits_to_penalize))] = -float("inf")
|
||||||
|
|
||||||
|
|
||||||
|
def apply_all_penalties(
|
||||||
|
logits: torch.Tensor,
|
||||||
|
prompt_token_ids: torch.Tensor,
|
||||||
|
presence_penalties: torch.Tensor,
|
||||||
|
frequency_penalties: torch.Tensor,
|
||||||
|
repetition_penalties: torch.Tensor,
|
||||||
|
output_token_ids: List[List[int]],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Applies presence, frequency and repetition penalties to the logits.
|
||||||
|
"""
|
||||||
|
_, vocab_size = logits.shape
|
||||||
|
output_tokens_t = _convert_to_tensors(output_token_ids, vocab_size,
|
||||||
|
logits.device)
|
||||||
|
return apply_penalties(logits, prompt_token_ids, output_tokens_t,
|
||||||
|
presence_penalties, frequency_penalties,
|
||||||
|
repetition_penalties)
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_to_tensors(output_token_ids: List[List[int]], vocab_size: int,
|
||||||
|
device: torch.device) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Convert the different list data structures to tensors.
|
||||||
|
"""
|
||||||
|
output_tokens_tensor = make_tensor_with_pad(
|
||||||
|
output_token_ids,
|
||||||
|
# Use the value of vocab_size as a pad since we don't have a
|
||||||
|
# token_id of this value.
|
||||||
|
pad=vocab_size,
|
||||||
|
device="cpu",
|
||||||
|
dtype=torch.int64,
|
||||||
|
pin_memory=is_pin_memory_available(),
|
||||||
|
)
|
||||||
|
return output_tokens_tensor.to(device, non_blocking=True)
|
||||||
201
vllm/v1/sample/ops/topk_topp_sampler.py
Normal file
201
vllm/v1/sample/ops/topk_topp_sampler.py
Normal file
@@ -0,0 +1,201 @@
|
|||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from vllm import envs
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
try:
|
||||||
|
import flashinfer.sampling
|
||||||
|
is_flashinfer_available = True
|
||||||
|
except ImportError:
|
||||||
|
is_flashinfer_available = False
|
||||||
|
|
||||||
|
|
||||||
|
class TopKTopPSampler(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
if current_platform.is_cuda:
|
||||||
|
if is_flashinfer_available:
|
||||||
|
if envs.VLLM_USE_FLASHINFER_SAMPLER is not False:
|
||||||
|
# NOTE(woosuk): The V0 sampler doesn't use FlashInfer for
|
||||||
|
# sampling unless VLLM_USE_FLASHINFER_SAMPLER=1 (i.e., by
|
||||||
|
# default it is unused). For backward compatibility, we set
|
||||||
|
# `VLLM_USE_FLASHINFER_SAMPLER` as None by default and
|
||||||
|
# interpret it differently in V0 and V1 samplers: In V0,
|
||||||
|
# None means False, while in V1, None means True. This is
|
||||||
|
# why we use the condition
|
||||||
|
# `envs.VLLM_USE_FLASHINFER_SAMPLER is not False` here.
|
||||||
|
logger.info("Using FlashInfer for top-p & top-k sampling.")
|
||||||
|
self.forward = self.forward_cuda
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"FlashInfer is available, but it is not enabled. "
|
||||||
|
"Falling back to the PyTorch-native implementation of "
|
||||||
|
"top-p & top-k sampling. For the best performance, "
|
||||||
|
"please set VLLM_USE_FLASHINFER_SAMPLER=1.")
|
||||||
|
self.forward = self.forward_native
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"FlashInfer is not available. Falling back to the PyTorch-"
|
||||||
|
"native implementation of top-p & top-k sampling. For the "
|
||||||
|
"best performance, please install FalshInfer.")
|
||||||
|
self.forward = self.forward_native
|
||||||
|
else:
|
||||||
|
self.forward = self.forward_native
|
||||||
|
|
||||||
|
def forward_native(
|
||||||
|
self,
|
||||||
|
logits: torch.Tensor,
|
||||||
|
generators: Dict[int, torch.Generator],
|
||||||
|
no_top_k: bool,
|
||||||
|
k: torch.Tensor,
|
||||||
|
no_top_p: bool,
|
||||||
|
p: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""PyTorch-native implementation of top-k and top-p sampling."""
|
||||||
|
logits = apply_top_k_top_p(logits, no_top_k, k, no_top_p, p)
|
||||||
|
probs = logits.softmax(dim=-1, dtype=torch.float32)
|
||||||
|
return random_sample(probs, generators)
|
||||||
|
|
||||||
|
def forward_cuda(
|
||||||
|
self,
|
||||||
|
logits: torch.Tensor,
|
||||||
|
generators: Dict[int, torch.Generator],
|
||||||
|
no_top_k: bool,
|
||||||
|
k: torch.Tensor,
|
||||||
|
no_top_p: bool,
|
||||||
|
p: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""More optimized implementation for top-k and top-p sampling."""
|
||||||
|
probs = logits.softmax(dim=-1, dtype=torch.float32)
|
||||||
|
if no_top_k and no_top_p:
|
||||||
|
# We prefer `random_sample` over `flashinfer_sample` when sorting is
|
||||||
|
# not needed. This is because `random_sample` does not require
|
||||||
|
# CPU-GPU synchronization while `flashinfer_sample` does.
|
||||||
|
return random_sample(probs, generators)
|
||||||
|
return flashinfer_sample(probs, no_top_k, k, no_top_p, p, generators)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_top_k_top_p(
|
||||||
|
logits: torch.Tensor,
|
||||||
|
no_top_k: bool,
|
||||||
|
k: torch.Tensor,
|
||||||
|
no_top_p: bool,
|
||||||
|
p: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Apply top-k and top-p masks to the logits.
|
||||||
|
|
||||||
|
This function sorts the logits tensor, which can be slow for large batches.
|
||||||
|
"""
|
||||||
|
if no_top_k and no_top_p:
|
||||||
|
return logits
|
||||||
|
logits_sort, logits_idx = logits.sort(dim=-1, descending=False)
|
||||||
|
|
||||||
|
if not no_top_k:
|
||||||
|
# Apply top-k.
|
||||||
|
top_k_mask = logits_sort.size(1) - k.to(torch.long)
|
||||||
|
# Get all the top_k values.
|
||||||
|
top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1))
|
||||||
|
top_k_mask = logits_sort < top_k_mask
|
||||||
|
logits_sort.masked_fill_(top_k_mask, -float("inf"))
|
||||||
|
|
||||||
|
if not no_top_p:
|
||||||
|
# Apply top-p.
|
||||||
|
probs_sort = logits_sort.softmax(dim=-1)
|
||||||
|
probs_sum = probs_sort.cumsum(dim=-1)
|
||||||
|
top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1)
|
||||||
|
# at least one
|
||||||
|
top_p_mask[:, -1] = False
|
||||||
|
logits_sort.masked_fill_(top_p_mask, -float("inf"))
|
||||||
|
|
||||||
|
# Re-sort the probabilities.
|
||||||
|
logits = logits_sort.scatter(dim=-1, index=logits_idx, src=logits_sort)
|
||||||
|
return logits
|
||||||
|
|
||||||
|
|
||||||
|
def random_sample(
|
||||||
|
probs: torch.Tensor,
|
||||||
|
generators: Dict[int, torch.Generator],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Randomly sample from the probabilities.
|
||||||
|
|
||||||
|
We use this function instead of torch.multinomial because torch.multinomial
|
||||||
|
causes CPU-GPU synchronization.
|
||||||
|
"""
|
||||||
|
q = torch.empty_like(probs)
|
||||||
|
# NOTE(woosuk): To batch-process the requests without their own seeds,
|
||||||
|
# which is the common case, we first assume that every request does
|
||||||
|
# not have its own seed. Then, we overwrite the values for the requests
|
||||||
|
# that have their own seeds.
|
||||||
|
if len(generators) != probs.shape[0]:
|
||||||
|
q.exponential_()
|
||||||
|
if generators:
|
||||||
|
# TODO(woosuk): This can be slow because we handle each request
|
||||||
|
# one by one. Optimize this.
|
||||||
|
for i, generator in generators.items():
|
||||||
|
q[i].exponential_(generator=generator)
|
||||||
|
return probs.div_(q).argmax(dim=-1).view(-1)
|
||||||
|
|
||||||
|
|
||||||
|
def flashinfer_sample(
|
||||||
|
probs: torch.Tensor,
|
||||||
|
no_top_k: bool,
|
||||||
|
k: torch.Tensor,
|
||||||
|
no_top_p: bool,
|
||||||
|
p: torch.Tensor,
|
||||||
|
generators: Dict[int, torch.Generator],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Sample from the probabilities using FlashInfer.
|
||||||
|
|
||||||
|
Statistically, this function is equivalent to the `random_sample` function.
|
||||||
|
However, this function is faster because it avoids sorting the logits tensor
|
||||||
|
via rejection sampling.
|
||||||
|
|
||||||
|
NOTE: The outputs of this function do not necessarily match the outputs of
|
||||||
|
the `random_sample` function. It only guarantees that the outputs are
|
||||||
|
statistically equivalent.
|
||||||
|
|
||||||
|
NOTE: This function includes CPU-GPU synchronization, while `random_sample`
|
||||||
|
does not. Call this function at the end of the forward pass to minimize
|
||||||
|
the synchronization overhead.
|
||||||
|
"""
|
||||||
|
assert not (no_top_k and no_top_p)
|
||||||
|
max_top_k_round = 32
|
||||||
|
batch_size = probs.shape[0]
|
||||||
|
uniform_samples = torch.empty((max_top_k_round, batch_size),
|
||||||
|
device=probs.device)
|
||||||
|
if len(generators) != batch_size:
|
||||||
|
uniform_samples.uniform_()
|
||||||
|
if generators:
|
||||||
|
for i, generator in generators.items():
|
||||||
|
uniform_samples[:, i].uniform_(generator=generator)
|
||||||
|
|
||||||
|
if no_top_k:
|
||||||
|
# Top-p only.
|
||||||
|
next_token_ids, success = flashinfer.sampling.top_p_sampling_from_probs(
|
||||||
|
probs, uniform_samples, p, deterministic=True)
|
||||||
|
elif no_top_p:
|
||||||
|
# Top-k only.
|
||||||
|
next_token_ids, success = flashinfer.sampling.top_k_sampling_from_probs(
|
||||||
|
probs, uniform_samples, k, deterministic=True)
|
||||||
|
else:
|
||||||
|
# Both top-k and top-p.
|
||||||
|
next_token_ids, success = (
|
||||||
|
flashinfer.sampling.top_k_top_p_sampling_from_probs(
|
||||||
|
probs, uniform_samples, k, p, deterministic=True))
|
||||||
|
|
||||||
|
# NOTE: CPU-GPU synchronization happens here.
|
||||||
|
if not success.all():
|
||||||
|
if not no_top_k:
|
||||||
|
probs = flashinfer.sampling.top_k_renorm_prob(probs, k)
|
||||||
|
if not no_top_p:
|
||||||
|
probs = flashinfer.sampling.top_p_renorm_prob(probs, p)
|
||||||
|
next_token_ids = flashinfer.sampling.sampling_from_probs(
|
||||||
|
probs, uniform_samples[0], deterministic=True)
|
||||||
|
return next_token_ids.view(-1)
|
||||||
@@ -1,53 +1,55 @@
|
|||||||
"""A layer that samples the next tokens from the model's outputs."""
|
"""A layer that samples the next tokens from the model's outputs."""
|
||||||
from typing import Dict, List, Set, Tuple
|
from typing import Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from vllm.model_executor.layers.utils import apply_penalties
|
|
||||||
from vllm.utils import is_pin_memory_available, make_tensor_with_pad
|
|
||||||
from vllm.v1.outputs import SamplerOutput
|
from vllm.v1.outputs import SamplerOutput
|
||||||
from vllm.v1.sample.metadata import SamplingMetadata
|
from vllm.v1.sample.metadata import SamplingMetadata
|
||||||
|
from vllm.v1.sample.ops.penalties import (apply_all_penalties,
|
||||||
|
apply_min_token_penalties)
|
||||||
|
from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler
|
||||||
|
|
||||||
_SAMPLING_EPS = 1e-5
|
_SAMPLING_EPS = 1e-5
|
||||||
|
|
||||||
|
|
||||||
class Sampler(nn.Module):
|
class Sampler(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.topk_topp_sampler = TopKTopPSampler()
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
logits: torch.Tensor,
|
logits: torch.Tensor,
|
||||||
sampling_metadata: SamplingMetadata,
|
sampling_metadata: SamplingMetadata,
|
||||||
) -> SamplerOutput:
|
) -> SamplerOutput:
|
||||||
_apply_min_token_penalties(logits, sampling_metadata.output_token_ids,
|
needs_logprobs = sampling_metadata.max_num_logprobs > 0
|
||||||
sampling_metadata.stop_token_ids,
|
if needs_logprobs:
|
||||||
sampling_metadata.min_tokens)
|
# NOTE(woosuk): Use the original logits (before any penalties or
|
||||||
if not sampling_metadata.no_penalties:
|
# temperature scaling) for the top-k logprobs.
|
||||||
assert sampling_metadata.prompt_token_ids is not None
|
# This is different from the V0 sampler, which uses the logits that
|
||||||
_apply_penalties(logits, sampling_metadata.prompt_token_ids,
|
# is used for sampling (after penalties and temperature scaling).
|
||||||
sampling_metadata.presence_penalties,
|
# NOTE: We compute logprobs first because the below ops may
|
||||||
sampling_metadata.frequency_penalties,
|
# modify the logits tensor in-place (and we don't want to clone
|
||||||
sampling_metadata.repetition_penalties,
|
# the logits tensor for memory efficiency).
|
||||||
sampling_metadata.output_token_ids)
|
topk_logprobs, topk_indices = self.get_topk_logprobs(
|
||||||
logits = self.apply_temperature(logits, sampling_metadata.temperature)
|
logits, sampling_metadata)
|
||||||
logits = self.apply_top_k_top_p(logits, sampling_metadata)
|
|
||||||
probs = self.get_probs(logits)
|
|
||||||
sampled = self.sample(probs, sampling_metadata)
|
|
||||||
# Use int32 to reduce the tensor size.
|
|
||||||
sampled = sampled.to(torch.int32)
|
|
||||||
|
|
||||||
if sampling_metadata.max_num_logprobs > 0:
|
|
||||||
logprobs = self.get_logprobs(logits)
|
|
||||||
# FIXME: Mask the sampled token_id, get topk logprobs,
|
|
||||||
# and concatenate the topk with the sampled token_id.
|
|
||||||
topk_logprobs, topk_indices = torch.topk(
|
|
||||||
logprobs, sampling_metadata.max_num_logprobs, dim=-1)
|
|
||||||
# Use int32 to reduce the tensor size.
|
|
||||||
topk_indices = topk_indices.to(torch.int32)
|
|
||||||
else:
|
else:
|
||||||
topk_logprobs = None
|
topk_logprobs = None
|
||||||
topk_indices = None
|
topk_indices = None
|
||||||
|
|
||||||
|
# Use float32 for the logits.
|
||||||
|
logits = logits.to(torch.float32)
|
||||||
|
# Apply penalties (e.g., min_tokens, freq_penalties).
|
||||||
|
logits = self.apply_penalties(logits, sampling_metadata)
|
||||||
|
# Apply temperature.
|
||||||
|
logits = self.apply_temperature(logits, sampling_metadata.temperature)
|
||||||
|
# Sample the next token.
|
||||||
|
sampled = self.sample(logits, sampling_metadata)
|
||||||
|
# Use int32 to reduce the tensor size.
|
||||||
|
sampled = sampled.to(torch.int32)
|
||||||
|
|
||||||
# NOTE: CPU-GPU synchronization happens here.
|
# NOTE: CPU-GPU synchronization happens here.
|
||||||
sampler_output = SamplerOutput(
|
sampler_output = SamplerOutput(
|
||||||
sampled_token_ids=sampled.tolist(),
|
sampled_token_ids=sampled.tolist(),
|
||||||
@@ -63,71 +65,37 @@ class Sampler(nn.Module):
|
|||||||
logits: torch.Tensor,
|
logits: torch.Tensor,
|
||||||
temp: torch.Tensor,
|
temp: torch.Tensor,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
# Use float32 to apply temperature scaling.
|
|
||||||
logits = logits.to(torch.float32)
|
|
||||||
# Avoid division by zero.
|
# Avoid division by zero.
|
||||||
temp = torch.where(temp < _SAMPLING_EPS, 1.0, temp)
|
temp = torch.where(temp < _SAMPLING_EPS, 1.0, temp)
|
||||||
# Use in-place division to avoid creating a new tensor.
|
# Use in-place division to avoid creating a new tensor.
|
||||||
logits.div_(temp.unsqueeze(dim=1))
|
logits.div_(temp.unsqueeze(dim=1))
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
def apply_top_k_top_p(
|
def greedy_sample(self, logits: torch.Tensor) -> torch.Tensor:
|
||||||
self,
|
return logits.argmax(dim=-1).view(-1)
|
||||||
logits: torch.Tensor,
|
|
||||||
sampling_metadata: SamplingMetadata,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
return _apply_top_k_top_p(
|
|
||||||
logits,
|
|
||||||
sampling_metadata.no_top_k,
|
|
||||||
sampling_metadata.top_k,
|
|
||||||
sampling_metadata.no_top_p,
|
|
||||||
sampling_metadata.top_p,
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_probs(self, logits: torch.Tensor) -> torch.Tensor:
|
|
||||||
return torch.softmax(logits, dim=-1, dtype=torch.float32)
|
|
||||||
|
|
||||||
def get_logprobs(self, logits: torch.Tensor) -> torch.Tensor:
|
|
||||||
return torch.log_softmax(logits, dim=-1, dtype=torch.float32)
|
|
||||||
|
|
||||||
def greedy_sample(self, probs: torch.Tensor) -> torch.Tensor:
|
|
||||||
return probs.argmax(dim=-1).view(-1)
|
|
||||||
|
|
||||||
def random_sample(
|
|
||||||
self,
|
|
||||||
probs: torch.Tensor,
|
|
||||||
generators: Dict[int, torch.Generator],
|
|
||||||
) -> torch.Tensor:
|
|
||||||
q = torch.empty_like(probs)
|
|
||||||
# NOTE(woosuk): To batch-process the requests without their own seeds,
|
|
||||||
# which is the common case, we first assume that every request does
|
|
||||||
# not have its own seed. Then, we overwrite the values for the requests
|
|
||||||
# that have their own seeds.
|
|
||||||
if len(generators) != probs.shape[0]:
|
|
||||||
# This might still be done here unnecessarily if there are greedies
|
|
||||||
q.exponential_()
|
|
||||||
if generators:
|
|
||||||
# TODO(woosuk): This can be slow because we handle each request
|
|
||||||
# one by one. Optimize this.
|
|
||||||
for i, generator in generators.items():
|
|
||||||
q[i].exponential_(generator=generator)
|
|
||||||
return probs.div_(q).argmax(dim=-1).view(-1)
|
|
||||||
|
|
||||||
def sample(
|
def sample(
|
||||||
self,
|
self,
|
||||||
probs: torch.Tensor,
|
logits: torch.Tensor,
|
||||||
sampling_metadata: SamplingMetadata,
|
sampling_metadata: SamplingMetadata,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
assert not (sampling_metadata.all_greedy
|
assert not (sampling_metadata.all_greedy
|
||||||
and sampling_metadata.all_random)
|
and sampling_metadata.all_random)
|
||||||
if sampling_metadata.all_greedy:
|
if sampling_metadata.all_greedy:
|
||||||
return self.greedy_sample(probs)
|
return self.greedy_sample(logits)
|
||||||
if sampling_metadata.all_random:
|
|
||||||
return self.random_sample(probs, sampling_metadata.generators)
|
|
||||||
|
|
||||||
greedy_sampled = self.greedy_sample(probs)
|
random_sampled = self.topk_topp_sampler(
|
||||||
random_sampled = self.random_sample(probs,
|
logits,
|
||||||
sampling_metadata.generators)
|
sampling_metadata.generators,
|
||||||
|
sampling_metadata.no_top_k,
|
||||||
|
sampling_metadata.top_k,
|
||||||
|
sampling_metadata.no_top_p,
|
||||||
|
sampling_metadata.top_p,
|
||||||
|
)
|
||||||
|
if sampling_metadata.all_random:
|
||||||
|
return random_sampled
|
||||||
|
|
||||||
|
greedy_sampled = self.greedy_sample(logits)
|
||||||
sampled = torch.where(
|
sampled = torch.where(
|
||||||
sampling_metadata.temperature < _SAMPLING_EPS,
|
sampling_metadata.temperature < _SAMPLING_EPS,
|
||||||
greedy_sampled,
|
greedy_sampled,
|
||||||
@@ -135,86 +103,34 @@ class Sampler(nn.Module):
|
|||||||
)
|
)
|
||||||
return sampled
|
return sampled
|
||||||
|
|
||||||
|
def get_topk_logprobs(
|
||||||
|
self,
|
||||||
|
logits: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
logprobs = logits.log_softmax(dim=-1, dtype=torch.float32)
|
||||||
|
# FIXME: Mask the sampled token_id, get topk logprobs,
|
||||||
|
# and concatenate the topk with the sampled token_id.
|
||||||
|
topk_logprobs, topk_indices = torch.topk(
|
||||||
|
logprobs, sampling_metadata.max_num_logprobs, dim=-1)
|
||||||
|
# Use int32 to reduce the tensor size.
|
||||||
|
topk_indices = topk_indices.to(torch.int32)
|
||||||
|
return topk_logprobs, topk_indices
|
||||||
|
|
||||||
# TODO(woosuk): Optimize this with a custom kernel.
|
def apply_penalties(
|
||||||
def _apply_top_k_top_p(
|
self,
|
||||||
logits: torch.Tensor,
|
logits: torch.Tensor,
|
||||||
no_top_k: bool,
|
sampling_metadata: SamplingMetadata,
|
||||||
k: torch.Tensor,
|
) -> torch.Tensor:
|
||||||
no_top_p: bool,
|
apply_min_token_penalties(logits, sampling_metadata.output_token_ids,
|
||||||
p: torch.Tensor,
|
sampling_metadata.stop_token_ids,
|
||||||
) -> torch.Tensor:
|
sampling_metadata.min_tokens)
|
||||||
if no_top_k and no_top_p:
|
if not sampling_metadata.no_penalties:
|
||||||
|
assert sampling_metadata.prompt_token_ids is not None
|
||||||
|
logits = apply_all_penalties(
|
||||||
|
logits, sampling_metadata.prompt_token_ids,
|
||||||
|
sampling_metadata.presence_penalties,
|
||||||
|
sampling_metadata.frequency_penalties,
|
||||||
|
sampling_metadata.repetition_penalties,
|
||||||
|
sampling_metadata.output_token_ids)
|
||||||
return logits
|
return logits
|
||||||
logits_sort, logits_idx = logits.sort(dim=-1, descending=False)
|
|
||||||
|
|
||||||
if not no_top_k:
|
|
||||||
# Apply top-k.
|
|
||||||
top_k_mask = logits_sort.size(1) - k.to(torch.long)
|
|
||||||
# Get all the top_k values.
|
|
||||||
top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1))
|
|
||||||
top_k_mask = logits_sort < top_k_mask
|
|
||||||
logits_sort.masked_fill_(top_k_mask, -float("inf"))
|
|
||||||
|
|
||||||
if not no_top_p:
|
|
||||||
# Apply top-p.
|
|
||||||
probs_sort = logits_sort.softmax(dim=-1)
|
|
||||||
probs_sum = probs_sort.cumsum(dim=-1)
|
|
||||||
top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1)
|
|
||||||
# at least one
|
|
||||||
top_p_mask[:, -1] = False
|
|
||||||
logits_sort.masked_fill_(top_p_mask, -float("inf"))
|
|
||||||
|
|
||||||
# Re-sort the probabilities.
|
|
||||||
logits = logits_sort.scatter(dim=-1, index=logits_idx, src=logits_sort)
|
|
||||||
return logits
|
|
||||||
|
|
||||||
|
|
||||||
def _apply_min_token_penalties(logits: torch.Tensor,
|
|
||||||
output_token_ids: List[List[int]],
|
|
||||||
stop_token_ids: List[Set[int]],
|
|
||||||
min_tokens: List[int]):
|
|
||||||
"""
|
|
||||||
Applies minimum token penalty by setting the logits of the stop tokens
|
|
||||||
to -inf.
|
|
||||||
"""
|
|
||||||
min_tokens_logits_to_penalize: List[Tuple[int, int]] = []
|
|
||||||
for index, min_token in enumerate(min_tokens):
|
|
||||||
if (len(output_token_ids[index]) < min_token):
|
|
||||||
for stop_token_id in stop_token_ids[index]:
|
|
||||||
min_tokens_logits_to_penalize.append((index, stop_token_id))
|
|
||||||
if min_tokens_logits_to_penalize:
|
|
||||||
logits[tuple(zip(*min_tokens_logits_to_penalize))] = -float("inf")
|
|
||||||
|
|
||||||
|
|
||||||
def _apply_penalties(logits: torch.Tensor, prompt_token_ids: torch.Tensor,
|
|
||||||
presence_penalties: torch.Tensor,
|
|
||||||
frequency_penalties: torch.Tensor,
|
|
||||||
repetition_penalties: torch.Tensor,
|
|
||||||
output_token_ids: List[List[int]]):
|
|
||||||
"""
|
|
||||||
Applies presence, frequency and repetition penalties to the logits.
|
|
||||||
"""
|
|
||||||
_, vocab_size = logits.shape
|
|
||||||
output_tokens_t = _convert_to_tensors(output_token_ids, vocab_size,
|
|
||||||
logits.device)
|
|
||||||
return apply_penalties(logits, prompt_token_ids, output_tokens_t,
|
|
||||||
presence_penalties, frequency_penalties,
|
|
||||||
repetition_penalties)
|
|
||||||
|
|
||||||
|
|
||||||
def _convert_to_tensors(output_token_ids: List[List[int]], vocab_size: int,
|
|
||||||
device: torch.device) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
Convert the different list data structures to tensors.
|
|
||||||
"""
|
|
||||||
output_tokens_tensor = make_tensor_with_pad(
|
|
||||||
output_token_ids,
|
|
||||||
# Use the value of vocab_size as a pad since we don't have a
|
|
||||||
# token_id of this value.
|
|
||||||
pad=vocab_size,
|
|
||||||
device="cpu",
|
|
||||||
dtype=torch.int64,
|
|
||||||
pin_memory=is_pin_memory_available(),
|
|
||||||
)
|
|
||||||
return output_tokens_tensor.to(device, non_blocking=True)
|
|
||||||
|
|||||||
Reference in New Issue
Block a user