Use w8a8 quantized matmul Pallas kernel (#19170)

Signed-off-by: Xiongfei Wei <isaacwxf23@gmail.com>
This commit is contained in:
XiongfeiWei
2025-07-14 20:06:33 -07:00
committed by GitHub
parent 946aadb4a0
commit d4170fad39
4 changed files with 50 additions and 19 deletions

View File

@@ -14,7 +14,7 @@ RTOL = 0.03
@dataclass
class GSM8KAccuracyTestConfig:
model_name: str
excepted_value: float
expected_value: float
def get_model_args(self) -> str:
return (f"pretrained={self.model_name},"
@@ -25,13 +25,13 @@ class GSM8KAccuracyTestConfig:
ACCURACY_CONFIGS = [
GSM8KAccuracyTestConfig(
model_name="neuralmagic/Meta-Llama-3.1-8B-Instruct-quantized.w8a8",
excepted_value=0.76), # no bias
expected_value=0.76), # no bias
# NOTE(rob): We cannot re-initialize vLLM in the same process for TPU,
# so only one of these tests can run in a single call to pytest. As
# a follow up, move this into the LM-EVAL section of the CI.
# GSM8KAccuracyTestConfig(
# model_name="neuralmagic/Qwen2-7B-Instruct-quantized.w8a8",
# excepted_value=0.66), # bias in QKV layers
# expected_value=0.66), # bias in QKV layers
]
@@ -45,7 +45,7 @@ def test_gsm8k_correctness(config: GSM8KAccuracyTestConfig):
batch_size="auto",
)
EXPECTED_VALUE = config.excepted_value
EXPECTED_VALUE = config.expected_value
measured_value = results["results"][TASK][FILTER]
assert (measured_value - RTOL < EXPECTED_VALUE
and measured_value + RTOL > EXPECTED_VALUE