Export NaNs in logits to scheduler_stats if output is corrupted (#18777)

Signed-off-by: Vlad Mihailescu <vtmihailescu@gmail.com>
This commit is contained in:
Vlad Tiberiu Mihailescu
2025-06-20 07:47:16 -07:00
committed by GitHub
parent 7e8977fcd4
commit 2e3e3c86dc
7 changed files with 104 additions and 2 deletions

View File

@@ -4,6 +4,7 @@
import random
import pytest
import torch
from vllm.attention import Attention
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
@@ -277,6 +278,54 @@ def test_update_states_request_resumed(model_runner):
assert _is_req_state_block_table_match(model_runner, req_id)
def test_get_nans_in_logits(model_runner):
req_ids = ("req_0", "req_1")
scheduler_output = _schedule_new_request(*req_ids)
model_runner._update_states(scheduler_output)
logits = torch.tensor([
[1.0, 2.0, 3.0],
[3.0, 2.0, 1.0],
], device=DEVICE)
result = model_runner._get_nans_in_logits(logits)
assert result == {"req_0": 0, "req_1": 0}
logits = torch.tensor([
[1.0, float('nan'), 3.0],
[4.0, float('nan'), float('nan')],
],
device=DEVICE)
result = model_runner._get_nans_in_logits(logits)
assert result == {"req_0": 1, "req_1": 2}
logits = torch.tensor([
[1.0, 2.0, 3.0],
[4.0, float('nan'), float('nan')],
],
device=DEVICE)
result = model_runner._get_nans_in_logits(logits)
assert result == {"req_0": 0, "req_1": 2}
result = model_runner._get_nans_in_logits(logits=None)
assert result == {"req_0": 0, "req_1": 0}
logits = torch.tensor([
[1.0, float('nan'), 3.0],
], device=DEVICE)
result = model_runner._get_nans_in_logits(logits)
assert result == {'req_0': 1, 'req_1': 0}
logits = torch.tensor([
[float('nan'), float('nan'), 2.0],
[1.0, 2.0, 3.0],
[float('nan'), 2.0, 3.0],
],
device=DEVICE)
result = model_runner._get_nans_in_logits(logits)
assert result == {'req_0': 2, 'req_1': 0}
def test_update_states_no_changes(model_runner):
req_id = "req_0"