[Misc] Remove Gemma RoPE (#7638)

This commit is contained in:
Woosuk Kwon
2024-08-19 09:29:31 -07:00
committed by GitHub
parent 1a36287b89
commit df845b2b46
3 changed files with 7 additions and 26 deletions

View File

@@ -33,7 +33,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import GemmaRotaryEmbedding
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
@@ -148,14 +148,12 @@ class GemmaAttention(nn.Module):
quant_config=quant_config,
)
# TODO(woosuk): Use the `get_rope` interface.
self.rotary_emb = GemmaRotaryEmbedding(
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position_embeddings=max_position_embeddings,
max_position=max_position_embeddings,
base=self.rope_theta,
is_neox_style=True,
dtype=torch.get_default_dtype(),
)
self.attn = Attention(self.num_heads,
self.head_dim,