[Spec-Decode] Support piecewise cudagraphs for Eagle head (#25109)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Signed-off-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com> Co-authored-by: Benjamin Chislett <chislett.ben@gmail.com>
This commit is contained in:
@@ -50,11 +50,14 @@ class CUDAGraphMode(enum.Enum):
|
||||
def mixed_mode(self) -> "CUDAGraphMode":
|
||||
return CUDAGraphMode(self.value[1]) if self.separate_routine() else self
|
||||
|
||||
def has_mode(self, mode: "CUDAGraphMode") -> bool:
|
||||
assert not mode.separate_routine()
|
||||
if self.separate_routine():
|
||||
return mode.value in self.value
|
||||
return self == mode
|
||||
|
||||
def requires_piecewise_compilation(self) -> bool:
|
||||
return (
|
||||
self.decode_mode() == CUDAGraphMode.PIECEWISE
|
||||
or self.mixed_mode() == CUDAGraphMode.PIECEWISE
|
||||
)
|
||||
return self.has_mode(CUDAGraphMode.PIECEWISE)
|
||||
|
||||
def max_cudagraph_mode(self) -> "CUDAGraphMode":
|
||||
return CUDAGraphMode(max(self.value)) if self.separate_routine() else self
|
||||
|
||||
Reference in New Issue
Block a user