[V0 deprecation] Remove V0 HPU backend (#21131)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
@@ -388,20 +388,8 @@ class VocabParallelEmbedding(torch.nn.Module):
|
||||
|
||||
# Copy the data. Select chunk corresponding to current shard.
|
||||
loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
|
||||
|
||||
if current_platform.is_hpu():
|
||||
# FIXME(kzawora): Weight copy with slicing bugs out on Gaudi here,
|
||||
# so we're using a workaround. Remove this when fixed in
|
||||
# HPU PT bridge.
|
||||
padded_weight = torch.cat([
|
||||
loaded_weight,
|
||||
torch.zeros(param.shape[0] - loaded_weight.shape[0],
|
||||
*loaded_weight.shape[1:])
|
||||
])
|
||||
param.data.copy_(padded_weight)
|
||||
else:
|
||||
param[:loaded_weight.shape[0]].data.copy_(loaded_weight)
|
||||
param[loaded_weight.shape[0]:].data.fill_(0)
|
||||
param[:loaded_weight.shape[0]].data.copy_(loaded_weight)
|
||||
param[loaded_weight.shape[0]:].data.fill_(0)
|
||||
|
||||
def forward(self, input_):
|
||||
if self.tp_size > 1:
|
||||
|
||||
Reference in New Issue
Block a user