[Model] Remove hardcoded image tokens ids from Pixtral (#11582)

Signed-off-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
Roger Wang
2024-12-28 02:54:23 -08:00
committed by GitHub
parent d34be24bb1
commit b7dcc003dc

View File

@@ -45,13 +45,6 @@ try:
except ImportError: except ImportError:
USE_XFORMERS_OPS = False USE_XFORMERS_OPS = False
# These token ids cannot be retrieved from model config
# so we hardcode them here.
PIXTRAL_12B_IMAGE_BREAK_ID = 12
PIXTRAL_12B_IMAGE_END_ID = 13
PIXTRAL_LARGE_IMAGE_BREAK_ID = 14
PIXTRAL_LARGE_IMAGE_END_ID = 15
def get_max_pixtral_image_tokens(ctx: InputContext): def get_max_pixtral_image_tokens(ctx: InputContext):
tokenizer = cached_get_tokenizer( tokenizer = cached_get_tokenizer(
@@ -201,6 +194,13 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
if key in dataclass_fields if key in dataclass_fields
} }
if not ("image_break_token_id" in vision_args
and "image_end_token_id" in vision_args):
raise ValueError(
"'image_break_token_id' and 'image_end_token_id' not found "
"in the vision_encoder arguments. Please download the latest "
"version of 'params.json' from the model repository.")
self.vision_args = VisionEncoderArgs(**vision_args) self.vision_args = VisionEncoderArgs(**vision_args)
# init MistralForCausalLM # init MistralForCausalLM
@@ -240,9 +240,8 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
# NOTE: Image embeddings are split into separate tensors for each image # NOTE: Image embeddings are split into separate tensors for each image
# by the indices of `[IMG_END]` token. # by the indices of `[IMG_END]` token.
image_end_condition = (image_tokens == PIXTRAL_12B_IMAGE_END_ID) | ( image_end_mask = image_tokens == self.vision_args.image_end_token_id
image_tokens == PIXTRAL_LARGE_IMAGE_END_ID) split_indices = torch.where(image_end_mask)[0] + 1
split_indices = torch.where(image_end_condition)[0] + 1
if len(split_indices) <= 1: if len(split_indices) <= 1:
# Do not split, return as tensor of shape [1, fs, hs] # Do not split, return as tensor of shape [1, fs, hs]
return image_embeds.unsqueeze(0) return image_embeds.unsqueeze(0)
@@ -265,10 +264,8 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
inputs_embeds = merge_multimodal_embeddings( inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings, [ input_ids, inputs_embeds, multimodal_embeddings, [
self.vision_args.image_token_id, self.vision_args.image_token_id,
PIXTRAL_12B_IMAGE_END_ID, self.vision_args.image_break_token_id,
PIXTRAL_12B_IMAGE_BREAK_ID, self.vision_args.image_end_token_id,
PIXTRAL_LARGE_IMAGE_BREAK_ID,
PIXTRAL_LARGE_IMAGE_END_ID,
]) ])
return inputs_embeds return inputs_embeds
@@ -409,6 +406,8 @@ class VisionEncoderArgs:
num_attention_heads: int num_attention_heads: int
rope_theta: float # for rope-2D rope_theta: float # for rope-2D
image_token_id: int image_token_id: int
image_break_token_id: int
image_end_token_id: int
adapter_bias: bool = True adapter_bias: bool = True