Fix some typing issues found by mypy==1.18.2 (#26596)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-10 19:21:25 +01:00
committed by GitHub
parent 3b780a4bbb
commit 7c12763b24
6 changed files with 19 additions and 18 deletions

View File

@@ -1229,10 +1229,10 @@ class QKVParallelLinear(ColumnParallelLinear):
param_data = param_data.narrow(output_dim, shard_offset, shard_size)
if loaded_shard_id == "q":
shard_id = self.tp_rank
shard_rank = self.tp_rank
else:
shard_id = self.tp_rank // self.num_kv_head_replicas
start_idx = shard_id * shard_size
shard_rank = self.tp_rank // self.num_kv_head_replicas
start_idx = shard_rank * shard_size
if not is_sharded_weight:
loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)