[Quantization][Deprecation] Remove DeepSpeedFp8 (#32679)
Signed-off-by: Robert Shaw <robshaw@redhat.com> Co-authored-by: Robert Shaw <robshaw@redhat.com>
This commit is contained in:
@@ -29,10 +29,6 @@ from vllm.model_executor.layers.linear import (
|
||||
)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.quantization.deepspeedfp import (
|
||||
DeepSpeedFPConfig,
|
||||
DeepSpeedFPParameter,
|
||||
)
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead,
|
||||
@@ -128,7 +124,6 @@ class ArcticMoE(nn.Module):
|
||||
self.intermediate_size = config.intermediate_size // self.tp_size
|
||||
|
||||
self.is_moe_layer = (layer_id + 1) % config.moe_layer_frequency == 0
|
||||
self.is_quant = isinstance(quant_config, DeepSpeedFPConfig)
|
||||
self.reduce_results = reduce_results
|
||||
# Some other parameters
|
||||
if params_dtype is None:
|
||||
@@ -151,40 +146,24 @@ class ArcticMoE(nn.Module):
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.gate",
|
||||
)
|
||||
if self.is_quant:
|
||||
self.ws = DeepSpeedFPParameter(
|
||||
torch.Size(
|
||||
(self.num_experts, 2 * self.intermediate_size, self.hidden_size)
|
||||
),
|
||||
params_dtype=params_dtype,
|
||||
quant_config=quant_config,
|
||||
self.ws = nn.Parameter(
|
||||
torch.empty(
|
||||
self.num_experts,
|
||||
2 * self.intermediate_size,
|
||||
self.hidden_size,
|
||||
device=current_platform.device_type,
|
||||
dtype=self.params_dtype,
|
||||
)
|
||||
self.w2s = DeepSpeedFPParameter(
|
||||
torch.Size(
|
||||
(self.num_experts, self.hidden_size, self.intermediate_size)
|
||||
),
|
||||
params_dtype=params_dtype,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
else:
|
||||
self.ws = nn.Parameter(
|
||||
torch.empty(
|
||||
self.num_experts,
|
||||
2 * self.intermediate_size,
|
||||
self.hidden_size,
|
||||
device=current_platform.device_type,
|
||||
dtype=self.params_dtype,
|
||||
)
|
||||
)
|
||||
self.w2s = nn.Parameter(
|
||||
torch.empty(
|
||||
self.num_experts,
|
||||
self.hidden_size,
|
||||
self.intermediate_size,
|
||||
device=current_platform.device_type,
|
||||
dtype=self.params_dtype,
|
||||
)
|
||||
)
|
||||
self.w2s = nn.Parameter(
|
||||
torch.empty(
|
||||
self.num_experts,
|
||||
self.hidden_size,
|
||||
self.intermediate_size,
|
||||
device=current_platform.device_type,
|
||||
dtype=self.params_dtype,
|
||||
)
|
||||
)
|
||||
set_weight_attrs(
|
||||
self.ws,
|
||||
{
|
||||
@@ -206,7 +185,7 @@ class ArcticMoE(nn.Module):
|
||||
expert_id: int,
|
||||
):
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
param_data = param.ds_dequantize() if self.is_quant else param.data
|
||||
param_data = param.data
|
||||
shard_size = self.intermediate_size
|
||||
shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
|
||||
if weight_name.endswith("w1.weight"):
|
||||
@@ -217,8 +196,6 @@ class ArcticMoE(nn.Module):
|
||||
]
|
||||
if weight_name.endswith("w2.weight"):
|
||||
param_data[expert_id, :, :] = loaded_weight[:, shard]
|
||||
if self.is_quant:
|
||||
param.ds_quantize_(param_data)
|
||||
|
||||
def local_moe_fused(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
num_tokens, hidden_size = hidden_states.shape
|
||||
@@ -229,26 +206,10 @@ class ArcticMoE(nn.Module):
|
||||
topk_weights, topk_ids, token_expert_indices = fused_topk(
|
||||
hidden_states, router_logits, self.top_k, renormalize=do_normalize
|
||||
)
|
||||
# topk_ids: (num_tokens, k)
|
||||
if self.is_quant:
|
||||
if 2 * num_tokens <= self.num_experts:
|
||||
# If much fewer tokens than experts, use selective dequantize.
|
||||
ws_dequantized = self.ws.ds_selective_dequantize(topk_ids.flatten())
|
||||
w2s_dequantized = self.w2s.ds_selective_dequantize(topk_ids.flatten())
|
||||
# We gathered the experts to the tokens so update the mapping.
|
||||
topk_ids = torch.arange(
|
||||
0,
|
||||
topk_ids.numel(),
|
||||
device=topk_ids.device,
|
||||
).reshape(topk_ids.shape)
|
||||
else:
|
||||
ws_dequantized = self.ws.ds_dequantize()
|
||||
w2s_dequantized = self.w2s.ds_dequantize()
|
||||
|
||||
final_hidden_states = fused_experts(
|
||||
hidden_states,
|
||||
ws_dequantized if self.is_quant else self.ws,
|
||||
w2s_dequantized if self.is_quant else self.w2s,
|
||||
self.ws,
|
||||
self.w2s,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
inplace=True,
|
||||
|
||||
Reference in New Issue
Block a user