[core] remove beam search from the core (#9105)

This commit is contained in:
youkaichao
2024-10-06 22:47:04 -07:00
committed by GitHub
parent c8f26bb636
commit 18b296fdb2
25 changed files with 98 additions and 596 deletions

View File

@@ -23,7 +23,6 @@ class RequestFuncInput:
output_len: int
model: str
best_of: int = 1
use_beam_search: bool = False
logprobs: Optional[int] = None
multi_modal_content: Optional[dict] = None
ignore_eos: bool = False
@@ -49,7 +48,6 @@ async def async_request_tgi(
assert api_url.endswith("generate_stream")
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
assert not request_func_input.use_beam_search
params = {
"best_of": request_func_input.best_of,
"max_new_tokens": request_func_input.output_len,
@@ -121,7 +119,6 @@ async def async_request_trt_llm(
assert api_url.endswith("generate_stream")
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
assert not request_func_input.use_beam_search
assert request_func_input.best_of == 1
payload = {
"accumulate_tokens": True,
@@ -187,7 +184,6 @@ async def async_request_deepspeed_mii(
) -> RequestFuncOutput:
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
assert request_func_input.best_of == 1
assert not request_func_input.use_beam_search
payload = {
"prompt": request_func_input.prompt,
@@ -235,7 +231,6 @@ async def async_request_openai_completions(
), "OpenAI Completions API URL must end with 'completions' or 'profile'."
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
assert not request_func_input.use_beam_search
payload = {
"model": request_func_input.model,
"prompt": request_func_input.prompt,
@@ -317,7 +312,6 @@ async def async_request_openai_chat_completions(
), "OpenAI Chat Completions API URL must end with 'chat/completions'."
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
assert not request_func_input.use_beam_search
content = [{"type": "text", "text": request_func_input.prompt}]
if request_func_input.multi_modal_content:
content.append(request_func_input.multi_modal_content)