Fix some typing issues found by mypy==1.18.2 (#26596)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -1229,10 +1229,10 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
|
||||
param_data = param_data.narrow(output_dim, shard_offset, shard_size)
|
||||
if loaded_shard_id == "q":
|
||||
shard_id = self.tp_rank
|
||||
shard_rank = self.tp_rank
|
||||
else:
|
||||
shard_id = self.tp_rank // self.num_kv_head_replicas
|
||||
start_idx = shard_id * shard_size
|
||||
shard_rank = self.tp_rank // self.num_kv_head_replicas
|
||||
start_idx = shard_rank * shard_size
|
||||
|
||||
if not is_sharded_weight:
|
||||
loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
|
||||
|
||||
Reference in New Issue
Block a user