[AITER] [ROCm] Fix crash when loading llama4 model with old aiter version installed, fallback to forward_native implementation (#29124)
Signed-off-by: Xiao Li <ilx@meta.com>
This commit is contained in:
@@ -60,6 +60,7 @@ class TopKTopPSampler(nn.Module):
|
|||||||
logprobs_mode not in ("processed_logits", "processed_logprobs")
|
logprobs_mode not in ("processed_logits", "processed_logprobs")
|
||||||
and rocm_aiter_ops.is_enabled()
|
and rocm_aiter_ops.is_enabled()
|
||||||
):
|
):
|
||||||
|
try:
|
||||||
import aiter.ops.sampling # noqa: F401
|
import aiter.ops.sampling # noqa: F401
|
||||||
|
|
||||||
self.aiter_ops = torch.ops.aiter
|
self.aiter_ops = torch.ops.aiter
|
||||||
@@ -67,6 +68,12 @@ class TopKTopPSampler(nn.Module):
|
|||||||
"Using aiter sampler on ROCm (lazy import, sampling-only)."
|
"Using aiter sampler on ROCm (lazy import, sampling-only)."
|
||||||
)
|
)
|
||||||
self.forward = self.forward_hip
|
self.forward = self.forward_hip
|
||||||
|
except ImportError:
|
||||||
|
logger.warning_once(
|
||||||
|
"aiter.ops.sampling is not available on ROCm. "
|
||||||
|
"Falling back to forward_native implementation."
|
||||||
|
)
|
||||||
|
self.forward = self.forward_native
|
||||||
else:
|
else:
|
||||||
self.forward = self.forward_native
|
self.forward = self.forward_native
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user