[Core] Add Lora Support to Beam Search (#18346)
Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
This commit is contained in:
@@ -3,6 +3,7 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.sequence import Logprob
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -19,6 +20,7 @@ class BeamSearchSequence:
|
||||
# The tokens includes the prompt.
|
||||
tokens: list[int]
|
||||
logprobs: list[dict[int, Logprob]]
|
||||
lora_request: Optional[LoRARequest] = None
|
||||
cum_logprob: float = 0.0
|
||||
text: Optional[str] = None
|
||||
finish_reason: Optional[str] = None
|
||||
@@ -41,6 +43,7 @@ class BeamSearchInstance:
|
||||
def __init__(
|
||||
self,
|
||||
prompt_tokens: list[int],
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
logprobs: Optional[list[dict[int, Logprob]]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
@@ -48,6 +51,7 @@ class BeamSearchInstance:
|
||||
BeamSearchSequence(
|
||||
tokens=prompt_tokens,
|
||||
logprobs=[] if logprobs is None else list(logprobs),
|
||||
lora_request=lora_request,
|
||||
**kwargs,
|
||||
)
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user