[Bugfix] Enable Kimi k25 processor test (#33562)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
@@ -96,16 +96,20 @@ class MoonshotKimiVAutoProcessor(ProcessorMixin):
|
||||
attributes = ["tokenizer"]
|
||||
tokenizer_class = "AutoTokenizer"
|
||||
|
||||
def __init__(self, media_processor=None, tokenizer=None):
|
||||
def __init__(
|
||||
self, media_processor=None, tokenizer=None, media_token_id: int | None = None
|
||||
):
|
||||
super().__init__(tokenizer)
|
||||
self.media_processor = media_processor
|
||||
self.media_token_id = media_token_id
|
||||
assert self.media_token_id is not None
|
||||
|
||||
# We do not support str input for text here
|
||||
def __call__(
|
||||
self,
|
||||
vision_chunks: list[VisionChunk] | None = None,
|
||||
*,
|
||||
text: list[int],
|
||||
text: list[int] | str,
|
||||
**kwargs,
|
||||
) -> BatchFeature:
|
||||
"""
|
||||
@@ -122,13 +126,30 @@ class MoonshotKimiVAutoProcessor(ProcessorMixin):
|
||||
- **grid_thws** -- list of image 3D grid in LLM. Returned when `vision_chunks` is not `None`.
|
||||
"""
|
||||
mm_inputs = {}
|
||||
input_ids = self.tokenizer.encode(text) if isinstance(text, str) else text
|
||||
if vision_chunks is not None:
|
||||
assert isinstance(vision_chunks, list)
|
||||
mm_inputs = self.media_processor.preprocess(vision_chunks)
|
||||
|
||||
num_tokens_per_chunk = [
|
||||
self.media_processor.media_tokens_calculator(chunk)
|
||||
for chunk in vision_chunks
|
||||
]
|
||||
|
||||
new_input_ids = []
|
||||
for token in input_ids:
|
||||
if token == self.media_token_id:
|
||||
new_input_ids.extend(
|
||||
[self.media_token_id] * num_tokens_per_chunk.pop(0)
|
||||
)
|
||||
else:
|
||||
new_input_ids.append(token)
|
||||
input_ids = new_input_ids
|
||||
|
||||
# XXX: _apply_hf_processor_text_mm will call tolist() on input_ids
|
||||
return BatchFeature(
|
||||
data={
|
||||
"input_ids": torch.tensor([text]),
|
||||
"input_ids": torch.tensor([input_ids]),
|
||||
**mm_inputs,
|
||||
}
|
||||
)
|
||||
@@ -152,6 +173,7 @@ class KimiK25ProcessingInfo(BaseProcessingInfo):
|
||||
self.hf_processor = MoonshotKimiVAutoProcessor(
|
||||
media_processor=self.media_processor,
|
||||
tokenizer=self.get_tokenizer(),
|
||||
media_token_id=self.media_token_id,
|
||||
)
|
||||
self.media_tokens_calculator = self.media_processor.media_tokens_calculator
|
||||
|
||||
@@ -174,9 +196,9 @@ class KimiK25DummyInputsBuilder(BaseDummyInputsBuilder[KimiK25ProcessingInfo]):
|
||||
self.media_token_id = self.info.media_token_id
|
||||
self.frame_per_chunk = self.info.media_processor.num_frames_per_chunk
|
||||
|
||||
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> list[int]:
|
||||
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
|
||||
num_media = mm_counts.get("vision_chunk", 0)
|
||||
return [self.media_token_id] * num_media
|
||||
return "<|media_pad|>" * num_media
|
||||
|
||||
def get_dummy_mm_items(self):
|
||||
dummy_videos = self._get_dummy_images(
|
||||
|
||||
Reference in New Issue
Block a user