[W8A8 Block Linear Refactor][2/N] Remove W8A8Fp8BlockLinearOp and adopt Fp8 block linear kernel selections. (#33892)
Signed-off-by: maral <maralbahari.98@gmail.com> Signed-off-by: Maral <maralbahari.98@gmail.com>
This commit is contained in:
@@ -39,7 +39,9 @@ from vllm.utils.torch_utils import set_random_seed
|
||||
|
||||
|
||||
class TestAllReduceRMSNormModel(torch.nn.Module):
|
||||
def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
|
||||
def __init__(
|
||||
self, hidden_size=16, token_num=16, eps=1e-6, dtype: torch.dtype = torch.float16
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.eps = eps
|
||||
@@ -78,7 +80,9 @@ class TestAllReduceRMSNormModel(torch.nn.Module):
|
||||
class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module):
|
||||
quant_key = kFp8StaticTensorSym
|
||||
|
||||
def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
|
||||
def __init__(
|
||||
self, hidden_size=16, token_num=16, eps=1e-6, dtype: torch.dtype = torch.float16
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.eps = eps
|
||||
@@ -88,6 +92,7 @@ class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module):
|
||||
weight_shape=(hidden_size, hidden_size),
|
||||
activation_quant_key=self.quant_key,
|
||||
weight_quant_key=self.quant_key,
|
||||
input_dtype=dtype,
|
||||
)
|
||||
for i in range(3)
|
||||
]
|
||||
@@ -127,7 +132,9 @@ class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module):
|
||||
|
||||
|
||||
class TestAllReduceFusedAddRMSNormStaticQuantFP4Model(torch.nn.Module):
|
||||
def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
|
||||
def __init__(
|
||||
self, hidden_size=16, token_num=16, eps=1e-6, dtype: torch.dtype = torch.float16
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.eps = eps
|
||||
@@ -314,7 +321,7 @@ def all_reduce_fusion_pass_on_test_model(
|
||||
)
|
||||
|
||||
token_num = batch_size * seq_len
|
||||
model = test_model_cls(hidden_size, token_num)
|
||||
model = test_model_cls(hidden_size, token_num, dtype=dtype)
|
||||
|
||||
hidden_states = torch.randn((token_num, hidden_size), requires_grad=False)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user