[Kernel] Support MoE Fp8 Checkpoints for Mixtral (Static Weights with Dynamic/Static Activations) (#4527)
Follow on to #4332 to enable FP8 checkpoint loading for Mixtral and supersedes #4436. This PR enables the following checkpoint loading features for Mixtral: Supports loading fp8 checkpoints for Mixtral, such as this "nm-testing/Mixtral-8x7B-Instruct-v0.1-FP8" test model Supports static or dynamic activation quantization with static weight quantization (all per tensor) Supports different scales for each expert weight Supports Fp8 in QKV layer Notes: The Expert Gate/Router always runs at half / full precision for now. If there are different weight scales between QKV layer (for separate QKV weights), they are re-quantized using layer.weight_scale.max() so we can have a single gemm for performance.
This commit is contained in:
@@ -77,8 +77,8 @@ def test_mixtral_moe(dtype: torch.dtype):
|
||||
for i in range(config.num_local_experts):
|
||||
weights = (hf_moe.experts[i].w1.weight.data,
|
||||
hf_moe.experts[i].w3.weight.data)
|
||||
vllm_moe.ws[i][:] = torch.cat(weights, dim=0)
|
||||
vllm_moe.w2s[i][:] = hf_moe.experts[i].w2.weight.data
|
||||
vllm_moe.w13_weight[i][:] = torch.cat(weights, dim=0)
|
||||
vllm_moe.w2_weight[i][:] = hf_moe.experts[i].w2.weight.data
|
||||
|
||||
# Generate input batch of dimensions [batch_size, seq_len, hidden_dim]
|
||||
hf_inputs = torch.randn((1, 64, config.hidden_size)).to(dtype).to("cuda")
|
||||
|
||||
Reference in New Issue
Block a user