add router mode debug print
This commit is contained in:
@@ -23,6 +23,9 @@ CHECKPOINT_DIR = os.environ.get(
|
||||
NUM_GPUS = int(os.environ.get("NUM_GPUS", "8"))
|
||||
DEVICE = "cuda:0"
|
||||
TEST_LAYERS = int(os.environ.get("TEST_LAYERS", "5"))
|
||||
# First layer index to test. L0-1 are hash routing, L2+ are dense/CSA/HCA.
|
||||
# Set to 0 to include hash layers.
|
||||
FIRST_LAYER = int(os.environ.get("FIRST_LAYER", "2"))
|
||||
|
||||
|
||||
def cosine(a, b):
|
||||
@@ -280,6 +283,9 @@ def main():
|
||||
gpu = li % NUM_GPUS
|
||||
if X.device != torch.device(f"cuda:{gpu}"): X = X.to(f"cuda:{gpu}")
|
||||
torch.cuda.set_device(gpu)
|
||||
if pi == 0:
|
||||
r = routers.get(li)
|
||||
print(f" L{li} router: mode={r.mode if r else 'None'} has_gate_lin={r._gate_lin is not None if r and hasattr(r, '_gate_lin') else 'N/A'}", flush=True)
|
||||
X = forward_layer(X, layer_w[li], li, cfg, *rope_caches[gpu],
|
||||
attn_mhcs.get(li), ffn_mhcs.get(li), attn_norms.get(li), ffn_norms.get(li),
|
||||
kv_caches[li], pos, tid32, compressors.get(li), indexers.get(li),
|
||||
|
||||
Reference in New Issue
Block a user