Compare commits
4 Commits
v0.15.1rc1
...
v0.11.2
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
275de34170 | ||
|
|
fa3ffb4365 | ||
|
|
6d5974369c | ||
|
|
0ce9990d2c |
@@ -512,9 +512,9 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
# The cutlass_scaled_mm kernels for Blackwell SM100 (c3x, i.e. CUTLASS 3.x)
|
||||
# require CUDA 12.8 or later
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
|
||||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0f;11.0f;12.0f" "${CUDA_ARCHS}")
|
||||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0f;11.0f" "${CUDA_ARCHS}")
|
||||
else()
|
||||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a;10.1a;10.3a;12.0a;12.1a" "${CUDA_ARCHS}")
|
||||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a;10.1a;10.3a" "${CUDA_ARCHS}")
|
||||
endif()
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
|
||||
set(SRCS
|
||||
@@ -619,9 +619,9 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
|
||||
# FP4 Archs and flags
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
|
||||
cuda_archs_loose_intersection(FP4_ARCHS "10.0f;11.0f;12.0f" "${CUDA_ARCHS}")
|
||||
cuda_archs_loose_intersection(FP4_ARCHS "10.0f;11.0f" "${CUDA_ARCHS}")
|
||||
else()
|
||||
cuda_archs_loose_intersection(FP4_ARCHS "10.0a;10.1a;12.0a;12.1a" "${CUDA_ARCHS}")
|
||||
cuda_archs_loose_intersection(FP4_ARCHS "10.0a;10.1a;10.3a" "${CUDA_ARCHS}")
|
||||
endif()
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND FP4_ARCHS)
|
||||
set(SRCS
|
||||
@@ -695,7 +695,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
|
||||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0f;11.0f" "${CUDA_ARCHS}")
|
||||
else()
|
||||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a" "${CUDA_ARCHS}")
|
||||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a;10.1a;10.3a" "${CUDA_ARCHS}")
|
||||
endif()
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
|
||||
set(SRCS "csrc/quantization/w8a8/cutlass/moe/grouped_mm_c3x_sm100.cu")
|
||||
@@ -741,9 +741,9 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
endif()
|
||||
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
|
||||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0f;11.0f;12.0f" "${CUDA_ARCHS}")
|
||||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0f;11.0f" "${CUDA_ARCHS}")
|
||||
else()
|
||||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a;10.1a;10.3a;12.0a;12.1a" "${CUDA_ARCHS}")
|
||||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a;10.1a;10.3a" "${CUDA_ARCHS}")
|
||||
endif()
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
|
||||
set(SRCS "csrc/quantization/w8a8/cutlass/moe/blockwise_scaled_group_mm_sm100.cu")
|
||||
|
||||
@@ -917,7 +917,7 @@ class CompilationConfig:
|
||||
self, uniform_decode_query_len: int, tensor_parallel_size: int
|
||||
):
|
||||
multiple_of = uniform_decode_query_len
|
||||
if tensor_parallel_size > 1:
|
||||
if tensor_parallel_size > 1 and self.pass_config.enable_sequence_parallelism:
|
||||
multiple_of = max(uniform_decode_query_len, tensor_parallel_size)
|
||||
if (
|
||||
multiple_of % uniform_decode_query_len != 0
|
||||
|
||||
@@ -755,6 +755,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
seq_lens = common_attn_metadata.seq_lens
|
||||
seq_lens_cpu = common_attn_metadata.seq_lens_cpu
|
||||
dcp_local_seq_lens = common_attn_metadata.dcp_local_seq_lens
|
||||
dcp_local_seq_lens_cpu = common_attn_metadata.dcp_local_seq_lens_cpu
|
||||
|
||||
query_seq_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
|
||||
|
||||
@@ -944,18 +945,20 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
|
||||
decode_metadata = None
|
||||
if num_decodes > 0:
|
||||
dcp_tot_seq_lens_device = None
|
||||
if self.dcp_world_size > 1:
|
||||
dcp_tot_seq_lens_device = seq_lens[:num_decodes]
|
||||
seq_lens_cpu = dcp_local_seq_lens_cpu
|
||||
seq_lens = dcp_local_seq_lens
|
||||
|
||||
decode_metadata = self._build_decode(
|
||||
block_table_tensor=block_table_tensor[:num_decodes, ...],
|
||||
seq_lens_cpu=seq_lens_cpu[:num_decodes],
|
||||
seq_lens_device=dcp_local_seq_lens[:num_decodes]
|
||||
if self.dcp_world_size > 1 and dcp_local_seq_lens is not None
|
||||
else seq_lens[:num_decodes],
|
||||
seq_lens_device=seq_lens[:num_decodes],
|
||||
query_start_loc_cpu=query_start_loc_cpu[: num_decodes + 1],
|
||||
query_start_loc_device=query_start_loc[: num_decodes + 1],
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
dcp_tot_seq_lens_device=seq_lens[:num_decodes]
|
||||
if self.dcp_world_size > 1
|
||||
else None,
|
||||
dcp_tot_seq_lens_device=dcp_tot_seq_lens_device,
|
||||
)
|
||||
|
||||
attn_metadata = self.metadata_cls(
|
||||
|
||||
@@ -173,7 +173,7 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]
|
||||
) -> FlashAttnMLADecodeMetadata:
|
||||
query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
|
||||
max_query_len = query_lens_cpu.max().item()
|
||||
max_seq_len = seq_lens_device.max().item()
|
||||
max_seq_len = seq_lens_cpu.max().item()
|
||||
|
||||
# For Flash Attention MLA + full cudagraph
|
||||
max_num_splits = 0
|
||||
|
||||
@@ -92,6 +92,7 @@ class CommonAttentionMetadata:
|
||||
encoder_seq_lens: np.ndarray | None = None
|
||||
|
||||
dcp_local_seq_lens: torch.Tensor | None = None
|
||||
dcp_local_seq_lens_cpu: torch.Tensor | None = None
|
||||
"""Sequence lengths of the local rank in decode context parallelism world"""
|
||||
|
||||
|
||||
|
||||
@@ -1450,9 +1450,12 @@ class GPUModelRunner(
|
||||
num_computed_tokens_cpu = self.input_batch.num_computed_tokens_cpu_tensor[
|
||||
:num_reqs
|
||||
]
|
||||
dcp_local_seq_lens = (
|
||||
self.dcp_local_seq_lens.gpu[:num_reqs] if self.dcp_world_size > 1 else None
|
||||
)
|
||||
|
||||
dcp_local_seq_lens, dcp_local_seq_lens_cpu = None, None
|
||||
if self.dcp_world_size > 1:
|
||||
dcp_local_seq_lens = self.dcp_local_seq_lens.gpu[:num_reqs]
|
||||
dcp_local_seq_lens_cpu = self.dcp_local_seq_lens.cpu[:num_reqs]
|
||||
|
||||
spec_decode_common_attn_metadata = None
|
||||
|
||||
if for_cudagraph_capture:
|
||||
@@ -1520,6 +1523,7 @@ class GPUModelRunner(
|
||||
causal=True,
|
||||
encoder_seq_lens=encoder_seq_lens,
|
||||
dcp_local_seq_lens=dcp_local_seq_lens,
|
||||
dcp_local_seq_lens_cpu=dcp_local_seq_lens_cpu,
|
||||
)
|
||||
|
||||
if self.speculative_config and spec_decode_common_attn_metadata is None:
|
||||
|
||||
@@ -204,14 +204,14 @@ class Worker(WorkerBase):
|
||||
assert self.local_rank < torch.cuda.device_count(), (
|
||||
f"DP adjusted local rank {self.local_rank} is out of bounds. "
|
||||
)
|
||||
visible_device_count = (
|
||||
torch.cuda.device_count() if torch.cuda.is_available() else 0
|
||||
)
|
||||
assert self.parallel_config.local_world_size <= visible_device_count, (
|
||||
f"local_world_size ({self.parallel_config.local_world_size}) must be "
|
||||
f"less than or equal to the number of visible devices "
|
||||
f"({visible_device_count})."
|
||||
)
|
||||
visible_device_count = (
|
||||
torch.cuda.device_count() if torch.cuda.is_available() else 0
|
||||
)
|
||||
assert self.parallel_config.local_world_size <= visible_device_count, (
|
||||
f"local_world_size ({self.parallel_config.local_world_size}) must "
|
||||
f"be less than or equal to the number of visible devices "
|
||||
f"({visible_device_count})."
|
||||
)
|
||||
self.device = torch.device(f"cuda:{self.local_rank}")
|
||||
current_platform.set_device(self.device)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user