diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 85d0744db..a480eeff0 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -55,7 +55,7 @@ repos: language: python types_or: [python, pyi] require_serial: true - additional_dependencies: [mypy==1.11.1, regex, types-cachetools, types-setuptools, types-PyYAML, types-requests, types-torch, pydantic] + additional_dependencies: ["mypy[faster-cache]==1.15.0", regex, types-cachetools, types-setuptools, types-PyYAML, types-requests, types-torch, pydantic] - id: mypy-3.10 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward name: Run mypy for Python 3.10 entry: python tools/pre_commit/mypy.py 1 "3.10" diff --git a/tests/kernels/core/test_pos_encoding.py b/tests/kernels/core/test_pos_encoding.py index b43e1dab4..5094a29c5 100644 --- a/tests/kernels/core/test_pos_encoding.py +++ b/tests/kernels/core/test_pos_encoding.py @@ -94,12 +94,9 @@ def test_rotary_embedding( positions = torch.randint(0, max_position, (batch_size, seq_len)) query_shape = tensor_shape_fn(batch_size, seq_len, num_heads, head_size) - query = torch.randn(query_shape, dtype=dtype) - key = torch.randn_like(query) if use_key else None - # slice tensor if required, noop otherwise - query = query[..., :head_size] - key = key[..., :head_size] if use_key else None + query = torch.randn(query_shape, dtype=dtype)[..., :head_size] + key = torch.randn_like(query)[..., :head_size] if use_key else None # NOTE(woosuk): The reference implementation should be executed first # because the custom kernel is in-place. diff --git a/tests/kernels/core/test_rotary_embedding.py b/tests/kernels/core/test_rotary_embedding.py index 912a422e0..6cdd94fdc 100644 --- a/tests/kernels/core/test_rotary_embedding.py +++ b/tests/kernels/core/test_rotary_embedding.py @@ -62,7 +62,7 @@ def test_rotary_embedding_opcheck( ) key = torch.randn_like(query) if use_key else None query = query[..., :head_size] - key = key[..., :head_size] if use_key else None + key = key[..., :head_size] if key is not None else None rotary_embedding_opcheck(rot, positions, query, key) @@ -73,5 +73,5 @@ def test_rotary_embedding_opcheck( rot, positions, query.flatten(start_dim=-2), - key.flatten(start_dim=-2) if use_key else None, + key.flatten(start_dim=-2) if key is not None else None, ) diff --git a/tests/kernels/mamba/test_mamba_ssm.py b/tests/kernels/mamba/test_mamba_ssm.py index 9a00e1d04..e8cbba29f 100644 --- a/tests/kernels/mamba/test_mamba_ssm.py +++ b/tests/kernels/mamba/test_mamba_ssm.py @@ -298,13 +298,13 @@ def test_selective_scan( C = torch.randn(C_shape, device=device, dtype=wtype if not is_variable_C else itype) C_ref = C.clone() D = torch.randn(dim, device=device, dtype=torch.float32) if has_D else None - D_ref = D.clone() + D_ref = D.clone() if D is not None else None z = ( torch.randn(batch_size, dim, seqlen, device=device, dtype=itype) if has_z else None ) - z_ref = z.clone() if has_z else None + z_ref = z.clone() if z is not None else None delta_bias = ( (0.5 * torch.rand(dim, device=device, dtype=torch.float32)) if has_delta_bias @@ -493,7 +493,7 @@ def test_selective_state_update_varlen(dim, dstate, has_z, itype, max_seq_len): B[idx : idx + 1], C[idx : idx + 1], D=D, - z=z[idx : idx + 1] if has_z else None, + z=z[idx : idx + 1] if z is not None else None, dt_bias=dt_bias, dt_softplus=True, ) @@ -578,7 +578,7 @@ def test_selective_scan_varlen( C = torch.randn(C_shape, device=device, dtype=wtype if not is_variable_C else itype) C_ref = C.clone() D = torch.randn(dim, device=device, dtype=torch.float32) if has_D else None - D_ref = D.clone() + D_ref = D.clone() if D is not None else None z = torch.randn(dim, seqlen, device=device, dtype=itype) z_ref = z.clone() delta_bias = ( @@ -750,7 +750,7 @@ def test_selective_state_update_with_batch_indices( B[:batch_size], C[:batch_size], D=D, - z=z[:batch_size], + z=z[:batch_size] if z is not None else None, dt_bias=dt_bias, dt_softplus=True, ) @@ -934,7 +934,7 @@ def test_selective_state_update_with_num_accepted_tokens( B[global_idx : global_idx + 1], C[global_idx : global_idx + 1], D=D, - z=z[global_idx : global_idx + 1] if has_z else None, + z=z[global_idx : global_idx + 1] if z is not None else None, dt_bias=dt_bias, dt_softplus=True, ) @@ -1061,7 +1061,7 @@ def test_selective_state_update_varlen_with_num_accepted( B[global_idx : global_idx + 1], C[global_idx : global_idx + 1], D=D, - z=z[global_idx : global_idx + 1] if has_z else None, + z=z[global_idx : global_idx + 1] if z is not None else None, dt_bias=dt_bias, dt_softplus=True, ) diff --git a/tests/kernels/quantization/test_fp8_quant.py b/tests/kernels/quantization/test_fp8_quant.py index ce94d3397..cec6d37e1 100644 --- a/tests/kernels/quantization/test_fp8_quant.py +++ b/tests/kernels/quantization/test_fp8_quant.py @@ -57,11 +57,11 @@ def opcheck_fp8_quant( @pytest.mark.parametrize("num_tokens", NUM_TOKENS) @pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) @pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("scale_ub", SCALE_UBS) +@pytest.mark.parametrize("do_scale_ub", SCALE_UBS) @pytest.mark.parametrize("seed", SEEDS) @torch.inference_mode() def test_dynamic_per_token_fp8_quant( - num_tokens: int, hidden_size: int, dtype: torch.dtype, scale_ub: bool, seed: int + num_tokens: int, hidden_size: int, dtype: torch.dtype, do_scale_ub: bool, seed: int ) -> None: set_random_seed(seed) @@ -70,7 +70,7 @@ def test_dynamic_per_token_fp8_quant( ) # avoid nans scale_ub = ( - torch.mean(x).to(dtype=torch.float32, device="cuda") if scale_ub else None + torch.mean(x).to(dtype=torch.float32, device="cuda") if do_scale_ub else None ) ref_out, ref_scales = ref_dynamic_per_token_quant(x, FP8_DTYPE, scale_ub) ops_out, ops_scales = ops.scaled_fp8_quant( diff --git a/vllm/config/parallel.py b/vllm/config/parallel.py index 6b69198eb..8ec6af2aa 100644 --- a/vllm/config/parallel.py +++ b/vllm/config/parallel.py @@ -3,11 +3,11 @@ import os from collections.abc import Callable -from typing import TYPE_CHECKING, Any, Literal +from typing import TYPE_CHECKING, Any, Literal, overload import torch from pydantic import Field, field_validator, model_validator -from torch.distributed import ProcessGroup, ReduceOp +from torch.distributed import ProcessGroup, ReduceOp, Store from typing_extensions import Self import vllm.envs as envs @@ -507,7 +507,17 @@ class ParallelConfig: def get_next_stateless_eplb_group_port(self) -> list[int]: return self._stateless_eplb_group_port_list.pop() - def stateless_init_dp_group(self, return_store: bool = False) -> ProcessGroup: + @overload + def stateless_init_dp_group( + self, return_store: Literal[False] = ... + ) -> ProcessGroup: ... + @overload + def stateless_init_dp_group( + self, return_store: Literal[True] = ... + ) -> tuple[ProcessGroup, Store]: ... + def stateless_init_dp_group( + self, return_store: bool = False + ) -> ProcessGroup | tuple[ProcessGroup, Store]: # NOTE: In high-concurrency scenarios multiple processes # can pick the same (currently free) port through a race # condition when calling `get_open_port()`. When the first diff --git a/vllm/distributed/elastic_ep/elastic_state.py b/vllm/distributed/elastic_ep/elastic_state.py index 4845a16f1..fce0d8361 100644 --- a/vllm/distributed/elastic_ep/elastic_state.py +++ b/vllm/distributed/elastic_ep/elastic_state.py @@ -4,7 +4,7 @@ import enum import time import weakref from datetime import timedelta -from typing import TYPE_CHECKING, Literal +from typing import TYPE_CHECKING, Literal, TypeAlias import torch.distributed @@ -61,6 +61,14 @@ class ScaleDownRemovingEngineState(enum.IntEnum): COMPLETE = 2 +EngineState: TypeAlias = ( + ScaleUpExistingEngineState + | ScaleUpNewEngineState + | ScaleDownRemainingEngineState + | ScaleDownRemovingEngineState +) + + class _BarrierTimeoutError(RuntimeError): """ Exception raised for timeout @@ -87,14 +95,13 @@ class ElasticEPScalingState: self.old_dp_group = self.engine_core.dp_group if worker_type != "new" else None self.old_dp_store = self.engine_core.dp_store if worker_type != "new" else None self.new_parallel_config: ParallelConfig = new_parallel_config - self.new_dp_group: torch.distributed.ProcessGroup | None = ( - self.engine_core.dp_group if worker_type == "new" else None - ) + self.new_dp_group = self.engine_core.dp_group if worker_type == "new" else None self.new_dp_store = self.engine_core.dp_store if worker_type == "new" else None self.worker_type = worker_type self.scale_type = scale_type self.reconfig_request = reconfig_request + self.state: EngineState if scale_type == "scale_up": self.state = ( ScaleUpNewEngineState.PREPARE @@ -182,9 +189,9 @@ class ElasticEPScalingState: engine step, and will synchronize with the other EngineCores in the next step with a barrier without timeout. """ - dp_store = self.new_dp_store if use_new_group else self.old_dp_store dp_group = self.new_dp_group if use_new_group else self.old_dp_group - assert dp_group is not None + dp_store = self.new_dp_store if use_new_group else self.old_dp_store + assert dp_group is not None and dp_store is not None group_rank = dp_group.rank() group_size = dp_group.size() @@ -212,6 +219,7 @@ class ElasticEPScalingState: def _progress_existing_engine(self) -> bool: state = self.state + assert self.old_dp_group is not None and self.old_dp_store is not None if state == ScaleUpExistingEngineState.WAIT_NEW_CORE_ENGINES_INIT: return False @@ -265,11 +273,12 @@ class ElasticEPScalingState: elif state == ScaleUpExistingEngineState.SWITCH_AND_PREPARE: self._switch_and_prepare() self.state = ScaleUpExistingEngineState.EPLB_RESHUFFLE + assert self.new_dp_store is not None self.new_dp_store.add("eep_barrier_engine_count", 1) return True elif state == ScaleUpExistingEngineState.EPLB_RESHUFFLE: - assert self.new_dp_group is not None + assert self.new_dp_group is not None and self.new_dp_store is not None if ( int(self.new_dp_store.get("eep_barrier_engine_count")) < self.new_dp_group.size() @@ -292,7 +301,7 @@ class ElasticEPScalingState: def _progress_new_engine(self) -> bool: state = self.state - assert self.new_dp_group is not None + assert self.new_dp_group is not None and self.new_dp_store is not None if state == ScaleUpNewEngineState.PREPARE: tensor = torch.tensor([0, 0, 0], dtype=torch.int32, device="cpu") @@ -330,6 +339,7 @@ class ElasticEPScalingState: def _progress_remaining_engine(self) -> bool: state = self.state + assert self.old_dp_group is not None and self.old_dp_store is not None if state == ScaleDownRemainingEngineState.PREPARE: self.state = ScaleDownRemainingEngineState.EPLB_RESHUFFLE @@ -369,6 +379,7 @@ class ElasticEPScalingState: def _progress_removing_engine(self) -> bool: state = self.state + assert self.old_dp_group is not None and self.old_dp_store is not None if state == ScaleDownRemovingEngineState.PREPARE: self.state = ScaleDownRemovingEngineState.EPLB_RESHUFFLE @@ -401,6 +412,7 @@ class ElasticEPScalingState: def handle_notification(self, notification_type: EEPNotificationType): assert self.worker_type != "new" + assert self.old_dp_store is not None if ( notification_type == EEPNotificationType.NEW_CORE_ENGINES_INIT_READY and self.state == ScaleUpExistingEngineState.WAIT_NEW_CORE_ENGINES_INIT @@ -429,6 +441,7 @@ class ElasticEPScalingState: ) def _create_standby_groups(self): + assert self.old_dp_group is not None self.new_dp_group, self.new_dp_store = ( self.new_parallel_config.stateless_init_dp_group(return_store=True) ) @@ -439,7 +452,7 @@ class ElasticEPScalingState: logger.info("[Elastic EP] Created standby communication groups") def _transfer_weights(self): - assert self.reconfig_request is not None + assert self.reconfig_request is not None and self.old_dp_group is not None old_dp_size = self.old_dp_group.size() new_dp_size = self.reconfig_request.new_data_parallel_size @@ -450,6 +463,7 @@ class ElasticEPScalingState: logger.info("[Elastic EP] Transferred weights to new workers") def _transfer_expert_mapping(self): + assert self.old_dp_group is not None self.model_executor.collective_rpc( "elastic_ep_execute", args=("broadcast_expert_mapping",) ) @@ -458,7 +472,7 @@ class ElasticEPScalingState: def _sync_kv_cache_memory_size(self): assert self.engine_core.available_gpu_memory_for_kv_cache > 0 - assert self.new_dp_group is not None + assert self.new_dp_group is not None and self.old_dp_group is not None ParallelConfig.sync_kv_cache_memory_size( self.new_dp_group, self.engine_core.available_gpu_memory_for_kv_cache, @@ -507,7 +521,7 @@ class ElasticEPScalingState: logger.info("[Elastic EP] EPLB reshuffle completed") def _eplb_reshuffle_before_scale_down(self): - assert self.reconfig_request is not None + assert self.reconfig_request is not None and self.old_dp_group is not None self.model_executor.collective_rpc( "elastic_ep_execute", args=( diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py index f9367da73..fb6bbf7b5 100644 --- a/vllm/distributed/kv_transfer/kv_connector/utils.py +++ b/vllm/distributed/kv_transfer/kv_connector/utils.py @@ -336,6 +336,7 @@ class TpKVTopology: self._cross_layers_blocks = ( len(self.tensor_shape) == len(kv_cache_shape) + 1 ) + self.tensor_shape: torch.Size if self._cross_layers_blocks: logger.debug("Using cross-layer KV cache") diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 46e9d2cb5..091a98952 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -972,6 +972,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): # Early-out for cascade attention if use_cascade: + assert num_blocks_np is not None # Grab the blocks of the shared prefix from the first request. num_common_kv_blocks = common_prefix_len // page_size @@ -1117,6 +1118,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): max_seq_len=max_seq_len, ) else: + assert seq_lens_cpu is not None pure_decode = num_prefills == 0 use_cudagraph = ( self.enable_cuda_graph diff --git a/vllm/v1/attention/backends/gdn_attn.py b/vllm/v1/attention/backends/gdn_attn.py index 3f76f3e24..a2dd05b4b 100644 --- a/vllm/v1/attention/backends/gdn_attn.py +++ b/vllm/v1/attention/backends/gdn_attn.py @@ -88,14 +88,14 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata] self.num_spec: int = self.speculative_config.num_speculative_tokens else: self.num_spec = 0 - self.use_spec_decode = self.num_spec > 0 + self.use_spec_decode: bool = self.num_spec > 0 self._init_reorder_batch_threshold(1, self.use_spec_decode) - self.use_full_cuda_graph = ( + self.use_full_cuda_graph: bool = ( self.compilation_config.cudagraph_mode.has_full_cudagraphs() ) - self.decode_cudagraph_max_bs = ( + self.decode_cudagraph_max_bs: int = ( self.vllm_config.scheduler_config.max_num_seqs * (self.num_spec + 1) ) if self.compilation_config.max_cudagraph_capture_size is not None: @@ -104,42 +104,42 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata] self.compilation_config.max_cudagraph_capture_size, ) - self.spec_state_indices_tensor = torch.empty( + self.spec_state_indices_tensor: torch.Tensor = torch.empty( (self.decode_cudagraph_max_bs, self.num_spec + 1), dtype=torch.int32, device=device, ) - self.non_spec_state_indices_tensor = torch.empty( + self.non_spec_state_indices_tensor: torch.Tensor = torch.empty( (self.decode_cudagraph_max_bs,), dtype=torch.int32, device=device, ) - self.spec_sequence_masks = torch.empty( + self.spec_sequence_masks: torch.Tensor = torch.empty( (self.decode_cudagraph_max_bs,), dtype=torch.bool, device=device, ) - self.spec_token_indx = torch.empty( + self.spec_token_indx: torch.Tensor = torch.empty( (self.decode_cudagraph_max_bs * (self.num_spec + 1),), dtype=torch.int32, device=device, ) - self.non_spec_token_indx = torch.empty( + self.non_spec_token_indx: torch.Tensor = torch.empty( (self.decode_cudagraph_max_bs * (self.num_spec + 1),), dtype=torch.int32, device=device, ) - self.spec_query_start_loc = torch.empty( + self.spec_query_start_loc: torch.Tensor = torch.empty( (self.decode_cudagraph_max_bs + 1,), dtype=torch.int32, device=device, ) - self.non_spec_query_start_loc = torch.empty( + self.non_spec_query_start_loc: torch.Tensor = torch.empty( (self.decode_cudagraph_max_bs + 1,), dtype=torch.int32, device=device, ) - self.num_accepted_tokens = torch.empty( + self.num_accepted_tokens: torch.Tensor = torch.empty( (self.decode_cudagraph_max_bs,), dtype=torch.int32, device=device, @@ -322,6 +322,7 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata] and num_spec_decodes <= self.decode_cudagraph_max_bs and num_spec_decode_tokens <= self.decode_cudagraph_max_bs ): + assert spec_sequence_masks is not None self.spec_state_indices_tensor[:num_spec_decodes].copy_( spec_state_indices_tensor, non_blocking=True ) diff --git a/vllm/v1/attention/backends/mamba_attn.py b/vllm/v1/attention/backends/mamba_attn.py index 27c9b85eb..f9105474e 100644 --- a/vllm/v1/attention/backends/mamba_attn.py +++ b/vllm/v1/attention/backends/mamba_attn.py @@ -98,8 +98,8 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC): self.use_spec_decode = self.num_spec_tokens > 0 assert isinstance(kv_cache_spec, MambaSpec) - self.compilation_config = vllm_config.compilation_config - self.decode_cudagraph_max_bs = self.vllm_config.scheduler_config.max_num_seqs + scheduler_config = vllm_config.scheduler_config + self.decode_cudagraph_max_bs: int = scheduler_config.max_num_seqs if self.compilation_config.max_cudagraph_capture_size is not None: self.decode_cudagraph_max_bs = min( self.decode_cudagraph_max_bs, @@ -114,7 +114,7 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC): # Speculative decoding not supported with prefix caching, # so keep shape consistent with prefill buffer # TODO: reduce this size as needed for decode-only cudagraph capture - self.state_indices_tensor_d = torch.empty( + self.state_indices_tensor_d: torch.Tensor = torch.empty( ( self.decode_cudagraph_max_bs, max_num_blocks, @@ -122,12 +122,12 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC): dtype=torch.int32, device=device, ) - self.block_idx_last_scheduled_token = torch.empty( + self.block_idx_last_scheduled_token: torch.Tensor = torch.empty( (self.decode_cudagraph_max_bs,), dtype=torch.int32, device=device, ) - self.block_idx_last_computed_token = torch.empty( + self.block_idx_last_computed_token: torch.Tensor = torch.empty( (self.decode_cudagraph_max_bs,), dtype=torch.int32, device=device, @@ -142,7 +142,7 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC): # For speculative decoding, we need to store the following buffers # for CUDA graph capture during decode if self.num_spec_tokens > 0: - self.decode_num_accepted_tokens = torch.empty( + self.decode_num_accepted_tokens: torch.Tensor = torch.empty( (self.decode_cudagraph_max_bs,), dtype=torch.int32, device=device, diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index e63c55427..9b70e4a9c 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -1539,18 +1539,18 @@ class DPEngineCoreProc(EngineCoreProc): def _init_data_parallel(self, vllm_config: VllmConfig): # Configure GPUs and stateless process group for data parallel. - dp_rank = vllm_config.parallel_config.data_parallel_rank - dp_size = vllm_config.parallel_config.data_parallel_size - local_dp_rank = vllm_config.parallel_config.data_parallel_rank_local + parallel_config = vllm_config.parallel_config + dp_rank = parallel_config.data_parallel_rank + dp_size = parallel_config.data_parallel_size + local_dp_rank = parallel_config.data_parallel_rank_local assert dp_size > 1 assert local_dp_rank is not None assert 0 <= local_dp_rank <= dp_rank < dp_size self.dp_rank = dp_rank - self.dp_group, self.dp_store = ( - vllm_config.parallel_config.stateless_init_dp_group(return_store=True) - ) + dp_group, dp_store = parallel_config.stateless_init_dp_group(return_store=True) + self.dp_group, self.dp_store = dp_group, dp_store def shutdown(self): super().shutdown() diff --git a/vllm/v1/sample/logits_processor/__init__.py b/vllm/v1/sample/logits_processor/__init__.py index 693f7b125..2cb89e1ea 100644 --- a/vllm/v1/sample/logits_processor/__init__.py +++ b/vllm/v1/sample/logits_processor/__init__.py @@ -309,12 +309,16 @@ class AdapterLogitsProcessor(LogitsProcessor): """ if req_lp := self.new_req_logits_processor(params): - args = ( - [prompt_ids, output_ids] - if (len(inspect.signature(req_lp).parameters) == 3) - else [output_ids] - ) - return partial(req_lp, *args) # type: ignore[misc] + if len(inspect.signature(req_lp).parameters) == 3: + if prompt_ids is None: + raise ValueError( + "Prompt token ids are required for this " + "logits processor but were not provided." + ) + args = [prompt_ids, output_ids] + else: + args = [output_ids] + return partial(req_lp, *args) return None def update_state(self, batch_update: BatchUpdate | None):