[Bugfix][Model] Fix Qwen3.5/Qwen3Next ignoring --dtype flag on older GPUs (#35617)
This commit is contained in:
@@ -274,7 +274,6 @@ class Qwen3_5DecoderLayer(Qwen3NextDecoderLayer):
|
|||||||
1,
|
1,
|
||||||
1,
|
1,
|
||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
dtype=config.dtype,
|
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
self.ffn_layer_scale = torch.nn.Parameter(
|
self.ffn_layer_scale = torch.nn.Parameter(
|
||||||
@@ -282,7 +281,6 @@ class Qwen3_5DecoderLayer(Qwen3NextDecoderLayer):
|
|||||||
1,
|
1,
|
||||||
1,
|
1,
|
||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
dtype=config.dtype,
|
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -463,7 +463,6 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
|
|||||||
group_size=None,
|
group_size=None,
|
||||||
norm_before_gate=True,
|
norm_before_gate=True,
|
||||||
device=current_platform.current_device(),
|
device=current_platform.current_device(),
|
||||||
dtype=config.dtype,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.out_proj = RowParallelLinear(
|
self.out_proj = RowParallelLinear(
|
||||||
@@ -1018,7 +1017,6 @@ class Qwen3NextDecoderLayer(nn.Module):
|
|||||||
1,
|
1,
|
||||||
1,
|
1,
|
||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
dtype=config.dtype,
|
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
self.ffn_layer_scale = torch.nn.Parameter(
|
self.ffn_layer_scale = torch.nn.Parameter(
|
||||||
@@ -1026,7 +1024,6 @@ class Qwen3NextDecoderLayer(nn.Module):
|
|||||||
1,
|
1,
|
||||||
1,
|
1,
|
||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
dtype=config.dtype,
|
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user