[Misc] Remove Gemma RoPE (#7638)
This commit is contained in:
@@ -33,7 +33,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.rotary_embedding import GemmaRotaryEmbedding
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
@@ -148,14 +148,12 @@ class GemmaAttention(nn.Module):
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
# TODO(woosuk): Use the `get_rope` interface.
|
||||
self.rotary_emb = GemmaRotaryEmbedding(
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=self.head_dim,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
max_position=max_position_embeddings,
|
||||
base=self.rope_theta,
|
||||
is_neox_style=True,
|
||||
dtype=torch.get_default_dtype(),
|
||||
)
|
||||
self.attn = Attention(self.num_heads,
|
||||
self.head_dim,
|
||||
|
||||
Reference in New Issue
Block a user