[core] simplify seq group code (#9569)
Co-authored-by: Zhuohan Li <zhuohan123@gmail.com>
This commit is contained in:
@@ -647,10 +647,24 @@ class LLMEngine:
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest],
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
priority: int = 0,
|
||||
) -> SequenceGroup:
|
||||
) -> Optional[SequenceGroup]:
|
||||
"""Add a processed request to the engine's request pool.
|
||||
return the created sequence group.
|
||||
"""
|
||||
if isinstance(params, SamplingParams) and params.n > 1:
|
||||
ParallelSampleSequenceGroup.add_request(
|
||||
request_id,
|
||||
self,
|
||||
params,
|
||||
processed_inputs=processed_inputs,
|
||||
arrival_time=arrival_time,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
priority=priority,
|
||||
)
|
||||
return None
|
||||
|
||||
self._validate_model_inputs(processed_inputs)
|
||||
# Create the sequences.
|
||||
block_size = self.cache_config.block_size
|
||||
@@ -721,7 +735,7 @@ class LLMEngine:
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
priority: int = 0,
|
||||
) -> Optional[SequenceGroup]:
|
||||
) -> None:
|
||||
...
|
||||
|
||||
@overload
|
||||
@@ -735,7 +749,7 @@ class LLMEngine:
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
priority: int = 0,
|
||||
) -> Optional[SequenceGroup]:
|
||||
) -> None:
|
||||
...
|
||||
|
||||
@deprecate_kwargs(
|
||||
@@ -754,7 +768,7 @@ class LLMEngine:
|
||||
priority: int = 0,
|
||||
*,
|
||||
inputs: Optional[PromptType] = None, # DEPRECATED
|
||||
) -> Optional[SequenceGroup]:
|
||||
) -> None:
|
||||
"""Add a request to the engine's request pool.
|
||||
|
||||
The request is added to the request pool and will be processed by the
|
||||
@@ -798,22 +812,6 @@ class LLMEngine:
|
||||
>>> # continue the request processing
|
||||
>>> ...
|
||||
"""
|
||||
|
||||
if isinstance(params, SamplingParams) and params.n > 1:
|
||||
ParallelSampleSequenceGroup.add_request(
|
||||
request_id,
|
||||
self,
|
||||
params,
|
||||
prompt=prompt,
|
||||
arrival_time=arrival_time,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
priority=priority,
|
||||
inputs=inputs,
|
||||
)
|
||||
return None
|
||||
|
||||
if inputs is not None:
|
||||
prompt = inputs
|
||||
assert prompt is not None and params is not None
|
||||
@@ -844,7 +842,7 @@ class LLMEngine:
|
||||
processed_inputs["mm_processor_kwargs"] = preprocessed_inputs.get(
|
||||
"mm_processor_kwargs")
|
||||
|
||||
return self._add_processed_request(
|
||||
self._add_processed_request(
|
||||
request_id=request_id,
|
||||
processed_inputs=processed_inputs,
|
||||
params=params,
|
||||
|
||||
Reference in New Issue
Block a user