[Model] GritLM supports other attention backends (#18109)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -100,19 +100,19 @@ class Qwen2MLP(nn.Module):
|
||||
class Qwen2Attention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
max_position: int = 4096 * 32,
|
||||
rope_theta: float = 10000,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
rope_scaling: Optional[Tuple] = None,
|
||||
prefix: str = "",
|
||||
attn_type: str = AttentionType.DECODER,
|
||||
dual_chunk_attention_config: Optional[dict[str,
|
||||
Any]] = None) -> None:
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
max_position: int = 4096 * 32,
|
||||
rope_theta: float = 10000,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
rope_scaling: Optional[Tuple] = None,
|
||||
prefix: str = "",
|
||||
attn_type: str = AttentionType.DECODER,
|
||||
dual_chunk_attention_config: Optional[dict[str, Any]] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
Reference in New Issue
Block a user