[Feat][RL][1/2] Native Weight Syncing API: NCCL (#31943)

Signed-off-by: ahao-anyscale <ahao@anyscale.com>
Signed-off-by: Aaron Hao <ahao@anyscale.com>
Co-authored-by: SumanthRH <sumanthrh99@gmail.com>
This commit is contained in:
Aaron Hao
2026-02-05 09:13:23 -08:00
committed by GitHub
parent 82914d2ae8
commit c1858b7ec8
27 changed files with 2974 additions and 2 deletions

View File

@@ -649,6 +649,9 @@ class Fp8OnlineLinearMethod(Fp8LinearMethod):
)
# Activations not quantized for marlin.
# Prevent duplicate processing (e.g., during weight reload)
layer._already_called_process_weights_after_loading = True
class Fp8MoEMethod(FusedMoEMethodBase):
"""MoE method for FP8.
@@ -908,6 +911,9 @@ class Fp8MoEMethod(FusedMoEMethodBase):
layer, w13, w2, w13_scale, w2_scale, w13_input_scale, w2_input_scale
)
# Prevent duplicate processing (e.g., during weight reload)
layer._already_called_process_weights_after_loading = True
def maybe_make_prepare_finalize(
self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
@@ -1241,6 +1247,9 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod):
layer.w2_input_scale,
)
# Prevent duplicate processing (e.g., during weight reload)
layer._already_called_process_weights_after_loading = True
class Fp8KVCacheMethod(BaseKVCacheMethod):
"""