[PERF] Qwen3-next MTP speedup (change bool mask indexing to index_select / index_copy to reduce d2h) (#26437)

Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
This commit is contained in:
Vadim Gimpelson
2025-10-16 08:18:31 +04:00
committed by GitHub
parent f6cdc9a02f
commit 785d8b6410
3 changed files with 56 additions and 36 deletions

View File

@@ -45,7 +45,7 @@ def tensor_cache(fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]
"""
cache_entries: tuple[tuple | None, dict | None, Any] = []
cache_size = 4
cache_size = 8
@functools.wraps(fn)
def wrapper(*args: Any, **kwargs: Any) -> Any: