[ Misc ] Refactor MoE to isolate Fp8 From Mixtral (#5970)

Co-authored-by: Robert Shaw <rshaw@neuralmagic>
Co-authored-by: Michael Goin <michael@neuralmagic.com>
This commit is contained in:
Robert Shaw
2024-07-02 17:54:35 -04:00
committed by GitHub
parent 4d26d806e1
commit 7c008c51a9
10 changed files with 535 additions and 304 deletions

View File

@@ -77,8 +77,8 @@ def test_mixtral_moe(dtype: torch.dtype):
for i in range(config.num_local_experts):
weights = (hf_moe.experts[i].w1.weight.data,
hf_moe.experts[i].w3.weight.data)
vllm_moe.w13_weight[i][:] = torch.cat(weights, dim=0)
vllm_moe.w2_weight[i][:] = hf_moe.experts[i].w2.weight.data
vllm_moe.experts.w13_weight[i][:] = torch.cat(weights, dim=0)
vllm_moe.experts.w2_weight[i][:] = hf_moe.experts[i].w2.weight.data
# Generate input batch of dimensions [batch_size, seq_len, hidden_dim]
hf_inputs = torch.randn((1, 64, config.hidden_size)).to(dtype).to("cuda")