[Model] Remove hardcoded image tokens ids from Pixtral (#11582)
Signed-off-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
@@ -45,13 +45,6 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
USE_XFORMERS_OPS = False
|
USE_XFORMERS_OPS = False
|
||||||
|
|
||||||
# These token ids cannot be retrieved from model config
|
|
||||||
# so we hardcode them here.
|
|
||||||
PIXTRAL_12B_IMAGE_BREAK_ID = 12
|
|
||||||
PIXTRAL_12B_IMAGE_END_ID = 13
|
|
||||||
PIXTRAL_LARGE_IMAGE_BREAK_ID = 14
|
|
||||||
PIXTRAL_LARGE_IMAGE_END_ID = 15
|
|
||||||
|
|
||||||
|
|
||||||
def get_max_pixtral_image_tokens(ctx: InputContext):
|
def get_max_pixtral_image_tokens(ctx: InputContext):
|
||||||
tokenizer = cached_get_tokenizer(
|
tokenizer = cached_get_tokenizer(
|
||||||
@@ -201,6 +194,13 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
if key in dataclass_fields
|
if key in dataclass_fields
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if not ("image_break_token_id" in vision_args
|
||||||
|
and "image_end_token_id" in vision_args):
|
||||||
|
raise ValueError(
|
||||||
|
"'image_break_token_id' and 'image_end_token_id' not found "
|
||||||
|
"in the vision_encoder arguments. Please download the latest "
|
||||||
|
"version of 'params.json' from the model repository.")
|
||||||
|
|
||||||
self.vision_args = VisionEncoderArgs(**vision_args)
|
self.vision_args = VisionEncoderArgs(**vision_args)
|
||||||
|
|
||||||
# init MistralForCausalLM
|
# init MistralForCausalLM
|
||||||
@@ -240,9 +240,8 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
|
|
||||||
# NOTE: Image embeddings are split into separate tensors for each image
|
# NOTE: Image embeddings are split into separate tensors for each image
|
||||||
# by the indices of `[IMG_END]` token.
|
# by the indices of `[IMG_END]` token.
|
||||||
image_end_condition = (image_tokens == PIXTRAL_12B_IMAGE_END_ID) | (
|
image_end_mask = image_tokens == self.vision_args.image_end_token_id
|
||||||
image_tokens == PIXTRAL_LARGE_IMAGE_END_ID)
|
split_indices = torch.where(image_end_mask)[0] + 1
|
||||||
split_indices = torch.where(image_end_condition)[0] + 1
|
|
||||||
if len(split_indices) <= 1:
|
if len(split_indices) <= 1:
|
||||||
# Do not split, return as tensor of shape [1, fs, hs]
|
# Do not split, return as tensor of shape [1, fs, hs]
|
||||||
return image_embeds.unsqueeze(0)
|
return image_embeds.unsqueeze(0)
|
||||||
@@ -265,10 +264,8 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
inputs_embeds = merge_multimodal_embeddings(
|
inputs_embeds = merge_multimodal_embeddings(
|
||||||
input_ids, inputs_embeds, multimodal_embeddings, [
|
input_ids, inputs_embeds, multimodal_embeddings, [
|
||||||
self.vision_args.image_token_id,
|
self.vision_args.image_token_id,
|
||||||
PIXTRAL_12B_IMAGE_END_ID,
|
self.vision_args.image_break_token_id,
|
||||||
PIXTRAL_12B_IMAGE_BREAK_ID,
|
self.vision_args.image_end_token_id,
|
||||||
PIXTRAL_LARGE_IMAGE_BREAK_ID,
|
|
||||||
PIXTRAL_LARGE_IMAGE_END_ID,
|
|
||||||
])
|
])
|
||||||
return inputs_embeds
|
return inputs_embeds
|
||||||
|
|
||||||
@@ -409,6 +406,8 @@ class VisionEncoderArgs:
|
|||||||
num_attention_heads: int
|
num_attention_heads: int
|
||||||
rope_theta: float # for rope-2D
|
rope_theta: float # for rope-2D
|
||||||
image_token_id: int
|
image_token_id: int
|
||||||
|
image_break_token_id: int
|
||||||
|
image_end_token_id: int
|
||||||
adapter_bias: bool = True
|
adapter_bias: bool = True
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user