Export NaNs in logits to scheduler_stats if output is corrupted (#18777)
Signed-off-by: Vlad Mihailescu <vtmihailescu@gmail.com>
This commit is contained in:
committed by
GitHub
parent
7e8977fcd4
commit
2e3e3c86dc
@@ -4,6 +4,7 @@
|
||||
import random
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.attention import Attention
|
||||
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
|
||||
@@ -277,6 +278,54 @@ def test_update_states_request_resumed(model_runner):
|
||||
assert _is_req_state_block_table_match(model_runner, req_id)
|
||||
|
||||
|
||||
def test_get_nans_in_logits(model_runner):
|
||||
req_ids = ("req_0", "req_1")
|
||||
|
||||
scheduler_output = _schedule_new_request(*req_ids)
|
||||
model_runner._update_states(scheduler_output)
|
||||
|
||||
logits = torch.tensor([
|
||||
[1.0, 2.0, 3.0],
|
||||
[3.0, 2.0, 1.0],
|
||||
], device=DEVICE)
|
||||
result = model_runner._get_nans_in_logits(logits)
|
||||
assert result == {"req_0": 0, "req_1": 0}
|
||||
|
||||
logits = torch.tensor([
|
||||
[1.0, float('nan'), 3.0],
|
||||
[4.0, float('nan'), float('nan')],
|
||||
],
|
||||
device=DEVICE)
|
||||
result = model_runner._get_nans_in_logits(logits)
|
||||
assert result == {"req_0": 1, "req_1": 2}
|
||||
|
||||
logits = torch.tensor([
|
||||
[1.0, 2.0, 3.0],
|
||||
[4.0, float('nan'), float('nan')],
|
||||
],
|
||||
device=DEVICE)
|
||||
result = model_runner._get_nans_in_logits(logits)
|
||||
assert result == {"req_0": 0, "req_1": 2}
|
||||
|
||||
result = model_runner._get_nans_in_logits(logits=None)
|
||||
assert result == {"req_0": 0, "req_1": 0}
|
||||
|
||||
logits = torch.tensor([
|
||||
[1.0, float('nan'), 3.0],
|
||||
], device=DEVICE)
|
||||
result = model_runner._get_nans_in_logits(logits)
|
||||
assert result == {'req_0': 1, 'req_1': 0}
|
||||
|
||||
logits = torch.tensor([
|
||||
[float('nan'), float('nan'), 2.0],
|
||||
[1.0, 2.0, 3.0],
|
||||
[float('nan'), 2.0, 3.0],
|
||||
],
|
||||
device=DEVICE)
|
||||
result = model_runner._get_nans_in_logits(logits)
|
||||
assert result == {'req_0': 2, 'req_1': 0}
|
||||
|
||||
|
||||
def test_update_states_no_changes(model_runner):
|
||||
req_id = "req_0"
|
||||
|
||||
|
||||
Reference in New Issue
Block a user