[Spec-Decode] Support piecewise cudagraphs for Eagle head (#25109)

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com>
Co-authored-by: Benjamin Chislett <chislett.ben@gmail.com>
This commit is contained in:
Lucas Wilkinson
2025-10-10 01:20:31 -04:00
committed by GitHub
parent da4455609d
commit 29255cfc3b
6 changed files with 84 additions and 16 deletions

View File

@@ -50,11 +50,14 @@ class CUDAGraphMode(enum.Enum):
def mixed_mode(self) -> "CUDAGraphMode":
return CUDAGraphMode(self.value[1]) if self.separate_routine() else self
def has_mode(self, mode: "CUDAGraphMode") -> bool:
assert not mode.separate_routine()
if self.separate_routine():
return mode.value in self.value
return self == mode
def requires_piecewise_compilation(self) -> bool:
return (
self.decode_mode() == CUDAGraphMode.PIECEWISE
or self.mixed_mode() == CUDAGraphMode.PIECEWISE
)
return self.has_mode(CUDAGraphMode.PIECEWISE)
def max_cudagraph_mode(self) -> "CUDAGraphMode":
return CUDAGraphMode(max(self.value)) if self.separate_routine() else self