[Quantization][Deprecation] Remove DeepSpeedFp8 (#32679)

Signed-off-by: Robert Shaw <robshaw@redhat.com>
Co-authored-by: Robert Shaw <robshaw@redhat.com>
This commit is contained in:
Robert Shaw
2026-01-21 09:32:12 -05:00
committed by GitHub
parent 42135d6898
commit cea3c754c4
5 changed files with 19 additions and 284 deletions

View File

@@ -29,10 +29,6 @@ from vllm.model_executor.layers.linear import (
)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.deepspeedfp import (
DeepSpeedFPConfig,
DeepSpeedFPParameter,
)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead,
@@ -128,7 +124,6 @@ class ArcticMoE(nn.Module):
self.intermediate_size = config.intermediate_size // self.tp_size
self.is_moe_layer = (layer_id + 1) % config.moe_layer_frequency == 0
self.is_quant = isinstance(quant_config, DeepSpeedFPConfig)
self.reduce_results = reduce_results
# Some other parameters
if params_dtype is None:
@@ -151,40 +146,24 @@ class ArcticMoE(nn.Module):
quant_config=quant_config,
prefix=f"{prefix}.gate",
)
if self.is_quant:
self.ws = DeepSpeedFPParameter(
torch.Size(
(self.num_experts, 2 * self.intermediate_size, self.hidden_size)
),
params_dtype=params_dtype,
quant_config=quant_config,
self.ws = nn.Parameter(
torch.empty(
self.num_experts,
2 * self.intermediate_size,
self.hidden_size,
device=current_platform.device_type,
dtype=self.params_dtype,
)
self.w2s = DeepSpeedFPParameter(
torch.Size(
(self.num_experts, self.hidden_size, self.intermediate_size)
),
params_dtype=params_dtype,
quant_config=quant_config,
)
else:
self.ws = nn.Parameter(
torch.empty(
self.num_experts,
2 * self.intermediate_size,
self.hidden_size,
device=current_platform.device_type,
dtype=self.params_dtype,
)
)
self.w2s = nn.Parameter(
torch.empty(
self.num_experts,
self.hidden_size,
self.intermediate_size,
device=current_platform.device_type,
dtype=self.params_dtype,
)
)
self.w2s = nn.Parameter(
torch.empty(
self.num_experts,
self.hidden_size,
self.intermediate_size,
device=current_platform.device_type,
dtype=self.params_dtype,
)
)
set_weight_attrs(
self.ws,
{
@@ -206,7 +185,7 @@ class ArcticMoE(nn.Module):
expert_id: int,
):
tp_rank = get_tensor_model_parallel_rank()
param_data = param.ds_dequantize() if self.is_quant else param.data
param_data = param.data
shard_size = self.intermediate_size
shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
if weight_name.endswith("w1.weight"):
@@ -217,8 +196,6 @@ class ArcticMoE(nn.Module):
]
if weight_name.endswith("w2.weight"):
param_data[expert_id, :, :] = loaded_weight[:, shard]
if self.is_quant:
param.ds_quantize_(param_data)
def local_moe_fused(self, hidden_states: torch.Tensor) -> torch.Tensor:
num_tokens, hidden_size = hidden_states.shape
@@ -229,26 +206,10 @@ class ArcticMoE(nn.Module):
topk_weights, topk_ids, token_expert_indices = fused_topk(
hidden_states, router_logits, self.top_k, renormalize=do_normalize
)
# topk_ids: (num_tokens, k)
if self.is_quant:
if 2 * num_tokens <= self.num_experts:
# If much fewer tokens than experts, use selective dequantize.
ws_dequantized = self.ws.ds_selective_dequantize(topk_ids.flatten())
w2s_dequantized = self.w2s.ds_selective_dequantize(topk_ids.flatten())
# We gathered the experts to the tokens so update the mapping.
topk_ids = torch.arange(
0,
topk_ids.numel(),
device=topk_ids.device,
).reshape(topk_ids.shape)
else:
ws_dequantized = self.ws.ds_dequantize()
w2s_dequantized = self.w2s.ds_dequantize()
final_hidden_states = fused_experts(
hidden_states,
ws_dequantized if self.is_quant else self.ws,
w2s_dequantized if self.is_quant else self.w2s,
self.ws,
self.w2s,
topk_weights,
topk_ids,
inplace=True,