Compare commits

..

11 Commits

Author SHA1 Message Date
roikoren755
95c0f928cd [NemotronH] Small fix reasoning parser (#36635)
Signed-off-by: Roi Koren <roik@nvidia.com>
(cherry picked from commit e661b9ee83)
2026-03-11 02:51:18 -07:00
Shaun Kotek
c9b1e977dc add nemotron v3 reasoning parser (#36393)
Signed-off-by: Shaun Kotek - Nvidia <skotek@nvidia.com>
Co-authored-by: root <root@gpu-259.slurm-workers-slurm.slurm.svc.cluster.local>
(cherry picked from commit 203a7f27da)
2026-03-11 02:51:04 -07:00
Kevin H. Luu
1ff2393897 [ci] Bound nvidia-cudnn-frontend version (#36719)
Signed-off-by: khluu <khluu000@gmail.com>
(cherry picked from commit 82b110d50e)
2026-03-10 21:20:41 -07:00
Benjamin Chislett
5bec0b0ba3 [DSV3.2][MTP] Optimize Indexer MTP handling (#36723)
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
(cherry picked from commit 9040cd40af)
2026-03-10 21:20:23 -07:00
Wei Zhao
6da1310f91 [Bug] Fix TRTLLM Block FP8 MoE Monolithic (#36296)
Signed-off-by: wzhao18 <wzhao18.sz@gmail.com>
Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
(cherry picked from commit 84e436ed1c)
2026-03-10 19:08:18 -07:00
khluu
bc46be5daf Revert "add nemotron v3 reasoning parser (#36393)"
This reverts commit 8e39d39fd4.
2026-03-10 11:47:09 -07:00
Shaun Kotek
8e39d39fd4 add nemotron v3 reasoning parser (#36393)
Signed-off-by: Shaun Kotek - Nvidia <skotek@nvidia.com>
Co-authored-by: root <root@gpu-259.slurm-workers-slurm.slurm.svc.cluster.local>
(cherry picked from commit 203a7f27da)
2026-03-10 09:50:38 -07:00
Vadim Gimpelson
46fa044cc1 [BUGFIX][Mamba][Qwen3.5] Zero freed SSM cache blocks on GPU (#35219)
Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
(cherry picked from commit 4ff8c3c8f9)
2026-03-10 09:26:18 -07:00
amirkl94
ab43e37158 Fix: Re-Enable EP for trtllm MoE FP8 backend (#36494)
Signed-off-by: Amir Klein <203507526+amirkl94@users.noreply.github.com>
(cherry picked from commit 156e33553c)
2026-03-10 09:26:03 -07:00
Shaun Kotek
f45d010120 Fix/resupport nongated fused moe triton (#36412)
Signed-off-by: Shaun Kotek - Nvidia <skotek@nvidia.com>
Signed-off-by: Natan Bagrov <nbagrov@nvidia.com>
Signed-off-by: Daniel Serebrenik <daserebrenik@nvidia.com>
Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
Signed-off-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
Signed-off-by: liweiguang <codingpunk@gmail.com>
Signed-off-by: wang.yuqi <yuqi.wang@daocloud.io>
Signed-off-by: wang.yuqi <noooop@126.com>
Signed-off-by: Alex Brooks <albrooks@redhat.com>
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Signed-off-by: cong-or <conchubhar.gannon@gmail.com>
Signed-off-by: Tushar Shetty <tushar.shetty@abbyy.com>
Signed-off-by: Tushar Shetty <54362365+tusharshetty61@users.noreply.github.com>
Signed-off-by: jiang1.li <jiang1.li@intel.com>
Signed-off-by: zhenwei-intel <zhenwei.liu@intel.com>
Signed-off-by: Xin Yang <xyangx@amazon.com>
Signed-off-by: Kevin H. Luu <khluu000@gmail.com>
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Co-authored-by: nvnbagrov <nbagrov@nvidia.com>
Co-authored-by: Sage <80211083+sagearc@users.noreply.github.com>
Co-authored-by: danisereb <daserebrenik@nvidia.com>
Co-authored-by: Jiangyun Zhu <riverclouds.zhu@qq.com>
Co-authored-by: Kunshang Ji <kunshang.ji@intel.com>
Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Weiguang Li <codingpunk@gmail.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: Li, Jiang <jiang1.li@intel.com>
Co-authored-by: wang.yuqi <yuqi.wang@daocloud.io>
Co-authored-by: Alex Brooks <albrooks@redhat.com>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
Co-authored-by: cong-or <conchubhar.gannon@gmail.com>
Co-authored-by: Tushar Shetty <54362365+tusharshetty61@users.noreply.github.com>
Co-authored-by: liuzhenwei <zhenwei.liu@intel.com>
Co-authored-by: Xin Yang <105740670+xyang16@users.noreply.github.com>
Co-authored-by: Kevin H. Luu <khluu000@gmail.com>
Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
(cherry picked from commit fa028207aa)
2026-03-10 09:25:51 -07:00
amitz-nv
244b922088 [Bugfix] Fix passing of activation_type to trtllm fused MoE NVFP4 and FP8 (#36017)
Signed-off-by: amitz-nv <203509407+amitz-nv@users.noreply.github.com>
(cherry picked from commit d7adcadb9b)
2026-03-10 09:25:36 -07:00
19 changed files with 515 additions and 24 deletions

View File

@@ -11,6 +11,9 @@ torchaudio==2.10.0
torchvision==0.25.0 # Required for phi3v processor. See https://github.com/pytorch/vision?tab=readme-ov-file#installation for corresponding version torchvision==0.25.0 # Required for phi3v processor. See https://github.com/pytorch/vision?tab=readme-ov-file#installation for corresponding version
# FlashInfer should be updated together with the Dockerfile # FlashInfer should be updated together with the Dockerfile
flashinfer-python==0.6.4 flashinfer-python==0.6.4
# Cap nvidia-cudnn-frontend (transitive dep of flashinfer) due to
# breaking changes in 1.19.0
nvidia-cudnn-frontend>=1.13.0,<1.19.0
# QuACK and Cutlass DSL for FA4 (cute-DSL implementation) # QuACK and Cutlass DSL for FA4 (cute-DSL implementation)
nvidia-cutlass-dsl>=4.4.0.dev1 nvidia-cutlass-dsl>=4.4.0.dev1

View File

@@ -0,0 +1,172 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import TypedDict
import pytest
import regex as re
from tests.reasoning.utils import run_reasoning_extraction
from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest
from vllm.reasoning import ReasoningParser, ReasoningParserManager
parser_name = "nemotron_v3"
class ReasoningCase(TypedDict):
output: str
reasoning: str | None
content: str | None
class FakeNemotronTokenizer:
def __init__(self):
self._vocab = {
"<think>": 1,
"</think>": 2,
}
self._pattern = re.compile(r"(<think>|</think>)")
def get_vocab(self) -> dict[str, int]:
return self._vocab
def tokenize(self, text: str) -> list[str]:
tokens: list[str] = []
for part in self._pattern.split(text):
if part:
tokens.append(part)
return tokens
def convert_tokens_to_string(self, tokens: list[str]) -> str:
return "".join(tokens)
@pytest.fixture
def tokenizer():
return FakeNemotronTokenizer()
@pytest.mark.parametrize(
"streaming,param_dict",
[
pytest.param(
False,
{
"output": "This is a reasoning section</think>This is the rest",
"reasoning": "This is a reasoning section",
"content": "This is the rest",
},
id="without_start_token",
),
pytest.param(
True,
{
"output": "This is a reasoning section</think>This is the rest",
"reasoning": "This is a reasoning section",
"content": "This is the rest",
},
id="without_start_token_streaming",
),
pytest.param(
False,
{
"output": "<think>This is a reasoning section</think>This is the rest",
"reasoning": "This is a reasoning section",
"content": "This is the rest",
},
id="with_start_token",
),
pytest.param(
True,
{
"output": "<think>This is a reasoning section</think>This is the rest",
"reasoning": "This is a reasoning section",
"content": "This is the rest",
},
id="with_start_token_streaming",
),
],
)
def test_nemotron_v3_reasoning(
tokenizer: FakeNemotronTokenizer,
streaming: bool,
param_dict: ReasoningCase,
):
output = tokenizer.tokenize(param_dict["output"])
model_output = [tokenizer.convert_tokens_to_string([token]) for token in output]
parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser(parser_name)(
tokenizer
)
reasoning, content = run_reasoning_extraction(
parser, model_output, streaming=streaming
)
assert reasoning == param_dict["reasoning"]
assert content == param_dict["content"]
def test_nemotron_v3_without_thinking_returns_content(
tokenizer: FakeNemotronTokenizer,
):
parser_cls = ReasoningParserManager.get_reasoning_parser(parser_name)
parser = parser_cls(tokenizer)
request = ChatCompletionRequest(
model="test-model",
messages=[],
chat_template_kwargs={"enable_thinking": False},
)
reasoning, content = run_reasoning_extraction(
parser,
["This is plain content"],
request=request,
streaming=False,
)
assert reasoning is None
assert content == "This is plain content"
def test_nemotron_v3_force_nonempty_content_returns_content(
tokenizer: FakeNemotronTokenizer,
):
parser_cls = ReasoningParserManager.get_reasoning_parser(parser_name)
parser = parser_cls(tokenizer)
request = ChatCompletionRequest(
model="test-model",
messages=[],
chat_template_kwargs={"force_nonempty_content": True},
)
reasoning, content = run_reasoning_extraction(
parser,
["<think>This is plain content"],
request=request,
streaming=False,
)
assert reasoning is None
assert content == "This is plain content"
def test_nemotron_v3_with_thinking_keeps_truncated_reasoning(
tokenizer: FakeNemotronTokenizer,
):
parser_cls = ReasoningParserManager.get_reasoning_parser(parser_name)
parser = parser_cls(tokenizer)
request = ChatCompletionRequest(
model="test-model",
messages=[],
chat_template_kwargs={"enable_thinking": True},
)
reasoning, content = run_reasoning_extraction(
parser,
["This is truncated reasoning"],
request=request,
streaming=False,
)
assert reasoning == "This is truncated reasoning"
assert content is None

View File

@@ -35,12 +35,6 @@ class TrtLlmFp8Experts(mk.FusedMoEExpertsMonolithic):
): ):
super().__init__(moe_config, quant_config) super().__init__(moe_config, quant_config)
if moe_config.moe_parallel_config.use_ep and quant_config.is_per_tensor:
raise NotImplementedError(
"EP parallelism is not supported with TRTLLM"
"per-tensor FP8 quantization."
)
self.routing_method_type = moe_config.routing_method self.routing_method_type = moe_config.routing_method
self.topk = moe_config.experts_per_token self.topk = moe_config.experts_per_token
self.intermediate_size_per_partition = ( self.intermediate_size_per_partition = (
@@ -182,9 +176,6 @@ class TrtLlmFp8Experts(mk.FusedMoEExpertsMonolithic):
assert not apply_router_weight_on_input assert not apply_router_weight_on_input
assert activation == MoEActivation.SILU assert activation == MoEActivation.SILU
if e_score_correction_bias is not None:
e_score_correction_bias = e_score_correction_bias.to(hidden_states.dtype)
if self.routing_method_type == RoutingMethodType.DeepSeekV3: if self.routing_method_type == RoutingMethodType.DeepSeekV3:
router_logits = router_logits.to(torch.float32) router_logits = router_logits.to(torch.float32)
@@ -240,12 +231,11 @@ class TrtLlmFp8Experts(mk.FusedMoEExpertsMonolithic):
) -> torch.Tensor: ) -> torch.Tensor:
# Delay import for non-CUDA. # Delay import for non-CUDA.
import flashinfer import flashinfer
from flashinfer.fused_moe.core import ActivationType
# Confirm supported activation function. # Confirm supported activation function.
assert activation in [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL] assert activation in [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL]
activation_type = ActivationType(activation_to_flashinfer_int(activation)) activation_type = activation_to_flashinfer_int(activation)
# Confirm Llama-4 routing is proper. # Confirm Llama-4 routing is proper.
if self.routing_method_type == RoutingMethodType.Llama4: if self.routing_method_type == RoutingMethodType.Llama4:

View File

@@ -323,4 +323,5 @@ class TrtLlmNvFp4ExpertsMonolithic(
routed_scaling_factor=routed_scaling_factor, routed_scaling_factor=routed_scaling_factor,
routing_method_type=self.routing_method_type, routing_method_type=self.routing_method_type,
do_finalize=True, do_finalize=True,
activation_type=activation_to_flashinfer_int(activation),
)[0] )[0]

View File

@@ -912,7 +912,7 @@ class BatchedTritonExperts(mk.FusedMoEExpertsModular):
@staticmethod @staticmethod
def _supports_no_act_and_mul() -> bool: def _supports_no_act_and_mul() -> bool:
return False return True
@staticmethod @staticmethod
def _supports_quant_scheme( def _supports_quant_scheme(

View File

@@ -1944,7 +1944,7 @@ class TritonExperts(mk.FusedMoEExpertsModular):
@staticmethod @staticmethod
def _supports_no_act_and_mul() -> bool: def _supports_no_act_and_mul() -> bool:
return False return True
@staticmethod @staticmethod
def _supports_quant_scheme( def _supports_quant_scheme(
@@ -1983,6 +1983,9 @@ class TritonExperts(mk.FusedMoEExpertsModular):
MoEActivation.GELU, MoEActivation.GELU,
MoEActivation.SWIGLUOAI, MoEActivation.SWIGLUOAI,
MoEActivation.SWIGLUSTEP, MoEActivation.SWIGLUSTEP,
MoEActivation.SILU_NO_MUL,
MoEActivation.GELU_NO_MUL,
MoEActivation.RELU2_NO_MUL,
] ]
@staticmethod @staticmethod

View File

@@ -68,6 +68,10 @@ _REASONING_PARSERS_TO_REGISTER = {
"mistral_reasoning_parser", "mistral_reasoning_parser",
"MistralReasoningParser", "MistralReasoningParser",
), ),
"nemotron_v3": (
"nemotron_v3_reasoning_parser",
"NemotronV3ReasoningParser",
),
"olmo3": ( "olmo3": (
"olmo3_reasoning_parser", "olmo3_reasoning_parser",
"Olmo3ReasoningParser", "Olmo3ReasoningParser",

View File

@@ -0,0 +1,35 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.entrypoints.openai.chat_completion.protocol import (
ChatCompletionRequest,
)
from vllm.entrypoints.openai.responses.protocol import (
ResponsesRequest,
)
from vllm.reasoning.deepseek_r1_reasoning_parser import DeepSeekR1ReasoningParser
class NemotronV3ReasoningParser(DeepSeekR1ReasoningParser):
"""
Reasoning parser for Nemotron V3 models.
"""
def extract_reasoning(
self, model_output: str, request: ChatCompletionRequest | ResponsesRequest
) -> tuple[str | None, str | None]:
reasoning_content, final_content = super().extract_reasoning(
model_output, request
)
chat_template_kwargs = getattr(request, "chat_template_kwargs", None)
if (
chat_template_kwargs
and (
chat_template_kwargs.get("enable_thinking") is False
or chat_template_kwargs.get("force_nonempty_content") is True
)
and final_content is None
):
reasoning_content, final_content = final_content, reasoning_content
return reasoning_content, final_content

View File

@@ -30,3 +30,8 @@ def round_up(x: int, y: int) -> int:
def round_down(x: int, y: int) -> int: def round_down(x: int, y: int) -> int:
"""Round down x to the nearest multiple of y.""" """Round down x to the nearest multiple of y."""
return (x // y) * y return (x // y) * y
def largest_power_of_2_divisor(n: int) -> int:
"""Return the largest power-of-2 that divides *n* (isolate lowest set bit)."""
return n & (-n)

View File

@@ -86,6 +86,26 @@ class AttentionBackend(ABC):
) -> tuple[int, ...]: ) -> tuple[int, ...]:
raise NotImplementedError raise NotImplementedError
@classmethod
def get_kv_cache_block_dim(
cls,
block_size: int,
num_kv_heads: int,
head_size: int,
cache_dtype_str: str = "auto",
) -> int:
"""Discover which tensor dim is the block index, since different
backends lay out dims differently."""
_S = 1234567
shape = cls.get_kv_cache_shape(
_S,
block_size,
num_kv_heads,
head_size,
cache_dtype_str=cache_dtype_str,
)
return shape.index(_S)
@staticmethod @staticmethod
def get_kv_cache_stride_order( def get_kv_cache_stride_order(
include_num_layers_dimension: bool = False, include_num_layers_dimension: bool = False,

View File

@@ -372,12 +372,14 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
# [7, 6, 8, 0] -> [7, 7, 7, 6, 8, 8, 8, 8] # [7, 6, 8, 0] -> [7, 7, 7, 6, 8, 8, 8, 8]
expanded_base = torch.repeat_interleave( expanded_base = torch.repeat_interleave(
seq_lens - decode_lens, decode_lens seq_lens - decode_lens, decode_lens, output_size=actual_expanded
) )
# [0, 3, 4, 8] -> [0, 0, 0, 3, 4, 4, 4, 4] # [0, 3, 4, 8] -> [0, 0, 0, 3, 4, 4, 4, 4]
expanded_starts = torch.repeat_interleave( expanded_starts = torch.repeat_interleave(
common_attn_metadata.query_start_loc[:num_decodes], decode_lens common_attn_metadata.query_start_loc[:num_decodes],
decode_lens,
output_size=actual_expanded,
) )
# [0, 1, 2, 0, 0, 1, 2, 3] # [0, 1, 2, 0, 0, 1, 2, 3]
@@ -395,7 +397,9 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
# Give each of the flattened entries the same block table row as the # Give each of the flattened entries the same block table row as the
# original request. # original request.
self.expanded_block_table_buffer[:actual_expanded] = ( self.expanded_block_table_buffer[:actual_expanded] = (
torch.repeat_interleave(block_table, decode_lens, dim=0) torch.repeat_interleave(
block_table, decode_lens, dim=0, output_size=actual_expanded
)
) )
if actual_expanded < num_decode_tokens: if actual_expanded < num_decode_tokens:
self.expanded_block_table_buffer[ self.expanded_block_table_buffer[

View File

@@ -489,6 +489,13 @@ class KVCacheManager:
# Only create new KVCacheBlocks for non-empty blocks # Only create new KVCacheBlocks for non-empty blocks
return KVCacheBlocks(blocks) if any(blocks) else self.empty_kv_cache_blocks return KVCacheBlocks(blocks) if any(blocks) else self.empty_kv_cache_blocks
def take_new_block_ids(self) -> list[int]:
"""Drain and return new attention block IDs for zeroing."""
ids: list[int] = []
for mgr in self.coordinator.single_type_managers:
ids.extend(mgr.take_new_block_ids())
return ids
def new_step_starts(self) -> None: def new_step_starts(self) -> None:
"""Called when a new step is started.""" """Called when a new step is started."""
self.coordinator.new_step_starts() self.coordinator.new_step_starts()

View File

@@ -233,6 +233,11 @@ class SchedulerOutput:
# EC Cache Connector metadata # EC Cache Connector metadata
ec_connector_metadata: ECConnectorMetadata | None = None ec_connector_metadata: ECConnectorMetadata | None = None
# Block IDs freshly allocated from the pool during this scheduling step.
# The worker zeros the corresponding GPU memory before the blocks are used,
# preventing stale NaN/data from corrupting attention or SSM computation.
new_block_ids_to_zero: list[int] | None = None
@classmethod @classmethod
def make_empty(cls) -> "SchedulerOutput": def make_empty(cls) -> "SchedulerOutput":
return cls( return cls(

View File

@@ -48,7 +48,7 @@ from vllm.v1.core.sched.output import (
from vllm.v1.core.sched.request_queue import SchedulingPolicy, create_request_queue from vllm.v1.core.sched.request_queue import SchedulingPolicy, create_request_queue
from vllm.v1.core.sched.utils import check_stop, remove_all from vllm.v1.core.sched.utils import check_stop, remove_all
from vllm.v1.engine import EngineCoreEventType, EngineCoreOutput, EngineCoreOutputs from vllm.v1.engine import EngineCoreEventType, EngineCoreOutput, EngineCoreOutputs
from vllm.v1.kv_cache_interface import KVCacheConfig, MambaSpec from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.metrics.perf import ModelMetrics, PerfStats from vllm.v1.metrics.perf import ModelMetrics, PerfStats
from vllm.v1.metrics.stats import PrefixCacheStats, SchedulerStats from vllm.v1.metrics.stats import PrefixCacheStats, SchedulerStats
from vllm.v1.outputs import DraftTokenIds, KVConnectorOutput, ModelRunnerOutput from vllm.v1.outputs import DraftTokenIds, KVConnectorOutput, ModelRunnerOutput
@@ -233,13 +233,8 @@ class Scheduler(SchedulerInterface):
self.use_pp = self.parallel_config.pipeline_parallel_size > 1 self.use_pp = self.parallel_config.pipeline_parallel_size > 1
self.use_v2_model_runner = envs.VLLM_USE_V2_MODEL_RUNNER self.use_v2_model_runner = envs.VLLM_USE_V2_MODEL_RUNNER
def has_mamba_layers(kv_cache_config: KVCacheConfig) -> bool: self.has_mamba_layers = kv_cache_config.has_mamba_layers
return any( self.needs_kv_cache_zeroing = kv_cache_config.needs_kv_cache_zeroing
isinstance(group_spec.kv_cache_spec, MambaSpec)
for group_spec in kv_cache_config.kv_cache_groups
)
self.has_mamba_layers = has_mamba_layers(kv_cache_config)
self.need_mamba_block_aligned_split = ( self.need_mamba_block_aligned_split = (
self.has_mamba_layers and self.cache_config.mamba_cache_mode == "align" self.has_mamba_layers and self.cache_config.mamba_cache_mode == "align"
) )
@@ -871,6 +866,12 @@ class Scheduler(SchedulerInterface):
self.prev_step_scheduled_req_ids.clear() self.prev_step_scheduled_req_ids.clear()
self.prev_step_scheduled_req_ids.update(num_scheduled_tokens.keys()) self.prev_step_scheduled_req_ids.update(num_scheduled_tokens.keys())
new_block_ids_to_zero = (
(self.kv_cache_manager.take_new_block_ids() or None)
if self.needs_kv_cache_zeroing
else None
)
scheduler_output = SchedulerOutput( scheduler_output = SchedulerOutput(
scheduled_new_reqs=new_reqs_data, scheduled_new_reqs=new_reqs_data,
scheduled_cached_reqs=cached_reqs_data, scheduled_cached_reqs=cached_reqs_data,
@@ -886,6 +887,7 @@ class Scheduler(SchedulerInterface):
# the previous and the current steps. # the previous and the current steps.
finished_req_ids=self.finished_req_ids, finished_req_ids=self.finished_req_ids,
free_encoder_mm_hashes=self.encoder_cache_manager.get_freed_mm_hashes(), free_encoder_mm_hashes=self.encoder_cache_manager.get_freed_mm_hashes(),
new_block_ids_to_zero=new_block_ids_to_zero,
) )
# NOTE(Kuntai): this function is designed for multiple purposes: # NOTE(Kuntai): this function is designed for multiple purposes:

View File

@@ -55,6 +55,7 @@ class SingleTypeKVCacheManager(ABC):
self.kv_cache_spec = kv_cache_spec self.kv_cache_spec = kv_cache_spec
self.block_pool = block_pool self.block_pool = block_pool
self.enable_caching = enable_caching self.enable_caching = enable_caching
self.new_block_ids: list[int] = []
# Mapping from request ID to blocks to track the blocks allocated # Mapping from request ID to blocks to track the blocks allocated
# for each request, so that we can free the blocks when the request # for each request, so that we can free the blocks when the request
@@ -208,6 +209,8 @@ class SingleTypeKVCacheManager(ABC):
cdiv(num_total_computed_tokens, self.block_size) - len(req_blocks) cdiv(num_total_computed_tokens, self.block_size) - len(req_blocks)
) )
req_blocks.extend(allocated_blocks) req_blocks.extend(allocated_blocks)
if type(self.kv_cache_spec) is FullAttentionSpec:
self.new_block_ids.extend(b.block_id for b in allocated_blocks)
def allocate_new_blocks( def allocate_new_blocks(
self, request_id: str, num_tokens: int, num_tokens_main_model: int self, request_id: str, num_tokens: int, num_tokens_main_model: int
@@ -234,8 +237,16 @@ class SingleTypeKVCacheManager(ABC):
else: else:
new_blocks = self.block_pool.get_new_blocks(num_new_blocks) new_blocks = self.block_pool.get_new_blocks(num_new_blocks)
req_blocks.extend(new_blocks) req_blocks.extend(new_blocks)
if type(self.kv_cache_spec) is FullAttentionSpec:
self.new_block_ids.extend(b.block_id for b in new_blocks)
return new_blocks return new_blocks
def take_new_block_ids(self) -> list[int]:
"""Drain and return block IDs allocated since the last call."""
ids = self.new_block_ids
self.new_block_ids = []
return ids
def cache_blocks(self, request: Request, num_tokens: int) -> None: def cache_blocks(self, request: Request, num_tokens: int) -> None:
""" """
Cache the blocks for the request. Cache the blocks for the request.

View File

@@ -489,3 +489,11 @@ class KVCacheConfig:
For models with multiple types of attention, there will be multiple groups, For models with multiple types of attention, there will be multiple groups,
see `_get_kv_cache_config_uniform_page_size` for more details. see `_get_kv_cache_config_uniform_page_size` for more details.
""" """
@property
def has_mamba_layers(self) -> bool:
return any(isinstance(g.kv_cache_spec, MambaSpec) for g in self.kv_cache_groups)
@property
def needs_kv_cache_zeroing(self) -> bool:
return self.has_mamba_layers

View File

@@ -187,6 +187,7 @@ from vllm.v1.worker.workspace import lock_workspace
from .utils import ( from .utils import (
AttentionGroup, AttentionGroup,
KVBlockZeroer,
add_kv_sharing_layers_to_kv_cache_groups, add_kv_sharing_layers_to_kv_cache_groups,
bind_kv_cache, bind_kv_cache,
prepare_kernel_block_sizes, prepare_kernel_block_sizes,
@@ -918,6 +919,26 @@ class GPUModelRunner(
decode_threshold=self.reorder_batch_threshold, decode_threshold=self.reorder_batch_threshold,
) )
def _init_kv_zero_meta(self) -> None:
"""One-time precomputation for _zero_block_ids.
Delegates to KVBlockZeroer.init_meta with the runner's state.
Called from gpu_worker.py outside the CuMem pool context.
"""
self._kv_block_zeroer = KVBlockZeroer(self.device, self.pin_memory)
self._kv_block_zeroer.init_meta(
attn_groups_iter=self._kv_cache_spec_attn_group_iterator(),
kernel_block_sizes=self._kernel_block_sizes,
cache_dtype=self.cache_config.cache_dtype,
runner_only_attn_layers=self.runner_only_attn_layers,
static_forward_context=(self.compilation_config.static_forward_context),
)
def _zero_block_ids(self, block_ids: list[int]) -> None:
"""Zero the KV cache memory for the given block IDs."""
if hasattr(self, "_kv_block_zeroer"):
self._kv_block_zeroer.zero_block_ids(block_ids)
# Note: used for model runner override. # Note: used for model runner override.
def _init_device_properties(self) -> None: def _init_device_properties(self) -> None:
"""Initialize attributes from torch.cuda.get_device_properties""" """Initialize attributes from torch.cuda.get_device_properties"""
@@ -951,6 +972,11 @@ class GPUModelRunner(
for req_id in scheduler_output.finished_req_ids: for req_id in scheduler_output.finished_req_ids:
self.input_batch.remove_request(req_id) self.input_batch.remove_request(req_id)
# Zero GPU memory for freshly allocated cache blocks to prevent
# stale NaN/data from corrupting attention or SSM computation.
if scheduler_output.new_block_ids_to_zero:
self._zero_block_ids(scheduler_output.new_block_ids_to_zero)
# Free the cached encoder outputs. # Free the cached encoder outputs.
for mm_hash in scheduler_output.free_encoder_mm_hashes: for mm_hash in scheduler_output.free_encoder_mm_hashes:
self.encoder_cache.pop(mm_hash, None) self.encoder_cache.pop(mm_hash, None)
@@ -6066,6 +6092,7 @@ class GPUModelRunner(
kernel_block_sizes = prepare_kernel_block_sizes( kernel_block_sizes = prepare_kernel_block_sizes(
kv_cache_config, self.attn_groups kv_cache_config, self.attn_groups
) )
self._kernel_block_sizes = kernel_block_sizes
# create metadata builders # create metadata builders
self.initialize_metadata_builders(kv_cache_config, kernel_block_sizes) self.initialize_metadata_builders(kv_cache_config, kernel_block_sizes)

View File

@@ -480,6 +480,14 @@ class Worker(WorkerBase):
else: else:
self.model_runner.initialize_kv_cache(kv_cache_config) self.model_runner.initialize_kv_cache(kv_cache_config)
# Build KV-zero metadata outside the CuMem pool so the bookkeeping
# GPU tensors (seg_addrs, block-id buffers) use the standard PyTorch
# allocator and are not discarded during sleep/wake cycles.
if kv_cache_config.needs_kv_cache_zeroing and hasattr(
self.model_runner, "_init_kv_zero_meta"
):
self.model_runner._init_kv_zero_meta()
@instrument(span_name="Warmup (GPU)") @instrument(span_name="Warmup (GPU)")
def compile_or_warm_up_model(self) -> float: def compile_or_warm_up_model(self) -> float:
warmup_sizes = [] warmup_sizes = []

View File

@@ -2,7 +2,10 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math import math
from collections import defaultdict from collections import defaultdict
from collections.abc import Iterable
from dataclasses import dataclass, field from dataclasses import dataclass, field
from itertools import product as iprod
from typing import Any
import torch import torch
@@ -12,6 +15,8 @@ from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.models.interfaces import MultiModalEmbeddings from vllm.model_executor.models.interfaces import MultiModalEmbeddings
from vllm.model_executor.models.utils import extract_layer_index from vllm.model_executor.models.utils import extract_layer_index
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.utils.math_utils import largest_power_of_2_divisor
from vllm.utils.mem_utils import MemorySnapshot, format_gib from vllm.utils.mem_utils import MemorySnapshot, format_gib
from vllm.v1.attention.backend import ( from vllm.v1.attention.backend import (
AttentionBackend, AttentionBackend,
@@ -21,6 +26,7 @@ from vllm.v1.attention.backend import (
from vllm.v1.kv_cache_interface import ( from vllm.v1.kv_cache_interface import (
AttentionSpec, AttentionSpec,
EncoderOnlyAttentionSpec, EncoderOnlyAttentionSpec,
FullAttentionSpec,
KVCacheConfig, KVCacheConfig,
KVCacheGroupSpec, KVCacheGroupSpec,
KVCacheSpec, KVCacheSpec,
@@ -31,6 +37,186 @@ from vllm.v1.kv_cache_interface import (
logger = init_logger(__name__) logger = init_logger(__name__)
@triton.jit
def _zero_kv_blocks_kernel(
seg_addrs_ptr,
block_ids_ptr,
n_blocks,
N_SEGS: tl.constexpr,
PAGE_SIZE_EL: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
"""Zero KV cache blocks across all segments in a single launch.
Each segment is a contiguous region of one block's data. For backends
where blocks are outermost (block_dim=0) there is one segment per
buffer. For backends where K/V is outermost (block_dim=1) there are
two segments per buffer (one for K, one for V).
seg_addrs_ptr holds absolute byte addresses (int64) for each segment,
allowing segments to live in different CUDA allocations.
Programs are mapped as (block_index, seg_index, chunk_index).
"""
pid = tl.program_id(0)
chunks = PAGE_SIZE_EL // BLOCK_SIZE
work_per_block = N_SEGS * chunks
block_index = pid // work_per_block
if block_index >= n_blocks:
return
remainder = pid % work_per_block
seg_index = remainder // chunks
chunk_index = remainder % chunks
block_id = tl.load(block_ids_ptr + block_index)
seg_addr = tl.load(seg_addrs_ptr + seg_index)
ptr = tl.cast(seg_addr, tl.pointer_type(tl.int32))
offset = (
block_id.to(tl.int64) * PAGE_SIZE_EL + chunk_index.to(tl.int64) * BLOCK_SIZE
)
cols = tl.arange(0, BLOCK_SIZE).to(tl.int64)
tl.store(ptr + offset + cols, tl.zeros([BLOCK_SIZE], dtype=tl.int32))
class KVBlockZeroer:
"""Manages efficient zeroing of KV cache blocks via a Triton kernel.
Call :meth:`init_meta` once after KV caches are allocated to precompute
segment addresses, then call :meth:`zero_block_ids` each step to zero
newly-allocated blocks.
"""
def __init__(self, device: torch.device, pin_memory: bool):
self.device = device
self.pin_memory = pin_memory
self._meta: tuple[torch.Tensor, int, int, int] | None = None
self._id_cap: int = 0
self._ids_pinned: torch.Tensor | None = None
self._ids_gpu: torch.Tensor | None = None
def init_meta(
self,
attn_groups_iter: Iterable["AttentionGroup"],
kernel_block_sizes: list[int],
cache_dtype: str,
runner_only_attn_layers: set[str],
static_forward_context: dict[str, Any],
) -> None:
"""One-time precomputation for zero_block_ids.
Builds absolute-address table for the Triton zeroing kernel.
Each entry is the absolute byte address of a segment start on the
GPU, so segments in different CUDA allocations work correctly.
Block IDs from the scheduler reference logical blocks whose size
may differ from the kernel block size (virtual block splitting).
PAGE_SIZE_EL accounts for this ratio so that
``block_id * PAGE_SIZE_EL`` lands at the correct offset.
Only AttentionSpec layers are processed; Mamba layers are skipped.
"""
seen_ptrs: set[int] = set()
seg_addrs: list[int] = []
page_size_el: int | None = None
for group in attn_groups_iter:
spec = group.kv_cache_spec
if type(spec) is not FullAttentionSpec:
continue
if group.kv_cache_group_id >= len(kernel_block_sizes):
continue
kernel_bs = kernel_block_sizes[group.kv_cache_group_id]
ratio = spec.block_size // kernel_bs
block_dim = group.backend.get_kv_cache_block_dim(
kernel_bs,
spec.num_kv_heads,
spec.head_size,
cache_dtype_str=cache_dtype,
)
for layer_name in group.layer_names:
if layer_name in runner_only_attn_layers:
continue
kv = static_forward_context[layer_name].kv_cache[0]
if isinstance(kv, list):
continue
dp = kv.data_ptr()
if dp in seen_ptrs:
continue
seen_ptrs.add(dp)
el = kv.element_size()
cur_bytes = kv.stride(block_dim) * el
assert cur_bytes % 4 == 0
kernel_block_el = cur_bytes // 4
cur_page_el = kernel_block_el * ratio
if page_size_el is None:
page_size_el = cur_page_el
else:
assert page_size_el == cur_page_el, (
f"Non-uniform page sizes: {page_size_el} vs {cur_page_el}"
)
block_stride_bytes = cur_bytes
outer_dims = [
d
for d in range(block_dim)
if kv.stride(d) * el > block_stride_bytes
]
outer_strides = [kv.stride(d) * el for d in outer_dims]
for outer in iprod(*(range(kv.shape[d]) for d in outer_dims)):
off_bytes = sum(i * s for i, s in zip(outer, outer_strides))
seg_addrs.append(dp + off_bytes)
if not seg_addrs or page_size_el is None:
self._meta = None
return
blk_size = min(largest_power_of_2_divisor(page_size_el), 1024)
self._id_cap = 8192
self._ids_pinned = torch.empty(
self._id_cap,
dtype=torch.int64,
pin_memory=self.pin_memory,
)
self._ids_gpu = torch.empty(self._id_cap, dtype=torch.int64, device=self.device)
self._meta = (
torch.tensor(seg_addrs, dtype=torch.int64, device=self.device),
page_size_el,
blk_size,
len(seg_addrs),
)
def zero_block_ids(self, block_ids: list[int]) -> None:
"""Zero the KV cache memory for the given block IDs."""
if not block_ids or self._meta is None:
return
seg_addrs, page_size_el, blk_size, n_segs = self._meta
n_blocks = len(block_ids)
if n_blocks > self._id_cap:
self._id_cap = n_blocks * 2
self._ids_pinned = torch.empty(
self._id_cap,
dtype=torch.int64,
pin_memory=self.pin_memory,
)
self._ids_gpu = torch.empty(
self._id_cap, dtype=torch.int64, device=self.device
)
assert self._ids_pinned is not None and self._ids_gpu is not None
self._ids_pinned[:n_blocks].numpy()[:] = block_ids
idx = self._ids_gpu[:n_blocks]
idx.copy_(self._ids_pinned[:n_blocks], non_blocking=True)
grid = (n_blocks * n_segs * (page_size_el // blk_size),)
_zero_kv_blocks_kernel[grid](
seg_addrs,
idx,
n_blocks,
N_SEGS=n_segs,
PAGE_SIZE_EL=page_size_el,
BLOCK_SIZE=blk_size,
)
@dataclass @dataclass
class AttentionGroup: class AttentionGroup:
backend: type[AttentionBackend] backend: type[AttentionBackend]