[Feat][RL][1/2] Native Weight Syncing API: NCCL (#31943)
Signed-off-by: ahao-anyscale <ahao@anyscale.com> Signed-off-by: Aaron Hao <ahao@anyscale.com> Co-authored-by: SumanthRH <sumanthrh99@gmail.com>
This commit is contained in:
@@ -649,6 +649,9 @@ class Fp8OnlineLinearMethod(Fp8LinearMethod):
|
||||
)
|
||||
# Activations not quantized for marlin.
|
||||
|
||||
# Prevent duplicate processing (e.g., during weight reload)
|
||||
layer._already_called_process_weights_after_loading = True
|
||||
|
||||
|
||||
class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
"""MoE method for FP8.
|
||||
@@ -908,6 +911,9 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
layer, w13, w2, w13_scale, w2_scale, w13_input_scale, w2_input_scale
|
||||
)
|
||||
|
||||
# Prevent duplicate processing (e.g., during weight reload)
|
||||
layer._already_called_process_weights_after_loading = True
|
||||
|
||||
def maybe_make_prepare_finalize(
|
||||
self,
|
||||
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
|
||||
@@ -1241,6 +1247,9 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod):
|
||||
layer.w2_input_scale,
|
||||
)
|
||||
|
||||
# Prevent duplicate processing (e.g., during weight reload)
|
||||
layer._already_called_process_weights_after_loading = True
|
||||
|
||||
|
||||
class Fp8KVCacheMethod(BaseKVCacheMethod):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user