[BugFix] Fix GGUF tp>1 when vocab_size is not divisible by 64 (#12230)
Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
@@ -355,7 +355,7 @@ class VocabParallelEmbedding(torch.nn.Module):
|
||||
elif isinstance(param, UninitializedParameter):
|
||||
shape = list(loaded_weight.shape)
|
||||
if output_dim is not None:
|
||||
shape[output_dim] = shape[output_dim] // self.tp_size
|
||||
shape[output_dim] = self.num_embeddings_per_partition
|
||||
param.materialize(tuple(shape), dtype=loaded_weight.dtype)
|
||||
|
||||
# If parameter does not have output dim, then it should
|
||||
@@ -381,7 +381,7 @@ class VocabParallelEmbedding(torch.nn.Module):
|
||||
else:
|
||||
assert loaded_weight.shape[output_dim] == self.org_vocab_size
|
||||
|
||||
# Copy the data.
|
||||
# Copy the data. Select chunk corresponding to current shard.
|
||||
loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
|
||||
|
||||
if current_platform.is_hpu():
|
||||
|
||||
Reference in New Issue
Block a user