[CI] Bump mypy version (#34950)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -972,6 +972,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
|
||||
# Early-out for cascade attention
|
||||
if use_cascade:
|
||||
assert num_blocks_np is not None
|
||||
# Grab the blocks of the shared prefix from the first request.
|
||||
num_common_kv_blocks = common_prefix_len // page_size
|
||||
|
||||
@@ -1117,6 +1118,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
max_seq_len=max_seq_len,
|
||||
)
|
||||
else:
|
||||
assert seq_lens_cpu is not None
|
||||
pure_decode = num_prefills == 0
|
||||
use_cudagraph = (
|
||||
self.enable_cuda_graph
|
||||
|
||||
@@ -88,14 +88,14 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
|
||||
self.num_spec: int = self.speculative_config.num_speculative_tokens
|
||||
else:
|
||||
self.num_spec = 0
|
||||
self.use_spec_decode = self.num_spec > 0
|
||||
self.use_spec_decode: bool = self.num_spec > 0
|
||||
self._init_reorder_batch_threshold(1, self.use_spec_decode)
|
||||
|
||||
self.use_full_cuda_graph = (
|
||||
self.use_full_cuda_graph: bool = (
|
||||
self.compilation_config.cudagraph_mode.has_full_cudagraphs()
|
||||
)
|
||||
|
||||
self.decode_cudagraph_max_bs = (
|
||||
self.decode_cudagraph_max_bs: int = (
|
||||
self.vllm_config.scheduler_config.max_num_seqs * (self.num_spec + 1)
|
||||
)
|
||||
if self.compilation_config.max_cudagraph_capture_size is not None:
|
||||
@@ -104,42 +104,42 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
|
||||
self.compilation_config.max_cudagraph_capture_size,
|
||||
)
|
||||
|
||||
self.spec_state_indices_tensor = torch.empty(
|
||||
self.spec_state_indices_tensor: torch.Tensor = torch.empty(
|
||||
(self.decode_cudagraph_max_bs, self.num_spec + 1),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
self.non_spec_state_indices_tensor = torch.empty(
|
||||
self.non_spec_state_indices_tensor: torch.Tensor = torch.empty(
|
||||
(self.decode_cudagraph_max_bs,),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
self.spec_sequence_masks = torch.empty(
|
||||
self.spec_sequence_masks: torch.Tensor = torch.empty(
|
||||
(self.decode_cudagraph_max_bs,),
|
||||
dtype=torch.bool,
|
||||
device=device,
|
||||
)
|
||||
self.spec_token_indx = torch.empty(
|
||||
self.spec_token_indx: torch.Tensor = torch.empty(
|
||||
(self.decode_cudagraph_max_bs * (self.num_spec + 1),),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
self.non_spec_token_indx = torch.empty(
|
||||
self.non_spec_token_indx: torch.Tensor = torch.empty(
|
||||
(self.decode_cudagraph_max_bs * (self.num_spec + 1),),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
self.spec_query_start_loc = torch.empty(
|
||||
self.spec_query_start_loc: torch.Tensor = torch.empty(
|
||||
(self.decode_cudagraph_max_bs + 1,),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
self.non_spec_query_start_loc = torch.empty(
|
||||
self.non_spec_query_start_loc: torch.Tensor = torch.empty(
|
||||
(self.decode_cudagraph_max_bs + 1,),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
self.num_accepted_tokens = torch.empty(
|
||||
self.num_accepted_tokens: torch.Tensor = torch.empty(
|
||||
(self.decode_cudagraph_max_bs,),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
@@ -322,6 +322,7 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
|
||||
and num_spec_decodes <= self.decode_cudagraph_max_bs
|
||||
and num_spec_decode_tokens <= self.decode_cudagraph_max_bs
|
||||
):
|
||||
assert spec_sequence_masks is not None
|
||||
self.spec_state_indices_tensor[:num_spec_decodes].copy_(
|
||||
spec_state_indices_tensor, non_blocking=True
|
||||
)
|
||||
|
||||
@@ -98,8 +98,8 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
|
||||
self.use_spec_decode = self.num_spec_tokens > 0
|
||||
|
||||
assert isinstance(kv_cache_spec, MambaSpec)
|
||||
self.compilation_config = vllm_config.compilation_config
|
||||
self.decode_cudagraph_max_bs = self.vllm_config.scheduler_config.max_num_seqs
|
||||
scheduler_config = vllm_config.scheduler_config
|
||||
self.decode_cudagraph_max_bs: int = scheduler_config.max_num_seqs
|
||||
if self.compilation_config.max_cudagraph_capture_size is not None:
|
||||
self.decode_cudagraph_max_bs = min(
|
||||
self.decode_cudagraph_max_bs,
|
||||
@@ -114,7 +114,7 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
|
||||
# Speculative decoding not supported with prefix caching,
|
||||
# so keep shape consistent with prefill buffer
|
||||
# TODO: reduce this size as needed for decode-only cudagraph capture
|
||||
self.state_indices_tensor_d = torch.empty(
|
||||
self.state_indices_tensor_d: torch.Tensor = torch.empty(
|
||||
(
|
||||
self.decode_cudagraph_max_bs,
|
||||
max_num_blocks,
|
||||
@@ -122,12 +122,12 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
self.block_idx_last_scheduled_token = torch.empty(
|
||||
self.block_idx_last_scheduled_token: torch.Tensor = torch.empty(
|
||||
(self.decode_cudagraph_max_bs,),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
self.block_idx_last_computed_token = torch.empty(
|
||||
self.block_idx_last_computed_token: torch.Tensor = torch.empty(
|
||||
(self.decode_cudagraph_max_bs,),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
@@ -142,7 +142,7 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
|
||||
# For speculative decoding, we need to store the following buffers
|
||||
# for CUDA graph capture during decode
|
||||
if self.num_spec_tokens > 0:
|
||||
self.decode_num_accepted_tokens = torch.empty(
|
||||
self.decode_num_accepted_tokens: torch.Tensor = torch.empty(
|
||||
(self.decode_cudagraph_max_bs,),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
|
||||
Reference in New Issue
Block a user