[Refactor] Make FP8 Linear Ops use kernel abstraction (#27814)
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
This commit is contained in:
@@ -27,13 +27,14 @@ from vllm.distributed.parallel_state import (
|
||||
initialize_model_parallel,
|
||||
)
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import Fp8LinearOp
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
kFp8StaticTensorSym,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.system_utils import update_environment_variables
|
||||
from vllm.utils.torch_utils import set_random_seed
|
||||
|
||||
from ...utils import multi_gpu_test
|
||||
from ...utils import TestFP8Layer, multi_gpu_test
|
||||
from ..backend import TestBackend
|
||||
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
@@ -94,50 +95,40 @@ class TestAllReduceRMSNormModel(torch.nn.Module):
|
||||
|
||||
|
||||
class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module):
|
||||
quant_key = kFp8StaticTensorSym
|
||||
|
||||
def __init__(self, hidden_size=16, eps=1e-6):
|
||||
super().__init__()
|
||||
self.vllm_config = get_current_vllm_config()
|
||||
self.hidden_size = hidden_size
|
||||
self.eps = eps
|
||||
self.norm = [RMSNorm(hidden_size, eps) for i in range(4)]
|
||||
self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
|
||||
self.w = [
|
||||
torch.rand(hidden_size, hidden_size)
|
||||
.to(dtype=current_platform.fp8_dtype())
|
||||
.t()
|
||||
for _ in range(3)
|
||||
self.fp8_linear_layers = [
|
||||
TestFP8Layer(
|
||||
weight_shape=(hidden_size, hidden_size),
|
||||
activation_quant_key=self.quant_key,
|
||||
weight_quant_key=self.quant_key,
|
||||
)
|
||||
for i in range(3)
|
||||
]
|
||||
|
||||
self.fp8_linear = Fp8LinearOp(
|
||||
act_quant_static=True,
|
||||
act_quant_group_shape=GroupShape.PER_TENSOR,
|
||||
)
|
||||
|
||||
self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
|
||||
|
||||
def forward(self, hidden_states):
|
||||
# avoid having graph input be an arg to a pattern directly
|
||||
z = torch.relu(hidden_states)
|
||||
x = resid = tensor_model_parallel_all_reduce(z)
|
||||
y = self.norm[0](x)
|
||||
|
||||
z2 = self.fp8_linear.apply(
|
||||
y, self.w[0], self.wscale[0], input_scale=self.scale[0]
|
||||
)
|
||||
z2 = self.fp8_linear_layers[0](y)
|
||||
|
||||
x2 = tensor_model_parallel_all_reduce(z2)
|
||||
y2, resid = self.norm[1](x2, resid)
|
||||
|
||||
z3 = self.fp8_linear.apply(
|
||||
y2, self.w[1], self.wscale[1], input_scale=self.scale[1]
|
||||
)
|
||||
z3 = self.fp8_linear_layers[1](y2)
|
||||
|
||||
x3 = tensor_model_parallel_all_reduce(z3)
|
||||
y3, resid = self.norm[2](x3, resid) # use resid here
|
||||
|
||||
z4 = self.fp8_linear.apply(
|
||||
y3, self.w[2], self.wscale[2], input_scale=self.scale[2]
|
||||
)
|
||||
z4 = self.fp8_linear_layers[2](y3)
|
||||
x4 = tensor_model_parallel_all_reduce(z4)
|
||||
y4, resid = self.norm[3](x4, resid) # use resid here
|
||||
return y4
|
||||
@@ -160,7 +151,7 @@ class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module):
|
||||
return [
|
||||
torch.ops._C.fused_add_rms_norm.default,
|
||||
]
|
||||
elif self.fp8_linear.quant_fp8.enabled():
|
||||
elif any(layer.is_quant_fp8_enabled() for layer in self.fp8_linear_layers):
|
||||
return [
|
||||
torch.ops._C.static_scaled_fp8_quant.default,
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user