[Model] Jamba support (#4115)
Signed-off-by: Muralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai> Co-authored-by: Erez Schwartz <erezs@ai21.com> Co-authored-by: Mor Zusman <morz@ai21.com> Co-authored-by: tomeras91 <57313761+tomeras91@users.noreply.github.com> Co-authored-by: Tomer Asida <tomera@ai21.com> Co-authored-by: Zhuohan Li <zhuohan123@gmail.com> Co-authored-by: Muralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai>
This commit is contained in:
@@ -75,15 +75,19 @@ class TP1DraftModelRunner(ModelRunner):
|
||||
List[SequenceGroupMetadata]] = None
|
||||
|
||||
def prepare_model_input(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
virtual_engine: int = 0) -> ModelInputForGPUWithSamplingMetadata:
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
virtual_engine: int = 0,
|
||||
finished_requests_ids: Optional[List[str]] = None
|
||||
) -> ModelInputForGPUWithSamplingMetadata:
|
||||
"""A temporary solution that caches the seq_group_metadata_list
|
||||
for multi-step execution.
|
||||
TODO: In-place update model_input and remove this function.
|
||||
"""
|
||||
self.cached_seq_group_metadata_list = seq_group_metadata_list
|
||||
return super().prepare_model_input(seq_group_metadata_list)
|
||||
return super().prepare_model_input(
|
||||
seq_group_metadata_list,
|
||||
finished_requests_ids=finished_requests_ids)
|
||||
|
||||
def update_model_input(
|
||||
self, model_input: ModelInputForGPUWithSamplingMetadata,
|
||||
|
||||
Reference in New Issue
Block a user