[W8A8 Block Linear Refactor][2/N] Remove W8A8Fp8BlockLinearOp and adopt Fp8 block linear kernel selections. (#33892)

Signed-off-by: maral <maralbahari.98@gmail.com>
Signed-off-by: Maral <maralbahari.98@gmail.com>
This commit is contained in:
Maral
2026-04-09 08:50:39 +08:00
committed by GitHub
parent 8332078cfd
commit 2e9034c998
35 changed files with 1710 additions and 904 deletions

View File

@@ -39,7 +39,9 @@ from vllm.utils.torch_utils import set_random_seed
class TestAllReduceRMSNormModel(torch.nn.Module):
def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
def __init__(
self, hidden_size=16, token_num=16, eps=1e-6, dtype: torch.dtype = torch.float16
):
super().__init__()
self.hidden_size = hidden_size
self.eps = eps
@@ -78,7 +80,9 @@ class TestAllReduceRMSNormModel(torch.nn.Module):
class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module):
quant_key = kFp8StaticTensorSym
def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
def __init__(
self, hidden_size=16, token_num=16, eps=1e-6, dtype: torch.dtype = torch.float16
):
super().__init__()
self.hidden_size = hidden_size
self.eps = eps
@@ -88,6 +92,7 @@ class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module):
weight_shape=(hidden_size, hidden_size),
activation_quant_key=self.quant_key,
weight_quant_key=self.quant_key,
input_dtype=dtype,
)
for i in range(3)
]
@@ -127,7 +132,9 @@ class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module):
class TestAllReduceFusedAddRMSNormStaticQuantFP4Model(torch.nn.Module):
def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
def __init__(
self, hidden_size=16, token_num=16, eps=1e-6, dtype: torch.dtype = torch.float16
):
super().__init__()
self.hidden_size = hidden_size
self.eps = eps
@@ -314,7 +321,7 @@ def all_reduce_fusion_pass_on_test_model(
)
token_num = batch_size * seq_len
model = test_model_cls(hidden_size, token_num)
model = test_model_cls(hidden_size, token_num, dtype=dtype)
hidden_states = torch.randn((token_num, hidden_size), requires_grad=False)

View File

@@ -109,6 +109,7 @@ class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module):
weight_shape=(hidden_size, hidden_size),
activation_quant_key=self.quant_key,
weight_quant_key=self.quant_key,
input_dtype=self.vllm_config.model_config.dtype,
)
for i in range(3)
]