[Bug fix][Core] fixup ngram not setup correctly (#4551)

Co-authored-by: Lei Wen <wenlei03@qiyi.com>
Co-authored-by: Cade Daniel <edacih@gmail.com>
Co-authored-by: Cody Yu <hao.yu.cody@gmail.com>
This commit is contained in:
leiwen83
2024-05-08 02:40:18 +08:00
committed by GitHub
parent 469f85c782
commit 8344f7742b
3 changed files with 29 additions and 13 deletions

View File

@@ -55,7 +55,7 @@ class AsyncLLM:
) -> None:
if "disable_log_stats" not in kwargs:
kwargs["disable_log_stats"] = True
self.engine_args = AsyncEngineArgs(
engine_args = AsyncEngineArgs(
model=model,
tokenizer=tokenizer,
tokenizer_mode=tokenizer_mode,
@@ -76,6 +76,8 @@ class AsyncLLM:
**kwargs,
)
self.request_counter = Counter()
self.llm_engine = AsyncLLMEngine.from_engine_args(
engine_args, usage_context=UsageContext.LLM_CLASS)
def generate(
self,
@@ -88,9 +90,6 @@ class AsyncLLM:
multi_modal_data: Optional[MultiModalData] = None,
) -> List[RequestOutput]:
llm_engine = AsyncLLMEngine.from_engine_args(
self.engine_args, usage_context=UsageContext.LLM_CLASS)
if prompts is None:
raise ValueError("prompts must be provided.")
if isinstance(prompts, str):
@@ -111,8 +110,8 @@ class AsyncLLM:
async def get_output(prompt, sampling_param) -> str:
request_id = random_uuid()
results_generator = llm_engine.generate(prompt, sampling_param,
request_id)
results_generator = self.llm_engine.generate(
prompt, sampling_param, request_id)
final_output = None
async for request_output in results_generator:
final_output = request_output
@@ -185,12 +184,25 @@ def create_llm_generator(baseline_or_test, request, common_llm_kwargs,
return generator_outer
def maybe_assert_ngram_worker(llm):
# Verify the proposer worker is ngram if ngram is specified.
if (not isinstance(llm, AsyncLLM)
and llm.llm_engine.speculative_config is not None
and llm.llm_engine.speculative_config.ngram_prompt_lookup_max > 0):
from vllm.spec_decode.ngram_worker import NGramWorker
assert isinstance(
llm.llm_engine.model_executor.driver_worker.proposer_worker,
NGramWorker)
def get_output_from_llm_generator(
llm_generator, prompts,
sampling_params) -> Tuple[List[str], List[List[int]]]:
tokens = []
token_ids = []
for llm in llm_generator():
maybe_assert_ngram_worker(llm)
outputs = llm.generate(prompts, sampling_params, use_tqdm=True)
token_ids = [output.outputs[0].token_ids for output in outputs]
tokens = [output.outputs[0].text for output in outputs]