[BugFix] 1D query fix for MoE models (#3597)
This commit is contained in:
@@ -81,11 +81,13 @@ def test_mixtral_moe(dtype: torch.dtype):
|
||||
vllm_moe.w2s[i][:] = hf_moe.experts[i].w2.weight.data
|
||||
|
||||
# Generate input batch of dimensions [batch_size, seq_len, hidden_dim]
|
||||
inputs = torch.randn((1, 64, config.hidden_size)).to(dtype).to("cuda")
|
||||
hf_inputs = torch.randn((1, 64, config.hidden_size)).to(dtype).to("cuda")
|
||||
# vLLM uses 1D query [num_tokens, hidden_dim]
|
||||
vllm_inputs = hf_inputs.flatten(0, 1)
|
||||
|
||||
# Run forward passes for both MoE blocks
|
||||
hf_states, _ = hf_moe.forward(inputs)
|
||||
vllm_states = vllm_moe.forward(inputs)
|
||||
hf_states, _ = hf_moe.forward(hf_inputs)
|
||||
vllm_states = vllm_moe.forward(vllm_inputs)
|
||||
|
||||
mixtral_moe_tol = {
|
||||
torch.float32: 1e-3,
|
||||
@@ -93,7 +95,7 @@ def test_mixtral_moe(dtype: torch.dtype):
|
||||
torch.bfloat16: 1e-2,
|
||||
}
|
||||
|
||||
assert torch.allclose(hf_states,
|
||||
assert torch.allclose(hf_states.flatten(0, 1),
|
||||
vllm_states,
|
||||
rtol=mixtral_moe_tol[dtype],
|
||||
atol=mixtral_moe_tol[dtype])
|
||||
|
||||
Reference in New Issue
Block a user