[Core] Rename input data types (#8688)
This commit is contained in:
@@ -14,7 +14,7 @@ from torch.nn import LayerNorm
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.config import CacheConfig, LoRAConfig, MultiModalConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
|
||||
from vllm.inputs import INPUT_REGISTRY, DecoderOnlyInputs, InputContext
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
@@ -149,20 +149,20 @@ def find_all_positions(input_ids: List[int], target: int) -> List[int]:
|
||||
return [index for index, value in enumerate(input_ids) if value == target]
|
||||
|
||||
|
||||
def input_processor_for_glmv(ctx: InputContext, llm_inputs: LLMInputs):
|
||||
def input_processor_for_glmv(ctx: InputContext, inputs: DecoderOnlyInputs):
|
||||
hf_config = ctx.get_hf_config(ChatGLMConfig)
|
||||
vision_config = getattr(hf_config, 'vision_config', None)
|
||||
|
||||
if vision_config is None:
|
||||
return llm_inputs
|
||||
return inputs
|
||||
elif isinstance(vision_config, dict):
|
||||
image_placeholder_length = calculate_image_placeholder(vision_config)
|
||||
else:
|
||||
msg = f"Unsupported vision config: {type(vision_config)}"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
input_ids = llm_inputs.get("prompt_token_ids")
|
||||
position_ids = llm_inputs.get("position_ids")
|
||||
input_ids = inputs.get("prompt_token_ids")
|
||||
position_ids = inputs.get("position_ids")
|
||||
tokenizer = cached_get_tokenizer(
|
||||
ctx.model_config.model,
|
||||
trust_remote_code=ctx.model_config.trust_remote_code)
|
||||
@@ -171,15 +171,15 @@ def input_processor_for_glmv(ctx: InputContext, llm_inputs: LLMInputs):
|
||||
raw_batch_data = tokenizer.apply_chat_template(
|
||||
conversation=[{
|
||||
"role": "user",
|
||||
"image": llm_inputs['multi_modal_data']["image"],
|
||||
"content": llm_inputs['prompt']
|
||||
"image": inputs['multi_modal_data']["image"],
|
||||
"content": inputs['prompt']
|
||||
}],
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_tensors="pt",
|
||||
return_dict=True).data
|
||||
except Exception:
|
||||
logger.error("Failed to process content (%s)", llm_inputs['prompt'])
|
||||
logger.error("Failed to process content (%s)", inputs['prompt'])
|
||||
raise
|
||||
input_ids = raw_batch_data['input_ids'][0].tolist()
|
||||
|
||||
@@ -214,9 +214,9 @@ def input_processor_for_glmv(ctx: InputContext, llm_inputs: LLMInputs):
|
||||
|
||||
assert len(new_input_ids) == len(new_position_ids)
|
||||
|
||||
llm_inputs["prompt_token_ids"] = new_input_ids
|
||||
llm_inputs["position_ids"] = new_position_ids
|
||||
return llm_inputs
|
||||
inputs["prompt_token_ids"] = new_input_ids
|
||||
inputs["position_ids"] = new_position_ids
|
||||
return inputs
|
||||
|
||||
|
||||
class GLMAttention(nn.Module):
|
||||
|
||||
Reference in New Issue
Block a user