[EPLB] Simplify EPLB rearrange by only returning one map (#36267)
Signed-off-by: Sage Moore <sage@neuralmagic.com> Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
@@ -5,6 +5,7 @@ import numpy as np
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from vllm.distributed.eplb.eplb_state import compute_logical_maps
|
||||||
from vllm.distributed.eplb.policy.default import DefaultEplbPolicy
|
from vllm.distributed.eplb.policy.default import DefaultEplbPolicy
|
||||||
|
|
||||||
|
|
||||||
@@ -24,9 +25,10 @@ def test_basic_rebalance():
|
|||||||
num_nodes = 2
|
num_nodes = 2
|
||||||
num_gpus = 8
|
num_gpus = 8
|
||||||
|
|
||||||
phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts(
|
phy2log = DefaultEplbPolicy.rebalance_experts(
|
||||||
weight, num_replicas, num_groups, num_nodes, num_gpus
|
weight, num_replicas, num_groups, num_nodes, num_gpus
|
||||||
)
|
)
|
||||||
|
log2phy, logcnt = compute_logical_maps(phy2log, weight.shape[-1])
|
||||||
|
|
||||||
# Verify output shapes
|
# Verify output shapes
|
||||||
assert phy2log.shape == (
|
assert phy2log.shape == (
|
||||||
@@ -78,9 +80,10 @@ def test_single_gpu_case():
|
|||||||
num_nodes = 1
|
num_nodes = 1
|
||||||
num_gpus = 1
|
num_gpus = 1
|
||||||
|
|
||||||
phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts(
|
phy2log = DefaultEplbPolicy.rebalance_experts(
|
||||||
weight, num_replicas, num_groups, num_nodes, num_gpus
|
weight, num_replicas, num_groups, num_nodes, num_gpus
|
||||||
)
|
)
|
||||||
|
log2phy, logcnt = compute_logical_maps(phy2log, weight.shape[-1])
|
||||||
|
|
||||||
# Verify shapes
|
# Verify shapes
|
||||||
assert phy2log.shape == (1, 4)
|
assert phy2log.shape == (1, 4)
|
||||||
@@ -100,9 +103,10 @@ def test_equal_weights():
|
|||||||
num_nodes = 2
|
num_nodes = 2
|
||||||
num_gpus = 4
|
num_gpus = 4
|
||||||
|
|
||||||
phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts(
|
phy2log = DefaultEplbPolicy.rebalance_experts(
|
||||||
weight, num_replicas, num_groups, num_nodes, num_gpus
|
weight, num_replicas, num_groups, num_nodes, num_gpus
|
||||||
)
|
)
|
||||||
|
_, logcnt = compute_logical_maps(phy2log, weight.shape[-1])
|
||||||
|
|
||||||
# Verify shapes
|
# Verify shapes
|
||||||
assert phy2log.shape == (1, 8)
|
assert phy2log.shape == (1, 8)
|
||||||
@@ -123,9 +127,10 @@ def test_extreme_weight_imbalance():
|
|||||||
num_nodes = 2
|
num_nodes = 2
|
||||||
num_gpus = 4
|
num_gpus = 4
|
||||||
|
|
||||||
phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts(
|
phy2log = DefaultEplbPolicy.rebalance_experts(
|
||||||
weight, num_replicas, num_groups, num_nodes, num_gpus
|
weight, num_replicas, num_groups, num_nodes, num_gpus
|
||||||
)
|
)
|
||||||
|
_, logcnt = compute_logical_maps(phy2log, weight.shape[-1])
|
||||||
|
|
||||||
# Verify shapes
|
# Verify shapes
|
||||||
assert phy2log.shape == (1, 12)
|
assert phy2log.shape == (1, 12)
|
||||||
@@ -151,9 +156,10 @@ def test_multiple_layers():
|
|||||||
num_nodes = 2
|
num_nodes = 2
|
||||||
num_gpus = 4
|
num_gpus = 4
|
||||||
|
|
||||||
phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts(
|
phy2log = DefaultEplbPolicy.rebalance_experts(
|
||||||
weight, num_replicas, num_groups, num_nodes, num_gpus
|
weight, num_replicas, num_groups, num_nodes, num_gpus
|
||||||
)
|
)
|
||||||
|
_, logcnt = compute_logical_maps(phy2log, weight.shape[-1])
|
||||||
|
|
||||||
# Verify shapes
|
# Verify shapes
|
||||||
assert phy2log.shape == (3, 8)
|
assert phy2log.shape == (3, 8)
|
||||||
@@ -176,7 +182,8 @@ def test_parameter_validation():
|
|||||||
# Test non-divisible case - this should handle normally without throwing
|
# Test non-divisible case - this should handle normally without throwing
|
||||||
# errors because the function will fall back to global load balancing
|
# errors because the function will fall back to global load balancing
|
||||||
# strategy
|
# strategy
|
||||||
phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts(weight, 8, 3, 2, 4)
|
phy2log = DefaultEplbPolicy.rebalance_experts(weight, 8, 3, 2, 4)
|
||||||
|
_, logcnt = compute_logical_maps(phy2log, weight.shape[-1])
|
||||||
assert phy2log.shape == (1, 8)
|
assert phy2log.shape == (1, 8)
|
||||||
assert logcnt.shape == (1, 4)
|
assert logcnt.shape == (1, 4)
|
||||||
|
|
||||||
@@ -198,9 +205,10 @@ def test_small_scale_hierarchical():
|
|||||||
num_nodes = 2 # 2 nodes
|
num_nodes = 2 # 2 nodes
|
||||||
num_gpus = 4 # 4 GPUs
|
num_gpus = 4 # 4 GPUs
|
||||||
|
|
||||||
phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts(
|
phy2log = DefaultEplbPolicy.rebalance_experts(
|
||||||
weight, num_replicas, num_groups, num_nodes, num_gpus
|
weight, num_replicas, num_groups, num_nodes, num_gpus
|
||||||
)
|
)
|
||||||
|
_, logcnt = compute_logical_maps(phy2log, weight.shape[-1])
|
||||||
|
|
||||||
# Verify basic constraints
|
# Verify basic constraints
|
||||||
assert phy2log.shape == (1, 12)
|
assert phy2log.shape == (1, 12)
|
||||||
@@ -225,9 +233,10 @@ def test_global_load_balance_fallback():
|
|||||||
num_nodes = 2
|
num_nodes = 2
|
||||||
num_gpus = 4
|
num_gpus = 4
|
||||||
|
|
||||||
phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts(
|
phy2log = DefaultEplbPolicy.rebalance_experts(
|
||||||
weight, num_replicas, num_groups, num_nodes, num_gpus
|
weight, num_replicas, num_groups, num_nodes, num_gpus
|
||||||
)
|
)
|
||||||
|
_, logcnt = compute_logical_maps(phy2log, weight.shape[-1])
|
||||||
|
|
||||||
# Should work normally, just using global load balancing strategy
|
# Should work normally, just using global load balancing strategy
|
||||||
assert phy2log.shape == (1, 8)
|
assert phy2log.shape == (1, 8)
|
||||||
@@ -247,9 +256,10 @@ def test_device_compatibility(device):
|
|||||||
num_nodes = 1
|
num_nodes = 1
|
||||||
num_gpus = 2
|
num_gpus = 2
|
||||||
|
|
||||||
phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts(
|
phy2log = DefaultEplbPolicy.rebalance_experts(
|
||||||
weight, num_replicas, num_groups, num_nodes, num_gpus
|
weight, num_replicas, num_groups, num_nodes, num_gpus
|
||||||
)
|
)
|
||||||
|
_, logcnt = compute_logical_maps(phy2log, weight.shape[-1])
|
||||||
|
|
||||||
# Function will convert to CPU internally, but should handle different
|
# Function will convert to CPU internally, but should handle different
|
||||||
# device inputs normally
|
# device inputs normally
|
||||||
@@ -264,9 +274,8 @@ def test_additional_cases():
|
|||||||
weight1 = torch.tensor(
|
weight1 = torch.tensor(
|
||||||
[[50, 100, 75, 120, 90, 60, 80, 110, 40, 70, 95, 85, 65, 55, 45, 35]]
|
[[50, 100, 75, 120, 90, 60, 80, 110, 40, 70, 95, 85, 65, 55, 45, 35]]
|
||||||
)
|
)
|
||||||
phy2log1, log2phy1, logcnt1 = DefaultEplbPolicy.rebalance_experts(
|
phy2log1 = DefaultEplbPolicy.rebalance_experts(weight1, 24, 8, 4, 8)
|
||||||
weight1, 24, 8, 4, 8
|
_, logcnt1 = compute_logical_maps(phy2log1, weight1.shape[-1])
|
||||||
)
|
|
||||||
|
|
||||||
assert phy2log1.shape == (1, 24)
|
assert phy2log1.shape == (1, 24)
|
||||||
assert logcnt1.shape == (1, 16)
|
assert logcnt1.shape == (1, 16)
|
||||||
@@ -279,9 +288,8 @@ def test_additional_cases():
|
|||||||
[12, 25, 50, 100, 150, 200], # Increasing weights
|
[12, 25, 50, 100, 150, 200], # Increasing weights
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
phy2log2, log2phy2, logcnt2 = DefaultEplbPolicy.rebalance_experts(
|
phy2log2 = DefaultEplbPolicy.rebalance_experts(weight2, 10, 3, 1, 2)
|
||||||
weight2, 10, 3, 1, 2
|
_, logcnt2 = compute_logical_maps(phy2log2, weight2.shape[-1])
|
||||||
)
|
|
||||||
|
|
||||||
assert phy2log2.shape == (2, 10)
|
assert phy2log2.shape == (2, 10)
|
||||||
assert logcnt2.shape == (2, 6)
|
assert logcnt2.shape == (2, 6)
|
||||||
@@ -292,6 +300,42 @@ def test_additional_cases():
|
|||||||
assert logcnt2[layer, max_weight_idx] >= 2
|
assert logcnt2[layer, max_weight_idx] >= 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_compute_logical_maps_with_negative_indices():
|
||||||
|
"""
|
||||||
|
Test that compute_logical_maps correctly handles physical slots containing
|
||||||
|
-1 (unused slots).
|
||||||
|
"""
|
||||||
|
# 2 layers, 6 physical slots, 4 logical experts.
|
||||||
|
# Slots 2 and 5 are unused (-1).
|
||||||
|
phy2log = torch.tensor(
|
||||||
|
[
|
||||||
|
[0, 1, -1, 2, 3, -1],
|
||||||
|
[3, -1, 2, 1, 0, -1],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
num_layers = 2
|
||||||
|
num_logical_experts = 4
|
||||||
|
|
||||||
|
log2phy, logcnt = compute_logical_maps(phy2log, num_logical_experts)
|
||||||
|
|
||||||
|
assert logcnt.shape == (num_layers, num_logical_experts)
|
||||||
|
assert log2phy.shape == (num_layers, num_logical_experts, 1)
|
||||||
|
|
||||||
|
expected_logcnt = torch.ones(num_layers, num_logical_experts, dtype=phy2log.dtype)
|
||||||
|
assert torch.all(logcnt == expected_logcnt), (
|
||||||
|
f"Expected that all replica counts == 1, got {logcnt}"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert torch.all(log2phy >= 0), (
|
||||||
|
"log2phy should only contain valid physical indices, not -1"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert log2phy[0, 0, 0] == 0
|
||||||
|
assert log2phy[0, 1, 0] == 1
|
||||||
|
assert log2phy[0, 2, 0] == 3
|
||||||
|
assert log2phy[0, 3, 0] == 4
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
weight = torch.tensor(
|
weight = torch.tensor(
|
||||||
[
|
[
|
||||||
@@ -305,7 +349,7 @@ if __name__ == "__main__":
|
|||||||
num_nodes = 2
|
num_nodes = 2
|
||||||
num_gpus = 8
|
num_gpus = 8
|
||||||
|
|
||||||
phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts(
|
phy2log = DefaultEplbPolicy.rebalance_experts(
|
||||||
weight, num_replicas, num_groups, num_nodes, num_gpus
|
weight, num_replicas, num_groups, num_nodes, num_gpus
|
||||||
)
|
)
|
||||||
print(phy2log)
|
print(phy2log)
|
||||||
@@ -434,9 +478,10 @@ def test_preserve_intragpu_slots(
|
|||||||
"""Experts that stay on a GPU keep their old slots; incoming not lost."""
|
"""Experts that stay on a GPU keep their old slots; incoming not lost."""
|
||||||
phy_replicas_idx = _make_phy_replicas_idx_from_phy2log(new_phy2log)
|
phy_replicas_idx = _make_phy_replicas_idx_from_phy2log(new_phy2log)
|
||||||
|
|
||||||
post_phy2log, post_phy_replicas_idx = DefaultEplbPolicy.preserve_intragpu_slots(
|
post_phy2log = DefaultEplbPolicy.preserve_intragpu_slots(
|
||||||
new_phy2log, phy_replicas_idx, num_ranks, old_phy2log
|
new_phy2log, num_ranks, old_phy2log
|
||||||
)
|
)
|
||||||
|
post_phy_replicas_idx = _make_phy_replicas_idx_from_phy2log(post_phy2log)
|
||||||
|
|
||||||
# Shapes preserved
|
# Shapes preserved
|
||||||
assert post_phy2log.shape == new_phy2log.shape
|
assert post_phy2log.shape == new_phy2log.shape
|
||||||
|
|||||||
@@ -73,11 +73,7 @@ def run_rebalance_experts(
|
|||||||
# Move the global expert load window to CPU for computation.
|
# Move the global expert load window to CPU for computation.
|
||||||
global_expert_load_window = eplb_stats.global_expert_load_window.cpu()
|
global_expert_load_window = eplb_stats.global_expert_load_window.cpu()
|
||||||
# Compute new expert mappings for the model
|
# Compute new expert mappings for the model
|
||||||
(
|
new_physical_to_logical_map = eplb_state.policy.rebalance_experts(
|
||||||
new_physical_to_logical_map,
|
|
||||||
new_logical_to_physical_map,
|
|
||||||
new_logical_replica_count,
|
|
||||||
) = eplb_state.policy.rebalance_experts(
|
|
||||||
global_expert_load_window,
|
global_expert_load_window,
|
||||||
eplb_stats.num_replicas,
|
eplb_stats.num_replicas,
|
||||||
eplb_stats.num_groups,
|
eplb_stats.num_groups,
|
||||||
@@ -89,16 +85,6 @@ def run_rebalance_experts(
|
|||||||
|
|
||||||
model_state.new_physical_to_logical_map = new_physical_to_logical_map
|
model_state.new_physical_to_logical_map = new_physical_to_logical_map
|
||||||
|
|
||||||
max_slots = model_state.logical_to_physical_map.shape[-1]
|
|
||||||
padded_logical = torch.nn.functional.pad(
|
|
||||||
new_logical_to_physical_map,
|
|
||||||
(0, max(0, max_slots - new_logical_to_physical_map.shape[-1])),
|
|
||||||
value=-1,
|
|
||||||
).to(model_state.logical_to_physical_map.device)
|
|
||||||
new_replica = new_logical_replica_count.to(model_state.logical_replica_count.device)
|
|
||||||
model_state.new_logical_to_physical_map = padded_logical
|
|
||||||
model_state.new_logical_replica_count = new_replica
|
|
||||||
|
|
||||||
|
|
||||||
async def transfer_run_periodically(
|
async def transfer_run_periodically(
|
||||||
state: "EplbState",
|
state: "EplbState",
|
||||||
|
|||||||
@@ -235,16 +235,6 @@ class EplbModelState:
|
|||||||
intermediate variable between `move_to_buffer` and `move_to_workspace`.
|
intermediate variable between `move_to_buffer` and `move_to_workspace`.
|
||||||
the size is same as physical_to_logical_map
|
the size is same as physical_to_logical_map
|
||||||
"""
|
"""
|
||||||
new_logical_to_physical_map: torch.Tensor | None = None
|
|
||||||
"""
|
|
||||||
intermediate variable between `move_to_buffer` and `move_to_workspace`.
|
|
||||||
the size is same as logical_to_physical_map
|
|
||||||
"""
|
|
||||||
new_logical_replica_count: torch.Tensor | None = None
|
|
||||||
"""
|
|
||||||
intermediate variable between `move_to_buffer` and `move_to_workspace`.
|
|
||||||
the size is same as logical_replica_count
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
class EplbState:
|
class EplbState:
|
||||||
@@ -508,8 +498,6 @@ class EplbState:
|
|||||||
),
|
),
|
||||||
cuda_device_index=self.cuda_device_index,
|
cuda_device_index=self.cuda_device_index,
|
||||||
new_physical_to_logical_map=None,
|
new_physical_to_logical_map=None,
|
||||||
new_logical_to_physical_map=None,
|
|
||||||
new_logical_replica_count=None,
|
|
||||||
)
|
)
|
||||||
self.model_states[model_config.compute_hash()] = model_state
|
self.model_states[model_config.compute_hash()] = model_state
|
||||||
self.num_valid_physical_experts = model.num_physical_experts
|
self.num_valid_physical_experts = model.num_physical_experts
|
||||||
@@ -738,17 +726,20 @@ class EplbState:
|
|||||||
):
|
):
|
||||||
if not self.is_async or is_profile:
|
if not self.is_async or is_profile:
|
||||||
# Get new expert mappings for the model
|
# Get new expert mappings for the model
|
||||||
(
|
new_physical_to_logical_map = self.policy.rebalance_experts(
|
||||||
new_physical_to_logical_map,
|
global_expert_load_window.cpu(),
|
||||||
new_logical_to_physical_map,
|
|
||||||
new_logical_replica_count,
|
|
||||||
) = self.policy.rebalance_experts(
|
|
||||||
global_expert_load_window,
|
|
||||||
num_replicas,
|
num_replicas,
|
||||||
num_groups,
|
num_groups,
|
||||||
num_nodes,
|
num_nodes,
|
||||||
num_gpus,
|
num_gpus,
|
||||||
eplb_model_state.physical_to_logical_map,
|
eplb_model_state.physical_to_logical_map.cpu(),
|
||||||
|
)
|
||||||
|
|
||||||
|
num_logical_experts = global_expert_load_window.shape[-1]
|
||||||
|
(new_logical_to_physical_map, new_logical_replica_count) = (
|
||||||
|
compute_logical_maps(
|
||||||
|
new_physical_to_logical_map, num_logical_experts
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Update expert weights
|
# Update expert weights
|
||||||
@@ -847,11 +838,7 @@ class EplbState:
|
|||||||
def _update_layer_mapping_from_new(
|
def _update_layer_mapping_from_new(
|
||||||
self, model_state: EplbModelState, layer: int
|
self, model_state: EplbModelState, layer: int
|
||||||
) -> None:
|
) -> None:
|
||||||
if (
|
if model_state.new_physical_to_logical_map is None:
|
||||||
model_state.new_physical_to_logical_map is None
|
|
||||||
or model_state.new_logical_to_physical_map is None
|
|
||||||
or model_state.new_logical_replica_count is None
|
|
||||||
):
|
|
||||||
return
|
return
|
||||||
|
|
||||||
target_device = model_state.physical_to_logical_map.device
|
target_device = model_state.physical_to_logical_map.device
|
||||||
@@ -865,19 +852,23 @@ class EplbState:
|
|||||||
new_physical[layer].to(target_device, non_blocking=True)
|
new_physical[layer].to(target_device, non_blocking=True)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
num_logical_experts = model_state.logical_to_physical_map.shape[1]
|
||||||
|
new_logical, new_replica_count = compute_logical_maps(
|
||||||
|
new_physical[layer], num_logical_experts
|
||||||
|
)
|
||||||
|
|
||||||
logical_device = model_state.logical_to_physical_map.device
|
logical_device = model_state.logical_to_physical_map.device
|
||||||
new_logical = model_state.new_logical_to_physical_map[layer].to(logical_device)
|
|
||||||
max_slots = model_state.logical_to_physical_map.shape[-1]
|
max_slots = model_state.logical_to_physical_map.shape[-1]
|
||||||
slot_delta = max_slots - new_logical.shape[-1]
|
slot_delta = max_slots - new_logical.shape[-1]
|
||||||
if slot_delta > 0:
|
if slot_delta > 0:
|
||||||
new_logical = torch.nn.functional.pad(
|
new_logical = torch.nn.functional.pad(
|
||||||
new_logical, (0, slot_delta), value=-1
|
new_logical, (0, slot_delta), value=-1
|
||||||
)
|
)
|
||||||
model_state.logical_to_physical_map[layer].copy_(new_logical)
|
model_state.logical_to_physical_map[layer].copy_(new_logical.to(logical_device))
|
||||||
|
|
||||||
replica_device = model_state.logical_replica_count.device
|
replica_device = model_state.logical_replica_count.device
|
||||||
model_state.logical_replica_count[layer].copy_(
|
model_state.logical_replica_count[layer].copy_(
|
||||||
model_state.new_logical_replica_count[layer].to(replica_device)
|
new_replica_count.to(replica_device)
|
||||||
)
|
)
|
||||||
|
|
||||||
def _all_ranks_buffer_ready(self, model_state: EplbModelState) -> bool:
|
def _all_ranks_buffer_ready(self, model_state: EplbModelState) -> bool:
|
||||||
@@ -966,7 +957,7 @@ class EplbState:
|
|||||||
transferred_layer,
|
transferred_layer,
|
||||||
)
|
)
|
||||||
if model_state.layer_to_transfer >= model_state.model.num_moe_layers:
|
if model_state.layer_to_transfer >= model_state.model.num_moe_layers:
|
||||||
self.post_eplb(model_state, is_profile)
|
self.post_eplb(model_state)
|
||||||
model_state.rebalanced = False
|
model_state.rebalanced = False
|
||||||
model_state.layer_to_transfer = 0
|
model_state.layer_to_transfer = 0
|
||||||
model_state.pending_global_ready_check = False
|
model_state.pending_global_ready_check = False
|
||||||
@@ -987,14 +978,9 @@ class EplbState:
|
|||||||
str(e),
|
str(e),
|
||||||
)
|
)
|
||||||
|
|
||||||
def post_eplb(self, model_state: EplbModelState, is_profile: bool = False) -> None:
|
def post_eplb(self, model_state: EplbModelState) -> None:
|
||||||
assert model_state.new_physical_to_logical_map is not None
|
assert model_state.new_physical_to_logical_map is not None
|
||||||
assert model_state.new_logical_to_physical_map is not None
|
|
||||||
assert model_state.new_logical_replica_count is not None
|
|
||||||
|
|
||||||
model_state.new_physical_to_logical_map = None
|
model_state.new_physical_to_logical_map = None
|
||||||
model_state.new_logical_to_physical_map = None
|
|
||||||
model_state.new_logical_replica_count = None
|
|
||||||
|
|
||||||
def _allreduce_list(self, tensor_list: list[torch.Tensor]) -> list[torch.Tensor]:
|
def _allreduce_list(self, tensor_list: list[torch.Tensor]) -> list[torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
@@ -1052,39 +1038,28 @@ class EplbState:
|
|||||||
model_config=model_config,
|
model_config=model_config,
|
||||||
)
|
)
|
||||||
eplb_state.num_valid_physical_experts = num_valid_physical_experts
|
eplb_state.num_valid_physical_experts = num_valid_physical_experts
|
||||||
num_moe_layers = expanded_physical_to_logical.shape[0]
|
|
||||||
num_physical_experts = expanded_physical_to_logical.shape[1]
|
|
||||||
eplb_model_state = eplb_state.model_states[model_config.compute_hash()]
|
eplb_model_state = eplb_state.model_states[model_config.compute_hash()]
|
||||||
eplb_model_state.physical_to_logical_map.copy_(expanded_physical_to_logical)
|
eplb_model_state.physical_to_logical_map.copy_(expanded_physical_to_logical)
|
||||||
|
|
||||||
logical_to_physical_map = torch.full(
|
(logical_to_physical_map_cpu, logical_replica_count_cpu) = compute_logical_maps(
|
||||||
(
|
expanded_physical_to_logical.cpu(), model.num_logical_experts
|
||||||
num_moe_layers,
|
|
||||||
model.num_logical_experts,
|
|
||||||
eplb_model_state.logical_to_physical_map.shape[2],
|
|
||||||
),
|
|
||||||
-1,
|
|
||||||
dtype=torch.int64,
|
|
||||||
)
|
)
|
||||||
logical_replica_count = torch.zeros(
|
|
||||||
(num_moe_layers, model.num_logical_experts),
|
|
||||||
dtype=torch.int64,
|
|
||||||
)
|
|
||||||
expanded_physical_to_logical_numpy = expanded_physical_to_logical.cpu().numpy()
|
|
||||||
for layer_idx in range(num_moe_layers):
|
|
||||||
for phys_idx in range(num_physical_experts):
|
|
||||||
logical_idx = expanded_physical_to_logical_numpy[layer_idx, phys_idx]
|
|
||||||
if logical_idx >= 0:
|
|
||||||
replica_idx = logical_replica_count[layer_idx, logical_idx]
|
|
||||||
logical_to_physical_map[layer_idx, logical_idx, replica_idx] = (
|
|
||||||
phys_idx
|
|
||||||
)
|
|
||||||
logical_replica_count[layer_idx, logical_idx] += 1
|
|
||||||
|
|
||||||
logical_to_physical_map = logical_to_physical_map.to(device)
|
max_num_replicas = eplb_model_state.logical_to_physical_map.shape[-1]
|
||||||
logical_replica_count = logical_replica_count.to(device)
|
num_replicas = logical_to_physical_map_cpu.shape[-1]
|
||||||
|
logical_to_physical_map = torch.nn.functional.pad(
|
||||||
|
logical_to_physical_map_cpu,
|
||||||
|
(
|
||||||
|
0,
|
||||||
|
max_num_replicas - num_replicas,
|
||||||
|
),
|
||||||
|
value=-1,
|
||||||
|
).to(device)
|
||||||
|
logical_replica_count = logical_replica_count_cpu.to(device)
|
||||||
|
|
||||||
eplb_model_state.logical_to_physical_map.copy_(logical_to_physical_map)
|
eplb_model_state.logical_to_physical_map.copy_(logical_to_physical_map)
|
||||||
eplb_model_state.logical_replica_count.copy_(logical_replica_count)
|
eplb_model_state.logical_replica_count.copy_(logical_replica_count)
|
||||||
|
|
||||||
return eplb_state
|
return eplb_state
|
||||||
|
|
||||||
|
|
||||||
@@ -1132,3 +1107,82 @@ def _node_count_with_rank_mapping(
|
|||||||
node_assignment[other_rank] = next_node_id
|
node_assignment[other_rank] = next_node_id
|
||||||
|
|
||||||
return next_node_id
|
return next_node_id
|
||||||
|
|
||||||
|
|
||||||
|
def compute_logical_maps(
|
||||||
|
physical_to_logical_map: torch.Tensor,
|
||||||
|
num_logical_experts: int,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Derive logical_to_physical_map and logical_replica_count from
|
||||||
|
physical_to_logical_map.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
physical_to_logical_map: [num_layers, num_physical_experts], logical
|
||||||
|
expert index for each physical expert slot
|
||||||
|
num_logical_experts: total number of logical experts
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
logical_to_physical_map: [num_layers, num_logical_experts, max_replicas],
|
||||||
|
physical slots per logical expert; -1 where unused
|
||||||
|
logical_replica_count: [num_layers, num_logical_experts], number of
|
||||||
|
physical replicas per logical expert
|
||||||
|
"""
|
||||||
|
device = physical_to_logical_map.device
|
||||||
|
assert physical_to_logical_map.device.type == "cpu"
|
||||||
|
|
||||||
|
dtype = physical_to_logical_map.dtype
|
||||||
|
|
||||||
|
# If computing maps for a single layer, unsqueeze a single element layer dimension
|
||||||
|
per_layer = physical_to_logical_map.dim() == 1
|
||||||
|
physical_to_logical_map_view = physical_to_logical_map
|
||||||
|
if per_layer:
|
||||||
|
physical_to_logical_map_view = physical_to_logical_map.unsqueeze(0)
|
||||||
|
assert len(physical_to_logical_map_view.shape) == 2
|
||||||
|
num_layers, num_physical = physical_to_logical_map_view.shape
|
||||||
|
|
||||||
|
valid_mask = physical_to_logical_map_view >= 0
|
||||||
|
logical_replica_count = torch.zeros(
|
||||||
|
num_layers,
|
||||||
|
num_logical_experts,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
logical_replica_count.scatter_add_(
|
||||||
|
1,
|
||||||
|
physical_to_logical_map_view.clamp(min=0),
|
||||||
|
valid_mask.to(dtype),
|
||||||
|
)
|
||||||
|
|
||||||
|
max_replicas = int(logical_replica_count.max().item())
|
||||||
|
logical_to_physical_map_out = torch.full(
|
||||||
|
(num_layers, num_logical_experts, max_replicas),
|
||||||
|
-1,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
running_count = torch.zeros_like(logical_replica_count)
|
||||||
|
layer_indices = torch.arange(num_layers, device=device)
|
||||||
|
for phys_idx in range(num_physical):
|
||||||
|
# Logical expert at physical slot phys_idx for each layer
|
||||||
|
logical_expert_ids = physical_to_logical_map_view[:, phys_idx] # [num_layers]
|
||||||
|
|
||||||
|
# Scale up will set the logical expert ids to -1 for all new physical experts.
|
||||||
|
# Only consider "valid" experts when setting up the logical_to_physical map.
|
||||||
|
valid_expert_mask = logical_expert_ids >= 0
|
||||||
|
if not valid_expert_mask.any():
|
||||||
|
continue
|
||||||
|
valid_layers = layer_indices[valid_expert_mask]
|
||||||
|
valid_experts = logical_expert_ids[valid_expert_mask]
|
||||||
|
|
||||||
|
# Use the current running count as the replica index, then increment it.
|
||||||
|
replica_idx = running_count[valid_layers, valid_experts]
|
||||||
|
logical_to_physical_map_out[valid_layers, valid_experts, replica_idx] = phys_idx
|
||||||
|
running_count[valid_layers, valid_experts] += 1
|
||||||
|
|
||||||
|
# If computing maps for a single layer, squeeze out the extra layer dimension
|
||||||
|
# before returning
|
||||||
|
if per_layer:
|
||||||
|
return logical_to_physical_map_out.squeeze(0), logical_replica_count.squeeze(0)
|
||||||
|
return logical_to_physical_map_out, logical_replica_count
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ class AbstractEplbPolicy(ABC):
|
|||||||
num_nodes: int,
|
num_nodes: int,
|
||||||
num_ranks: int,
|
num_ranks: int,
|
||||||
old_global_expert_indices: torch.Tensor | None = None,
|
old_global_expert_indices: torch.Tensor | None = None,
|
||||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Entry point for expert-parallelism load balancer.
|
Entry point for expert-parallelism load balancer.
|
||||||
|
|
||||||
@@ -35,9 +35,5 @@ class AbstractEplbPolicy(ABC):
|
|||||||
Returns:
|
Returns:
|
||||||
physical_to_logical_map: [layers, num_replicas], the expert
|
physical_to_logical_map: [layers, num_replicas], the expert
|
||||||
index of each replica
|
index of each replica
|
||||||
logical_to_physical_map: [layers, num_logical_experts, X],
|
|
||||||
the replica indices for each expert
|
|
||||||
expert_count: [layers, num_logical_experts], number of
|
|
||||||
physical replicas for each logical expert
|
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|||||||
@@ -75,7 +75,7 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def replicate_experts(
|
def replicate_experts(
|
||||||
cls, weight: np.ndarray, num_phy: int
|
cls, weight: np.ndarray, num_phy: int
|
||||||
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
) -> tuple[np.ndarray, np.ndarray]:
|
||||||
"""
|
"""
|
||||||
Replicate `num_log` experts to `num_phy` replicas, such that the maximum
|
Replicate `num_log` experts to `num_phy` replicas, such that the maximum
|
||||||
load of all replicas is minimized.
|
load of all replicas is minimized.
|
||||||
@@ -86,22 +86,19 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
phy2log: [X, num_phy], logical expert id of each physical expert
|
phy2log: [X, num_phy], logical expert id of each physical expert
|
||||||
replica_idx: [X, num_phy], the index of the replica for each logical expert
|
|
||||||
logcnt: [X, num_log], number of replicas for each logical expert
|
logcnt: [X, num_log], number of replicas for each logical expert
|
||||||
"""
|
"""
|
||||||
n, num_log = weight.shape
|
n, num_log = weight.shape
|
||||||
num_redundant = num_phy - num_log
|
num_redundant = num_phy - num_log
|
||||||
assert num_redundant >= 0
|
assert num_redundant >= 0
|
||||||
phy2log = np.tile(np.arange(num_phy, dtype=np.int64), (n, 1))
|
phy2log = np.tile(np.arange(num_phy, dtype=np.int64), (n, 1))
|
||||||
replica_idx = np.zeros((n, num_phy), dtype=np.int64)
|
|
||||||
logcnt = np.ones((n, num_log), dtype=np.int64)
|
logcnt = np.ones((n, num_log), dtype=np.int64)
|
||||||
arangen = np.arange(n, dtype=np.int64)
|
arangen = np.arange(n, dtype=np.int64)
|
||||||
for i in range(num_log, num_phy):
|
for i in range(num_log, num_phy):
|
||||||
redundant_indices = np.argmax(weight / logcnt, axis=-1)
|
redundant_indices = np.argmax(weight / logcnt, axis=-1)
|
||||||
phy2log[:, i] = redundant_indices
|
phy2log[:, i] = redundant_indices
|
||||||
replica_idx[:, i] = logcnt[arangen, redundant_indices]
|
|
||||||
logcnt[arangen, redundant_indices] += 1
|
logcnt[arangen, redundant_indices] += 1
|
||||||
return phy2log, replica_idx, logcnt
|
return phy2log, logcnt
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def rebalance_experts_hierarchical(
|
def rebalance_experts_hierarchical(
|
||||||
@@ -111,7 +108,7 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
|
|||||||
num_groups: int,
|
num_groups: int,
|
||||||
num_nodes: int,
|
num_nodes: int,
|
||||||
num_gpus: int,
|
num_gpus: int,
|
||||||
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
) -> np.ndarray:
|
||||||
"""
|
"""
|
||||||
Parameters:
|
Parameters:
|
||||||
weight: [num_moe_layers, num_logical_experts]
|
weight: [num_moe_layers, num_logical_experts]
|
||||||
@@ -124,10 +121,6 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
|
|||||||
Returns:
|
Returns:
|
||||||
phy2log: [layers, num_replicas], the expert
|
phy2log: [layers, num_replicas], the expert
|
||||||
index of each replica
|
index of each replica
|
||||||
pphy_replicas_idx: [layers, num_logical_experts, X],
|
|
||||||
the replica indices for each expert
|
|
||||||
logcnt: [layers, num_logical_experts], number of
|
|
||||||
physical replicas for each logical expert
|
|
||||||
"""
|
"""
|
||||||
num_layers, num_logical_experts = weight.shape
|
num_layers, num_logical_experts = weight.shape
|
||||||
assert num_logical_experts % num_groups == 0
|
assert num_logical_experts % num_groups == 0
|
||||||
@@ -167,7 +160,7 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
|
|||||||
tokens_per_mlog = np.take_along_axis(weight, mlog2log, axis=1).reshape(
|
tokens_per_mlog = np.take_along_axis(weight, mlog2log, axis=1).reshape(
|
||||||
-1, num_logical_experts // num_nodes
|
-1, num_logical_experts // num_nodes
|
||||||
)
|
)
|
||||||
phy2mlog, replicas_idx, mlogcnt = cls.replicate_experts(
|
phy2mlog, mlogcnt = cls.replicate_experts(
|
||||||
tokens_per_mlog, num_physical_experts // num_nodes
|
tokens_per_mlog, num_physical_experts // num_nodes
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -193,22 +186,15 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
|
|||||||
).reshape(num_layers, -1)
|
).reshape(num_layers, -1)
|
||||||
# Map node-local logical indices back to global logical expert ids.
|
# Map node-local logical indices back to global logical expert ids.
|
||||||
pphy2log = np.take_along_axis(mlog2log, pphy2mlog, axis=1)
|
pphy2log = np.take_along_axis(mlog2log, pphy2mlog, axis=1)
|
||||||
# Reorder replica ranks to the post-packing physical ordering.
|
return pphy2log
|
||||||
pphy_replicas_idx = np.take_along_axis(replicas_idx, pphy2phy, axis=1).reshape(
|
|
||||||
num_layers, -1
|
|
||||||
)
|
|
||||||
# Convert replica counts back to the original logical ordering.
|
|
||||||
logcnt = np.take_along_axis(mlogcnt.reshape(num_layers, -1), log2mlog, axis=1)
|
|
||||||
return pphy2log, pphy_replicas_idx, logcnt
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def preserve_intragpu_slots(
|
def preserve_intragpu_slots(
|
||||||
cls,
|
cls,
|
||||||
phy2log: np.ndarray,
|
phy2log: np.ndarray,
|
||||||
phy_replicas_idx: np.ndarray,
|
|
||||||
num_ranks: int,
|
num_ranks: int,
|
||||||
old_phy2log: np.ndarray,
|
old_phy2log: np.ndarray,
|
||||||
) -> tuple[np.ndarray, np.ndarray]:
|
) -> np.ndarray:
|
||||||
"""
|
"""
|
||||||
Reorder the new mapping per GPU so that experts that remain on the same GPU
|
Reorder the new mapping per GPU so that experts that remain on the same GPU
|
||||||
keep their previous slot positions when possible. Incoming experts to that GPU
|
keep their previous slot positions when possible. Incoming experts to that GPU
|
||||||
@@ -218,14 +204,13 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
|
|||||||
"""
|
"""
|
||||||
num_phy_experts = phy2log.shape[1]
|
num_phy_experts = phy2log.shape[1]
|
||||||
if num_ranks <= 0 or num_phy_experts % num_ranks != 0:
|
if num_ranks <= 0 or num_phy_experts % num_ranks != 0:
|
||||||
return phy2log, phy_replicas_idx
|
return phy2log
|
||||||
|
|
||||||
# Move to CPU and convert to NumPy for processing
|
# Move to CPU and convert to NumPy for processing
|
||||||
slots_per_gpu = num_phy_experts // num_ranks
|
slots_per_gpu = num_phy_experts // num_ranks
|
||||||
num_layers = phy2log.shape[0]
|
num_layers = phy2log.shape[0]
|
||||||
|
|
||||||
post_phy2log = phy2log.copy()
|
post_phy2log = phy2log.copy()
|
||||||
post_phy_replicas_idx = phy_replicas_idx.copy()
|
|
||||||
|
|
||||||
for gpu_idx in range(num_ranks):
|
for gpu_idx in range(num_ranks):
|
||||||
start = gpu_idx * slots_per_gpu
|
start = gpu_idx * slots_per_gpu
|
||||||
@@ -233,7 +218,6 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
|
|||||||
# Experts across all layers for this GPU
|
# Experts across all layers for this GPU
|
||||||
old_local = old_phy2log[:, start:end] # [layers, slots]
|
old_local = old_phy2log[:, start:end] # [layers, slots]
|
||||||
new_local = phy2log[:, start:end] # [layers, slots]
|
new_local = phy2log[:, start:end] # [layers, slots]
|
||||||
new_ridx = phy_replicas_idx[:, start:end] # [layers, slots]
|
|
||||||
|
|
||||||
used_new_indices = np.zeros((num_layers, slots_per_gpu), dtype=bool)
|
used_new_indices = np.zeros((num_layers, slots_per_gpu), dtype=bool)
|
||||||
preserved_positions = np.zeros((num_layers, slots_per_gpu), dtype=bool)
|
preserved_positions = np.zeros((num_layers, slots_per_gpu), dtype=bool)
|
||||||
@@ -253,9 +237,6 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
|
|||||||
post_phy2log[layer_indices, start + slot_idx] = new_local[
|
post_phy2log[layer_indices, start + slot_idx] = new_local[
|
||||||
layer_indices, matched_new_positions
|
layer_indices, matched_new_positions
|
||||||
]
|
]
|
||||||
post_phy_replicas_idx[layer_indices, start + slot_idx] = new_ridx[
|
|
||||||
layer_indices, matched_new_positions
|
|
||||||
]
|
|
||||||
used_new_indices[layer_indices, matched_new_positions] = True
|
used_new_indices[layer_indices, matched_new_positions] = True
|
||||||
preserved_positions[layer_indices, slot_idx] = True
|
preserved_positions[layer_indices, slot_idx] = True
|
||||||
|
|
||||||
@@ -287,11 +268,8 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
|
|||||||
post_phy2log[layer_idx, start + dst_pos] = new_local[
|
post_phy2log[layer_idx, start + dst_pos] = new_local[
|
||||||
layer_idx, src_pos
|
layer_idx, src_pos
|
||||||
]
|
]
|
||||||
post_phy_replicas_idx[layer_idx, start + dst_pos] = new_ridx[
|
|
||||||
layer_idx, src_pos
|
|
||||||
]
|
|
||||||
|
|
||||||
return post_phy2log, post_phy_replicas_idx
|
return post_phy2log
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def rebalance_experts(
|
def rebalance_experts(
|
||||||
@@ -302,7 +280,7 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
|
|||||||
num_nodes: int,
|
num_nodes: int,
|
||||||
num_ranks: int,
|
num_ranks: int,
|
||||||
old_global_expert_indices: torch.Tensor | None = None,
|
old_global_expert_indices: torch.Tensor | None = None,
|
||||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Entry point for expert-parallelism load balancer.
|
Entry point for expert-parallelism load balancer.
|
||||||
|
|
||||||
@@ -321,13 +299,7 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
|
|||||||
Returns:
|
Returns:
|
||||||
phy2log: [layers, num_replicas], the expert
|
phy2log: [layers, num_replicas], the expert
|
||||||
index of each replica
|
index of each replica
|
||||||
log2phy: [layers, num_logical_experts, X],
|
|
||||||
the replica indices for each expert
|
|
||||||
logcnt: [layers, num_logical_experts], number of
|
|
||||||
physical replicas for each logical expert
|
|
||||||
"""
|
"""
|
||||||
device = weight.device
|
|
||||||
num_layers, num_logical_experts = weight.shape
|
|
||||||
weight_np = weight.float().cpu().numpy()
|
weight_np = weight.float().cpu().numpy()
|
||||||
old_phy2log_np = (
|
old_phy2log_np = (
|
||||||
old_global_expert_indices.cpu().numpy()
|
old_global_expert_indices.cpu().numpy()
|
||||||
@@ -337,17 +309,13 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
|
|||||||
|
|
||||||
if num_groups % num_nodes == 0:
|
if num_groups % num_nodes == 0:
|
||||||
# use hierarchical load-balance policy
|
# use hierarchical load-balance policy
|
||||||
phy2log_np, phy_replicas_idx_np, logcnt_np = (
|
phy2log_np = cls.rebalance_experts_hierarchical(
|
||||||
cls.rebalance_experts_hierarchical(
|
weight_np, num_replicas, num_groups, num_nodes, num_ranks
|
||||||
weight_np, num_replicas, num_groups, num_nodes, num_ranks
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# use global load-balance policy
|
# use global load-balance policy
|
||||||
phy2log_np, phy_replicas_idx_np, logcnt_np = (
|
phy2log_np = cls.rebalance_experts_hierarchical(
|
||||||
cls.rebalance_experts_hierarchical(
|
weight_np, num_replicas, 1, 1, num_ranks
|
||||||
weight_np, num_replicas, 1, 1, num_ranks
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Optional postprocessing to preserve slots for experts moving
|
# Optional postprocessing to preserve slots for experts moving
|
||||||
@@ -355,22 +323,10 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
|
|||||||
# Only apply when the number of GPUs and slots per GPU remain unchanged.
|
# Only apply when the number of GPUs and slots per GPU remain unchanged.
|
||||||
# Helps to avoid unnecessary weight copying when experts move
|
# Helps to avoid unnecessary weight copying when experts move
|
||||||
# within the same GPU.
|
# within the same GPU.
|
||||||
if old_global_expert_indices is not None:
|
if old_phy2log_np is not None:
|
||||||
phy2log_np, phy_replicas_idx_np = cls.preserve_intragpu_slots(
|
phy2log_np = cls.preserve_intragpu_slots(
|
||||||
phy2log_np, phy_replicas_idx_np, num_ranks, old_phy2log_np
|
phy2log_np, num_ranks, old_phy2log_np
|
||||||
)
|
)
|
||||||
num_redundant_experts = num_replicas - num_logical_experts
|
|
||||||
maxlogcnt = num_redundant_experts + 1
|
|
||||||
log2phy_np = np.full(
|
|
||||||
(num_layers, num_logical_experts, maxlogcnt), -1, dtype=np.int64
|
|
||||||
)
|
|
||||||
layer_indices = np.arange(num_layers)[:, None]
|
|
||||||
replica_indices = np.tile(
|
|
||||||
np.arange(num_replicas, dtype=np.int64), (num_layers, 1)
|
|
||||||
)
|
|
||||||
log2phy_np[layer_indices, phy2log_np, phy_replicas_idx_np] = replica_indices
|
|
||||||
|
|
||||||
phy2log = torch.from_numpy(phy2log_np).to(device)
|
phy2log = torch.from_numpy(phy2log_np)
|
||||||
log2phy = torch.from_numpy(log2phy_np).to(device)
|
return phy2log
|
||||||
logcnt = torch.from_numpy(logcnt_np).to(device)
|
|
||||||
return phy2log, log2phy, logcnt
|
|
||||||
|
|||||||
Reference in New Issue
Block a user