[Bugfix] Fix platform-specific routing in CustomOp implementations (#24444)

Signed-off-by: Konrad Zawora <kzawora@habana.ai>
This commit is contained in:
Konrad Zawora
2025-09-11 19:15:01 +02:00
committed by GitHub
parent 1fdd5c42d7
commit 4aa23892d6
8 changed files with 53 additions and 30 deletions

View File

@@ -399,7 +399,7 @@ class VocabParallelEmbedding(CustomOp):
param[:loaded_weight.shape[0]].data.copy_(loaded_weight)
param[loaded_weight.shape[0]:].data.fill_(0)
def forward(self, input_):
def forward_native(self, input_):
if self.tp_size > 1:
# Build the mask.
masked_input, input_mask = get_masked_input_and_mask(
@@ -420,6 +420,9 @@ class VocabParallelEmbedding(CustomOp):
output = tensor_model_parallel_all_reduce(output_parallel)
return output
def forward_cuda(self, input_):
return self.forward_native(input_)
def extra_repr(self) -> str:
s = f"num_embeddings={self.num_embeddings_per_partition}"
s += f", embedding_dim={self.embedding_dim}"