[Refactor] Make FP8 Linear Ops use kernel abstraction (#27814)
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
This commit is contained in:
@@ -20,11 +20,13 @@ from vllm.config import (
|
||||
)
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import Fp8LinearOp
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
kFp8StaticTensorSym,
|
||||
)
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from ..utils import TestFP8Layer
|
||||
from .backend import TestBackend
|
||||
|
||||
TEST_FP8 = current_platform.supports_fp8()
|
||||
@@ -32,24 +34,22 @@ FP8_DTYPE = current_platform.fp8_dtype()
|
||||
|
||||
|
||||
class TestSiluMul(torch.nn.Module):
|
||||
quant_key = kFp8StaticTensorSym
|
||||
|
||||
def __init__(self, hidden_size: int = 128):
|
||||
super().__init__()
|
||||
self.silu_and_mul = SiluAndMul()
|
||||
self.wscale = torch.rand(1, dtype=torch.float32)
|
||||
self.scale = torch.rand(1, dtype=torch.float32)
|
||||
|
||||
if TEST_FP8:
|
||||
self.w = torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t()
|
||||
self.fp8_linear = Fp8LinearOp(
|
||||
act_quant_static=True,
|
||||
act_quant_group_shape=GroupShape.PER_TENSOR,
|
||||
self.fp8_linear = TestFP8Layer(
|
||||
weight_shape=(hidden_size, hidden_size),
|
||||
activation_quant_key=self.quant_key,
|
||||
weight_quant_key=self.quant_key,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
y = self.silu_and_mul(x)
|
||||
if TEST_FP8:
|
||||
x2 = self.fp8_linear.apply(y, self.w, self.wscale, input_scale=self.wscale)
|
||||
return x2
|
||||
return self.fp8_linear(y)
|
||||
else:
|
||||
return y
|
||||
|
||||
@@ -67,6 +67,8 @@ class TestSiluMul(torch.nn.Module):
|
||||
|
||||
|
||||
class TestFusedAddRMSNorm(torch.nn.Module):
|
||||
quant_key = kFp8StaticTensorSym
|
||||
|
||||
def __init__(self, hidden_size=16, intermediate_size=32):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
@@ -81,11 +83,11 @@ class TestFusedAddRMSNorm(torch.nn.Module):
|
||||
torch.nn.init.normal_(self.gate_proj, std=0.02)
|
||||
|
||||
if TEST_FP8:
|
||||
self.fp8_linear = Fp8LinearOp(act_quant_static=True)
|
||||
|
||||
self.scale = torch.rand(1, dtype=torch.float32)
|
||||
self.w = torch.rand(hidden_size, intermediate_size).to(dtype=FP8_DTYPE).t()
|
||||
self.wscale = torch.rand(1, dtype=torch.float32)
|
||||
self.fp8_linear = TestFP8Layer(
|
||||
weight_shape=(hidden_size, intermediate_size),
|
||||
activation_quant_key=self.quant_key,
|
||||
weight_quant_key=self.quant_key,
|
||||
)
|
||||
|
||||
def forward(self, hidden_states, residual):
|
||||
# Reshape input
|
||||
@@ -100,12 +102,7 @@ class TestFusedAddRMSNorm(torch.nn.Module):
|
||||
|
||||
if TEST_FP8:
|
||||
# scaled_mm with static input quantization
|
||||
fp8_linear_result = self.fp8_linear.apply(
|
||||
norm_output,
|
||||
self.w,
|
||||
self.wscale,
|
||||
input_scale=self.scale.to(norm_output.device),
|
||||
)
|
||||
fp8_linear_result = self.fp8_linear(norm_output)
|
||||
|
||||
return fp8_linear_result, residual_output
|
||||
|
||||
|
||||
Reference in New Issue
Block a user