Fix test_moe.py for Transformers v5 (#33413)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -680,13 +680,21 @@ def test_mixtral_moe(
|
||||
|
||||
# Load the weights
|
||||
vllm_moe.gate.weight.data[:] = hf_moe.gate.weight.data
|
||||
for i in range(config.num_local_experts):
|
||||
weights = (
|
||||
hf_moe.experts[i].w1.weight.data,
|
||||
hf_moe.experts[i].w3.weight.data,
|
||||
)
|
||||
vllm_moe.experts.w13_weight[i][:] = torch.cat(weights, dim=0)
|
||||
vllm_moe.experts.w2_weight[i][:] = hf_moe.experts[i].w2.weight.data
|
||||
if isinstance(hf_moe.experts, torch.nn.ModuleList):
|
||||
# Transformers v4
|
||||
for i in range(config.num_local_experts):
|
||||
weights = (
|
||||
hf_moe.experts[i].w1.weight.data,
|
||||
hf_moe.experts[i].w3.weight.data,
|
||||
)
|
||||
vllm_moe.experts.w13_weight[i][:] = torch.cat(weights, dim=0)
|
||||
vllm_moe.experts.w2_weight[i][:] = hf_moe.experts[i].w2.weight.data
|
||||
else:
|
||||
# Transformers v5
|
||||
vllm_moe.experts.w13_weight.data[:] = hf_moe.experts.gate_up_proj.data
|
||||
vllm_moe.experts.w2_weight.data[:] = hf_moe.experts.down_proj.data
|
||||
# TODO: remove this line after https://github.com/huggingface/transformers/pull/43622
|
||||
hf_moe.experts.config._experts_implementation = "eager"
|
||||
|
||||
# Generate input batch of dimensions [batch_size, seq_len, hidden_dim]
|
||||
hf_inputs = torch.randn((1, 64, config.hidden_size)).to(dtype).to("cuda")
|
||||
@@ -718,7 +726,10 @@ def test_mixtral_moe(
|
||||
get_forward_context().all_moe_layers = None
|
||||
|
||||
# Run forward passes for both MoE blocks
|
||||
hf_states, _ = hf_moe.forward(hf_inputs)
|
||||
hf_states = hf_moe.forward(hf_inputs)
|
||||
if isinstance(hf_states, tuple):
|
||||
# Transformers v4
|
||||
hf_states = hf_states[0]
|
||||
vllm_states = vllm_moe.forward(vllm_inputs)
|
||||
|
||||
mixtral_moe_tol = {
|
||||
|
||||
Reference in New Issue
Block a user