Use w8a8 quantized matmul Pallas kernel (#19170)
Signed-off-by: Xiongfei Wei <isaacwxf23@gmail.com>
This commit is contained in:
@@ -14,7 +14,7 @@ RTOL = 0.03
|
||||
@dataclass
|
||||
class GSM8KAccuracyTestConfig:
|
||||
model_name: str
|
||||
excepted_value: float
|
||||
expected_value: float
|
||||
|
||||
def get_model_args(self) -> str:
|
||||
return (f"pretrained={self.model_name},"
|
||||
@@ -25,13 +25,13 @@ class GSM8KAccuracyTestConfig:
|
||||
ACCURACY_CONFIGS = [
|
||||
GSM8KAccuracyTestConfig(
|
||||
model_name="neuralmagic/Meta-Llama-3.1-8B-Instruct-quantized.w8a8",
|
||||
excepted_value=0.76), # no bias
|
||||
expected_value=0.76), # no bias
|
||||
# NOTE(rob): We cannot re-initialize vLLM in the same process for TPU,
|
||||
# so only one of these tests can run in a single call to pytest. As
|
||||
# a follow up, move this into the LM-EVAL section of the CI.
|
||||
# GSM8KAccuracyTestConfig(
|
||||
# model_name="neuralmagic/Qwen2-7B-Instruct-quantized.w8a8",
|
||||
# excepted_value=0.66), # bias in QKV layers
|
||||
# expected_value=0.66), # bias in QKV layers
|
||||
]
|
||||
|
||||
|
||||
@@ -45,7 +45,7 @@ def test_gsm8k_correctness(config: GSM8KAccuracyTestConfig):
|
||||
batch_size="auto",
|
||||
)
|
||||
|
||||
EXPECTED_VALUE = config.excepted_value
|
||||
EXPECTED_VALUE = config.expected_value
|
||||
measured_value = results["results"][TASK][FILTER]
|
||||
assert (measured_value - RTOL < EXPECTED_VALUE
|
||||
and measured_value + RTOL > EXPECTED_VALUE
|
||||
|
||||
Reference in New Issue
Block a user