[platform] support custom torch.compile backend key (#11318)
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com> Signed-off-by: youkaichao <youkaichao@gmail.com> Co-authored-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
@@ -133,7 +133,7 @@ class VocabParallelEmbeddingShardIndices:
|
||||
assert self.num_added_elements <= self.num_added_elements_padded
|
||||
|
||||
|
||||
@torch.compile(dynamic=True)
|
||||
@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
|
||||
def get_masked_input_and_mask(
|
||||
input_: torch.Tensor, org_vocab_start_index: int,
|
||||
org_vocab_end_index: int, num_org_vocab_padding: int,
|
||||
|
||||
Reference in New Issue
Block a user