[Model] Jamba support (#4115)

Signed-off-by: Muralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai>
Co-authored-by: Erez Schwartz <erezs@ai21.com>
Co-authored-by: Mor Zusman <morz@ai21.com>
Co-authored-by: tomeras91 <57313761+tomeras91@users.noreply.github.com>
Co-authored-by: Tomer Asida <tomera@ai21.com>
Co-authored-by: Zhuohan Li <zhuohan123@gmail.com>
Co-authored-by: Muralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai>
This commit is contained in:
Mor Zusman
2024-07-03 02:11:29 +03:00
committed by GitHub
parent ee93f4f92a
commit 9d6a8daa87
21 changed files with 1192 additions and 34 deletions

View File

@@ -75,15 +75,19 @@ class TP1DraftModelRunner(ModelRunner):
List[SequenceGroupMetadata]] = None
def prepare_model_input(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
virtual_engine: int = 0) -> ModelInputForGPUWithSamplingMetadata:
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
virtual_engine: int = 0,
finished_requests_ids: Optional[List[str]] = None
) -> ModelInputForGPUWithSamplingMetadata:
"""A temporary solution that caches the seq_group_metadata_list
for multi-step execution.
TODO: In-place update model_input and remove this function.
"""
self.cached_seq_group_metadata_list = seq_group_metadata_list
return super().prepare_model_input(seq_group_metadata_list)
return super().prepare_model_input(
seq_group_metadata_list,
finished_requests_ids=finished_requests_ids)
def update_model_input(
self, model_input: ModelInputForGPUWithSamplingMetadata,