[Bugfix] Use runner_type instead of task in GritLM (#11144)

Signed-off-by: Pooya Davoodi <pooya.davoodi@parasail.io>
This commit is contained in:
Pooya Davoodi
2024-12-12 20:09:53 -08:00
committed by GitHub
parent 30870b4f66
commit 1efce68605
2 changed files with 7 additions and 7 deletions

View File

@@ -203,12 +203,12 @@ class GritLM(LlamaForCausalLM):
) -> None:
super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs)
self.task = vllm_config.model_config.task
self.runner_type = vllm_config.model_config.runner_type
self._pooler = GritLMPooler(vllm_config.model_config)
for layer in self.model.layers:
if self.task == "embedding" and hasattr(layer, "self_attn"):
if self.runner_type == "pooling" and hasattr(layer, "self_attn"):
assert isinstance(layer.self_attn.attn.impl, XFormersImpl), (
"GritLM embedding is only supported by XFormers backend, "
"which can be forced by VLLM_ATTENTION_BACKEND=XFORMERS")
@@ -222,8 +222,8 @@ class GritLM(LlamaForCausalLM):
**kwargs,
) -> Union[torch.Tensor, IntermediateTensors]:
# Change attention to non-causal for embedding task.
if self.task == "embedding":
# Change attention to non-causal for pooling tasks.
if self.runner_type == "pooling":
assert attn_metadata.prefill_metadata.attn_bias is None
attn_metadata.prefill_metadata.attn_bias = [
BlockDiagonalMask.from_seqlens(attn_metadata.seq_lens)