[W8A8 Block Linear Refactor][2/N] Remove W8A8Fp8BlockLinearOp and adopt Fp8 block linear kernel selections. (#33892)
Signed-off-by: maral <maralbahari.98@gmail.com> Signed-off-by: Maral <maralbahari.98@gmail.com>
This commit is contained in:
@@ -23,6 +23,7 @@ from vllm.config import (
|
||||
ModelConfig,
|
||||
PassConfig,
|
||||
VllmConfig,
|
||||
get_current_vllm_config,
|
||||
set_current_vllm_config,
|
||||
)
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
@@ -49,6 +50,7 @@ class TestSiluMul(torch.nn.Module):
|
||||
weight_shape=(hidden_size, hidden_size),
|
||||
activation_quant_key=self.quant_key,
|
||||
weight_quant_key=self.quant_key,
|
||||
input_dtype=get_current_vllm_config().model_config.dtype,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
@@ -92,6 +94,7 @@ class TestFusedAddRMSNorm(torch.nn.Module):
|
||||
weight_shape=(hidden_size, intermediate_size),
|
||||
activation_quant_key=self.quant_key,
|
||||
weight_quant_key=self.quant_key,
|
||||
input_dtype=get_current_vllm_config().model_config.dtype,
|
||||
)
|
||||
|
||||
def forward(self, hidden_states, residual):
|
||||
|
||||
Reference in New Issue
Block a user