[fix]: remove data type hardcoding from gptoss model implementation (#23807)
Signed-off-by: Nikhil Gupta <nikhil.gupta2@arm.com>
This commit is contained in:
@@ -76,7 +76,6 @@ class OAIAttention(nn.Module):
|
|||||||
|
|
||||||
self.sinks = torch.nn.Parameter(
|
self.sinks = torch.nn.Parameter(
|
||||||
torch.empty(config.num_attention_heads // tp_size,
|
torch.empty(config.num_attention_heads // tp_size,
|
||||||
dtype=torch.bfloat16,
|
|
||||||
requires_grad=False))
|
requires_grad=False))
|
||||||
|
|
||||||
self.q_size = self.num_attention_heads * self.head_dim // tp_size
|
self.q_size = self.num_attention_heads * self.head_dim // tp_size
|
||||||
@@ -145,8 +144,7 @@ class MLPBlock(torch.nn.Module):
|
|||||||
self.experts_per_token = config.num_experts_per_tok
|
self.experts_per_token = config.num_experts_per_tok
|
||||||
self.world_size = dist.get_world_size() if dist.is_initialized() else 1
|
self.world_size = dist.get_world_size() if dist.is_initialized() else 1
|
||||||
self.router = torch.nn.Linear(config.hidden_size,
|
self.router = torch.nn.Linear(config.hidden_size,
|
||||||
config.num_local_experts,
|
config.num_local_experts)
|
||||||
dtype=torch.bfloat16)
|
|
||||||
assert config.intermediate_size % self.world_size == 0
|
assert config.intermediate_size % self.world_size == 0
|
||||||
self.experts = FusedMoE(num_experts=config.num_local_experts,
|
self.experts = FusedMoE(num_experts=config.num_local_experts,
|
||||||
top_k=config.num_experts_per_tok,
|
top_k=config.num_experts_per_tok,
|
||||||
|
|||||||
Reference in New Issue
Block a user