[Bugfix] Use runner_type instead of task in GritLM (#11144)
Signed-off-by: Pooya Davoodi <pooya.davoodi@parasail.io>
This commit is contained in:
@@ -203,12 +203,12 @@ class GritLM(LlamaForCausalLM):
|
||||
) -> None:
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs)
|
||||
|
||||
self.task = vllm_config.model_config.task
|
||||
self.runner_type = vllm_config.model_config.runner_type
|
||||
|
||||
self._pooler = GritLMPooler(vllm_config.model_config)
|
||||
|
||||
for layer in self.model.layers:
|
||||
if self.task == "embedding" and hasattr(layer, "self_attn"):
|
||||
if self.runner_type == "pooling" and hasattr(layer, "self_attn"):
|
||||
assert isinstance(layer.self_attn.attn.impl, XFormersImpl), (
|
||||
"GritLM embedding is only supported by XFormers backend, "
|
||||
"which can be forced by VLLM_ATTENTION_BACKEND=XFORMERS")
|
||||
@@ -222,8 +222,8 @@ class GritLM(LlamaForCausalLM):
|
||||
**kwargs,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
|
||||
# Change attention to non-causal for embedding task.
|
||||
if self.task == "embedding":
|
||||
# Change attention to non-causal for pooling tasks.
|
||||
if self.runner_type == "pooling":
|
||||
assert attn_metadata.prefill_metadata.attn_bias is None
|
||||
attn_metadata.prefill_metadata.attn_bias = [
|
||||
BlockDiagonalMask.from_seqlens(attn_metadata.seq_lens)
|
||||
|
||||
Reference in New Issue
Block a user