[Speculators][Speculative Decoding] Fix gpt-oss eagle3 accuracy issue (#25406)

Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
This commit is contained in:
jiahanc
2025-09-23 12:44:35 -07:00
committed by GitHub
parent 24fab45d96
commit d5944d5146
6 changed files with 79 additions and 17 deletions

View File

@@ -534,6 +534,8 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch):
proposer.runner.attn_groups.append([mock.MagicMock()])
proposer.runner.attn_groups[0][0].get_metadata_builder.return_value = \
attn_metadata_builder
proposer._get_attention_metadata_builder = mock.MagicMock(
return_value=attn_metadata_builder)
result = proposer.propose(target_token_ids=target_token_ids,
target_positions=target_positions,
@@ -660,6 +662,8 @@ def test_propose_tree(spec_token_tree):
proposer.runner.attn_groups.append([mock.MagicMock()])
proposer.runner.attn_groups[0][0].get_metadata_builder.return_value = \
attn_metadata_builder
proposer._get_attention_metadata_builder = mock.MagicMock(
return_value=attn_metadata_builder)
# Setup inputs for the proposer.
target_token_ids = torch.randint(0,