[Bugfix] Fix platform-specific routing in CustomOp implementations (#24444)
Signed-off-by: Konrad Zawora <kzawora@habana.ai>
This commit is contained in:
@@ -399,7 +399,7 @@ class VocabParallelEmbedding(CustomOp):
|
||||
param[:loaded_weight.shape[0]].data.copy_(loaded_weight)
|
||||
param[loaded_weight.shape[0]:].data.fill_(0)
|
||||
|
||||
def forward(self, input_):
|
||||
def forward_native(self, input_):
|
||||
if self.tp_size > 1:
|
||||
# Build the mask.
|
||||
masked_input, input_mask = get_masked_input_and_mask(
|
||||
@@ -420,6 +420,9 @@ class VocabParallelEmbedding(CustomOp):
|
||||
output = tensor_model_parallel_all_reduce(output_parallel)
|
||||
return output
|
||||
|
||||
def forward_cuda(self, input_):
|
||||
return self.forward_native(input_)
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
s = f"num_embeddings={self.num_embeddings_per_partition}"
|
||||
s += f", embedding_dim={self.embedding_dim}"
|
||||
|
||||
Reference in New Issue
Block a user