[ROCm][CI] Fix flaky GPTQ compile correctness test (#38161)

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
This commit is contained in:
Andreas Karatzas
2026-03-26 06:57:00 -05:00
committed by GitHub
parent 37a83007fe
commit f2d16207c7
2 changed files with 41 additions and 34 deletions

View File

@@ -1348,40 +1348,47 @@ def initialize_single_dummy_weight(
high: float = 1e-3,
seed: int = 1234,
) -> None:
if torch.is_floating_point(param):
if current_platform.is_tpu():
generator = torch.Generator(device="cpu")
generator.manual_seed(seed)
# Note: The param.uniform_ function cannot be used in this
# context because it demands more TPU HBM than directly copying
# from a CPU tensor.
# Note: We avoid using torch.rank_like as it doesn't currently
# support the generator argument.
param.copy_(
(high - low)
* torch.rand(
param.shape,
generator=generator,
dtype=param.dtype,
layout=param.layout,
requires_grad=param.requires_grad,
device="cpu",
)
+ low
)
torch._sync(param)
return
if not torch.is_floating_point(param):
if current_platform.is_rocm():
# On ROCm, integer params (e.g. GPTQ qweight/qzeros) are left
# as torch.empty() by default, giving non-deterministic values
# across processes. Zero them for reproducibility.
param.zero_()
return
generator = torch.Generator(device=param.data.device)
if current_platform.is_tpu():
generator = torch.Generator(device="cpu")
generator.manual_seed(seed)
if torch.finfo(param.data.dtype).bits < 16:
# uniform_ doesn't support < 16-bit datatypes (FP8)
dtype = param.data.dtype
tmp_param = param.data.to(torch.float16)
tmp_param = tmp_param.uniform_(low, high, generator=generator).to(dtype)
param.data.copy_(tmp_param)
else:
param.uniform_(low, high, generator=generator)
# Note: The param.uniform_ function cannot be used in this
# context because it demands more TPU HBM than directly copying
# from a CPU tensor.
# Note: We avoid using torch.rank_like as it doesn't currently
# support the generator argument.
param.copy_(
(high - low)
* torch.rand(
param.shape,
generator=generator,
dtype=param.dtype,
layout=param.layout,
requires_grad=param.requires_grad,
device="cpu",
)
+ low
)
torch._sync(param)
return
generator = torch.Generator(device=param.data.device)
generator.manual_seed(seed)
if torch.finfo(param.data.dtype).bits < 16:
# uniform_ doesn't support < 16-bit datatypes (FP8)
dtype = param.data.dtype
tmp_param = param.data.to(torch.float16)
tmp_param = tmp_param.uniform_(low, high, generator=generator).to(dtype)
param.data.copy_(tmp_param)
else:
param.uniform_(low, high, generator=generator)
def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> str | None: