[ROCm][Bugfix] Disable hip sampler to fix deepseek's accuracy issue on ROCm (#32413)
Signed-off-by: ganyi <ygan@amd.com>
This commit is contained in:
@@ -174,6 +174,8 @@ class TopKTopPSampler(nn.Module):
|
||||
k: torch.Tensor | None,
|
||||
p: torch.Tensor | None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
# FIXME: Fix aiter_sampler's accuracy issue and remove this flag
|
||||
DISABLE_AITER_SAMPLER = True
|
||||
"""Optimized ROCm/aiter path (same structure as forward_cuda)."""
|
||||
if (k is None and p is None) or generators:
|
||||
if generators:
|
||||
@@ -186,6 +188,8 @@ class TopKTopPSampler(nn.Module):
|
||||
"processed_logits",
|
||||
"processed_logprobs",
|
||||
), "aiter sampler does not support returning logits/logprobs."
|
||||
if DISABLE_AITER_SAMPLER:
|
||||
return self.forward_native(logits, generators, k, p)
|
||||
return self.aiter_sample(logits, k, p, generators), None
|
||||
|
||||
def aiter_sample(
|
||||
|
||||
Reference in New Issue
Block a user