[platform] support custom torch.compile backend key (#11318)

Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>
Co-authored-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
wangxiyuan
2025-01-10 23:46:51 +08:00
committed by GitHub
parent 12664ddda5
commit 20410b2fda
5 changed files with 14 additions and 5 deletions

View File

@@ -133,7 +133,7 @@ class VocabParallelEmbeddingShardIndices:
assert self.num_added_elements <= self.num_added_elements_padded
@torch.compile(dynamic=True)
@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
def get_masked_input_and_mask(
input_: torch.Tensor, org_vocab_start_index: int,
org_vocab_end_index: int, num_org_vocab_padding: int,