diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 740be2bc8..04d7cdc3d 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -31,16 +31,6 @@ steps: ##### fast check tests ##### -- label: Documentation Build # 2min - mirror_hardwares: [amdexperimental] - working_dir: "/vllm-workspace/test_docs" - fast_check: true - no_gpu: True - commands: - - pip install -r ../requirements/docs.txt - # TODO: add `--strict` once warnings in docstrings are fixed - - mkdocs build - - label: Pytorch Nightly Dependency Override Check # 2min # if this test fails, it means the nightly torch version is not compatible with some # of the dependencies. Please check the error message and add the package to whitelist @@ -669,6 +659,7 @@ steps: - pytest -v -s tests/kernels/quantization/test_cutlass_scaled_mm.py -k 'fp8' - pytest -v -s tests/kernels/quantization/test_nvfp4_quant.py - pytest -v -s tests/kernels/quantization/test_nvfp4_scaled_mm.py + - pytest -v -s tests/kernels/quantization/test_flashinfer_nvfp4_scaled_mm.py - pytest -v -s tests/kernels/moe/test_nvfp4_moe.py # Fusion - pytest -v -s tests/compile/test_fusion_all_reduce.py diff --git a/.gitignore b/.gitignore index 721dd7536..465935d48 100644 --- a/.gitignore +++ b/.gitignore @@ -207,3 +207,6 @@ shellcheck*/ # Ignore moe/marlin_moe gen code csrc/moe/marlin_moe_wna16/kernel_* + +# Ignore ep_kernels_workspace folder +ep_kernels_workspace/ \ No newline at end of file diff --git a/CMakeLists.txt b/CMakeLists.txt index 43b9d9b64..471fb2ad3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -249,7 +249,6 @@ set(VLLM_EXT_SRC "csrc/quantization/gguf/gguf_kernel.cu" "csrc/quantization/activation_kernels.cu" "csrc/cuda_utils_kernels.cu" - "csrc/prepare_inputs/advance_step.cu" "csrc/custom_all_reduce.cu" "csrc/torch_bindings.cpp") @@ -352,6 +351,10 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") set_gencode_flags_for_srcs( SRCS "${MARLIN_TEMPLATE_KERNEL_SRC}" CUDA_ARCHS "${MARLIN_ARCHS}") + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8) + set_source_files_properties(${MARLIN_TEMPLATE_KERNEL_SRC} + PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false") + endif() list(APPEND VLLM_EXT_SRC ${MARLIN_TEMPLATE_KERNEL_SRC}) @@ -365,7 +368,12 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") set_gencode_flags_for_srcs( SRCS "${MARLIN_SRCS}" CUDA_ARCHS "${MARLIN_ARCHS}") + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8) + set_source_files_properties("csrc/quantization/gptq_marlin/gptq_marlin.cu" + PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false") + endif() list(APPEND VLLM_EXT_SRC "${MARLIN_SRCS}") + message(STATUS "Building Marlin kernels for archs: ${MARLIN_ARCHS}") else() message(STATUS "Not building Marlin kernels as no compatible archs found" @@ -855,6 +863,10 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") set_gencode_flags_for_srcs( SRCS "${MOE_WNAA16_MARLIN_SRC}" CUDA_ARCHS "${MARLIN_MOE_ARCHS}") + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8) + set_source_files_properties(${MOE_WNAA16_MARLIN_SRC} + PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false") + endif() list(APPEND VLLM_MOE_EXT_SRC ${MOE_WNAA16_MARLIN_SRC}) diff --git a/benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py b/benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py index f62d8102e..904f80534 100644 --- a/benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py +++ b/benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py @@ -1,63 +1,199 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import argparse +import asyncio +import logging import os import aiohttp -from quart import Quart, make_response, request +from quart import Quart, Response, make_response, request +from rate_limiter import RateLimiter +from request_queue import RequestQueue -AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60) - -app = Quart(__name__) +# Configure logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) -async def forward_request(url, data): - async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: +def parse_args(): + """parse command line arguments""" + parser = argparse.ArgumentParser(description="vLLM P/D disaggregation proxy server") + + # Add args + parser.add_argument( + "--timeout", + type=float, + default=300, + help="Timeout for backend service requests in seconds (default: 300)", + ) + parser.add_argument( + "--max-concurrent", + type=int, + default=100, + help="Maximum concurrent requests to backend services (default: 100)", + ) + parser.add_argument( + "--queue-size", + type=int, + default=500, + help="Maximum number of requests in the queue (default: 500)", + ) + parser.add_argument( + "--rate-limit", + type=int, + default=40, + help="Maximum requests per second (default: 40)", + ) + parser.add_argument( + "--port", + type=int, + default=8000, + help="Port to run the server on (default: 8000)", + ) + parser.add_argument( + "--prefill-url", + type=str, + default="http://localhost:8100/v1/completions", + help="Prefill service endpoint URL", + ) + parser.add_argument( + "--decode-url", + type=str, + default="http://localhost:8200/v1/completions", + help="Decode service endpoint URL", + ) + + return parser.parse_args() + + +def main(): + """parse command line arguments""" + args = parse_args() + + # Initialize configuration using command line parameters + AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=args.timeout) + MAX_CONCURRENT_REQUESTS = args.max_concurrent + REQUEST_QUEUE_SIZE = args.queue_size + RATE_LIMIT = args.rate_limit + PREFILL_SERVICE_URL = args.prefill_url + DECODE_SERVICE_URL = args.decode_url + PORT = args.port + + app = Quart(__name__) + + # Initialize the rate limiter and request queue + rate_limiter = RateLimiter(RATE_LIMIT) + request_queue = RequestQueue(MAX_CONCURRENT_REQUESTS, REQUEST_QUEUE_SIZE) + + # Attach the configuration object to the application instance + app.config.update( + { + "AIOHTTP_TIMEOUT": AIOHTTP_TIMEOUT, + "rate_limiter": rate_limiter, + "request_queue": request_queue, + "PREFILL_SERVICE_URL": PREFILL_SERVICE_URL, + "DECODE_SERVICE_URL": DECODE_SERVICE_URL, + } + ) + + # Start queue processing on app startup + @app.before_serving + async def startup(): + """Start request processing task when app starts serving""" + asyncio.create_task(request_queue.process()) + + async def forward_request(url, data): + """Forward request to backend service with rate limiting and error handling""" headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"} - async with session.post(url=url, json=data, headers=headers) as response: - if response.status == 200: - # if response.headers.get('Transfer-Encoding') == 'chunked': - if True: - async for chunk_bytes in response.content.iter_chunked(1024): - yield chunk_bytes - else: - content = await response.read() - yield content - -@app.route("/v1/completions", methods=["POST"]) -async def handle_request(): - try: - original_request_data = await request.get_json() - - prefill_request = original_request_data.copy() - # change max_tokens = 1 to let it only do prefill - prefill_request["max_tokens"] = 1 - - # finish prefill - async for _ in forward_request( - "http://localhost:8100/v1/completions", prefill_request + # Use rate limiter as context manager + async with ( + rate_limiter, + aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session, ): - continue + try: + async with session.post( + url=url, json=data, headers=headers + ) as response: + if response.status == 200: + # Stream response chunks + async for chunk_bytes in response.content.iter_chunked(1024): + yield chunk_bytes + else: + # Handle backend service errors + error_text = await response.text() + logger.error( + "Backend service error: %s - %s", + response.status, + error_text, + ) + yield b'{"error": "Backend service error"}' + except aiohttp.ClientError as e: + # Handle connection errors + logger.error("Connection error to %s: %s", url, str(e)) + yield b'{"error": "Service unavailable"}' + except asyncio.TimeoutError: + # Handle timeout errors + logger.error("Timeout connecting to %s", url) + yield b'{"error": "Service timeout"}' - # return decode - generator = forward_request( - "http://localhost:8200/v1/completions", original_request_data - ) - response = await make_response(generator) - response.timeout = None + async def process_request(): + """Process a single request through prefill and decode stages""" + try: + original_request_data = await request.get_json() - return response + # Create prefill request (max_tokens=1) + prefill_request = original_request_data.copy() + prefill_request["max_tokens"] = 1 - except Exception as e: - import sys - import traceback + # Execute prefill stage + async for _ in forward_request(PREFILL_SERVICE_URL, prefill_request): + continue - exc_info = sys.exc_info() - print("Error occurred in disagg prefill proxy server") - print(e) - print("".join(traceback.format_exception(*exc_info))) + # Execute decode stage and stream response + generator = forward_request(DECODE_SERVICE_URL, original_request_data) + response = await make_response(generator) + response.timeout = None # Disable timeout for streaming response + return response + + except Exception: + logger.exception("Error processing request") + return Response( + response=b'{"error": "Internal server error"}', + status=500, + content_type="application/json", + ) + + @app.route("/v1/completions", methods=["POST"]) + async def handle_request(): + """Handle incoming API requests with concurrency and rate limiting""" + # Create task for request processing + task = asyncio.create_task(process_request()) + + # Enqueue request or reject if queue is full + if not await request_queue.enqueue(task): + return Response( + response=b'{"error": "Server busy, try again later"}', + status=503, + content_type="application/json", + ) + + try: + # Return the response from the processing task + return await task + except asyncio.CancelledError: + # Handle task cancellation (timeout or queue full) + logger.warning("Request cancelled due to timeout or queue full") + return Response( + response=b'{"error": "Request cancelled"}', + status=503, + content_type="application/json", + ) + + # Start the Quart server with host can be set to 0.0.0.0 + app.run(port=PORT) if __name__ == "__main__": - app.run(port=8000) + main() diff --git a/benchmarks/disagg_benchmarks/rate_limiter.py b/benchmarks/disagg_benchmarks/rate_limiter.py new file mode 100644 index 000000000..87ac8cb6a --- /dev/null +++ b/benchmarks/disagg_benchmarks/rate_limiter.py @@ -0,0 +1,45 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import asyncio +import time + + +class RateLimiter: + """Token bucket rate limiter implementation""" + + def __init__(self, rate_limit): + self.rate_limit = rate_limit # Requests per second + self.num_available_tokens = rate_limit # Available tokens + self.last_refill = time.monotonic() # Last token refill time + self.lock = asyncio.Lock() # Synchronization lock + + async def acquire(self): + """Acquire a token from the rate limiter""" + while True: + async with self.lock: + current_time = time.monotonic() + elapsed = current_time - self.last_refill + + # Refill num_available_tokens if more than 1 second has passed + if elapsed > 1.0: + self.num_available_tokens = self.rate_limit + self.last_refill = current_time + + # Check if num_available_tokens are available + if self.num_available_tokens > 0: + self.num_available_tokens -= 1 + return True + + # Calculate wait time if no num_available_tokens available + wait_time = 1.0 - elapsed + await asyncio.sleep(wait_time) + + async def __aenter__(self): + """Enter async context manager - acquire token""" + await self.acquire() + return self + + async def __aexit__(self, exc_type, exc_value, traceback): + """Exit async context manager - no cleanup needed""" + pass diff --git a/benchmarks/disagg_benchmarks/request_queue.py b/benchmarks/disagg_benchmarks/request_queue.py new file mode 100644 index 000000000..410bcb956 --- /dev/null +++ b/benchmarks/disagg_benchmarks/request_queue.py @@ -0,0 +1,39 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import asyncio +from collections import deque + + +class RequestQueue: + """Request queue manager with concurrency control""" + + def __init__(self, max_concurrent, max_queue_size): + # Maximum concurrent requests + self.max_concurrent = max_concurrent + self.max_queue_size = max_queue_size # Maximum queue size + # Concurrency control + self.semaphore = asyncio.Semaphore(max_concurrent) + self.queue = deque() # Request queue + self.queue_size = 0 # Current queue size + self.lock = asyncio.Lock() # Sync queue Lock + + async def enqueue(self, task): + """Add a request task to the queue""" + async with self.lock: + if self.queue_size >= self.max_queue_size: + return False + + self.queue.append(task) + self.queue_size += 1 + return True + + async def process(self): + """Process queued requests using semaphore for concurrency control""" + while True: + if self.queue: + async with self.semaphore, self.lock: + task = self.queue.popleft() + self.queue_size -= 1 + await task + await asyncio.sleep(0.01) # Yield control to event loop diff --git a/benchmarks/kernels/benchmark_machete.py b/benchmarks/kernels/benchmark_machete.py index f73d0511e..975d10f2e 100644 --- a/benchmarks/kernels/benchmark_machete.py +++ b/benchmarks/kernels/benchmark_machete.py @@ -236,6 +236,7 @@ def marlin_create_bench_fn(bt: BenchmarkTensors) -> Callable: a=bt.a, c=None, b_q_weight=w_q, + b_bias=None, b_scales=w_s, global_scale=None, b_zeros=w_zp, diff --git a/csrc/core/scalar_type.hpp b/csrc/core/scalar_type.hpp index d0f85e236..68a8750f5 100644 --- a/csrc/core/scalar_type.hpp +++ b/csrc/core/scalar_type.hpp @@ -321,6 +321,8 @@ static inline constexpr auto kFE3M2f = ScalarType::float_(3, 2, true, ScalarType::NAN_NONE); static inline constexpr auto kFE4M3fn = ScalarType::float_(4, 3, true, ScalarType::NAN_EXTD_RANGE_MAX_MIN); +static inline constexpr auto kFE8M0fnu = + ScalarType(8, 0, false, 0, true, ScalarType::NAN_EXTD_RANGE_MAX_MIN); static inline constexpr auto kFE5M2 = ScalarType::float_IEEE754(5, 2); static inline constexpr auto kFE8M7 = ScalarType::float_IEEE754(8, 7); static inline constexpr auto kFE5M10 = ScalarType::float_IEEE754(5, 10); diff --git a/csrc/moe/marlin_moe_wna16/generate_kernels.py b/csrc/moe/marlin_moe_wna16/generate_kernels.py index 49f33718a..698deb107 100644 --- a/csrc/moe/marlin_moe_wna16/generate_kernels.py +++ b/csrc/moe/marlin_moe_wna16/generate_kernels.py @@ -20,6 +20,7 @@ namespace MARLIN_NAMESPACE_NAME { TEMPLATE = ("template __global__ void Marlin<" "{{scalar_t}}, " "{{w_type_id}}, " + "{{s_type_id}}, " "{{threads}}, " "{{thread_m_blocks}}, " "{{thread_n_blocks}}, " @@ -77,6 +78,7 @@ def generate_new_kernels(): if scalar_type == "vllm::kFE4M3fn" and group_blocks not in [-1, 8]: continue # nvfp4 only supports group_size == 16 + # mxfp4 only supports group_size == 32 if scalar_type == "vllm::kFE2M1f" and group_blocks not in [1, 2]: continue # other quantization methods don't support group_size = 16 @@ -89,9 +91,22 @@ def generate_new_kernels(): c_dtype = "half" if dtype == "fp16" else "nv_bfloat16" + if scalar_type == "vllm::kFE2M1f" and group_blocks == 1: + s_type = "vllm::kFE4M3fn" + elif scalar_type == "vllm::kFE2M1f" and group_blocks == 2: + s_type = "vllm::kFE8M0fnu" + if dtype == "fp16": + # we cannot safely dequantize e8m0 to fp16, so skip this + continue + elif dtype == "fp16": + s_type = "vllm::kFloat16" + elif dtype == "bf16": + s_type = "vllm::kBFloat16" + template_str = jinja2.Template(TEMPLATE).render( scalar_t=c_dtype, w_type_id=scalar_type + ".id()", + s_type_id=s_type + ".id()", threads=threads, thread_m_blocks=max(m_blocks, 1), thread_n_blocks=n_blocks, diff --git a/csrc/moe/marlin_moe_wna16/kernel.h b/csrc/moe/marlin_moe_wna16/kernel.h index 537282aba..6190f7ee2 100644 --- a/csrc/moe/marlin_moe_wna16/kernel.h +++ b/csrc/moe/marlin_moe_wna16/kernel.h @@ -7,23 +7,25 @@ #include "quantization/gptq_marlin/marlin_dtypes.cuh" #include "core/scalar_type.hpp" -#define MARLIN_KERNEL_PARAMS \ - const int4 *__restrict__ A, const int4 *__restrict__ B, \ - int4 *__restrict__ C, int4 *__restrict__ C_tmp, \ - const int4 *__restrict__ scales_ptr, \ - const uint16_t *__restrict__ scale2_ptr, \ - const int4 *__restrict__ zp_ptr, const int *__restrict__ g_idx, \ - const int32_t *__restrict__ sorted_token_ids_ptr, \ - const int32_t *__restrict__ expert_ids_ptr, \ - const int32_t *__restrict__ num_tokens_past_padded_ptr, \ - const float *__restrict__ topk_weights_ptr, int top_k, \ - bool mul_topk_weights, bool is_ep, int num_groups, int prob_m, \ - int prob_n, int prob_k, int *locks, bool use_atomic_add, \ +#define MARLIN_KERNEL_PARAMS \ + const int4 *__restrict__ A, const int4 *__restrict__ B, \ + int4 *__restrict__ C, int4 *__restrict__ C_tmp, \ + const int4 *__restrict__ b_bias_ptr, \ + const int4 *__restrict__ scales_ptr, \ + const uint16_t *__restrict__ scale2_ptr, \ + const int4 *__restrict__ zp_ptr, const int *__restrict__ g_idx, \ + const int32_t *__restrict__ sorted_token_ids_ptr, \ + const int32_t *__restrict__ expert_ids_ptr, \ + const int32_t *__restrict__ num_tokens_past_padded_ptr, \ + const float *__restrict__ topk_weights_ptr, int top_k, \ + bool mul_topk_weights, bool is_ep, int num_groups, int prob_m, \ + int prob_n, int prob_k, int *locks, bool has_bias, bool use_atomic_add, \ bool use_fp32_reduce, int max_shared_mem namespace MARLIN_NAMESPACE_NAME { template ::value) { + static_assert(s_type == vllm::kBFloat16); + } else if constexpr (std::is_same::value) { + static_assert(s_type == vllm::kFloat16); + } + constexpr bool has_zp = w_type == vllm::kU4 || w_type == vllm::kU8; constexpr bool is_int_type = w_type == vllm::kU4 || w_type == vllm::kU8 || w_type == vllm::kU4B8 || w_type == vllm::kU8B128; // see comments of dequant.h for more details constexpr bool dequant_skip_flop = - !is_int_type || + w_type == vllm::kFE4M3fn || + w_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn || has_zp && !is_zp_float && !std::is_same::value || has_zp && !is_zp_float && !(w_type == vllm::kU8); @@ -365,6 +379,7 @@ __global__ void Marlin( const int zp_expert_stride = is_zp_float ? prob_n * prob_k / group_size / 8 : prob_n * prob_k / group_size / (pack_factor * 4); + const int b_bias_expert_stride = prob_n / 8; // parallel: num valid moe blocks int num_tokens_past_padded = num_tokens_past_padded_ptr[0]; @@ -475,7 +490,7 @@ __global__ void Marlin( for (int i = 0; i < 4; i++) { int idx = tid4 * 4 + i; idx = idx < block_num_valid_tokens ? idx : 0; - if constexpr (w_type == vllm::kFE2M1f) { + if constexpr (w_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) { sh_block_topk_weights[idx] = __hmul2( global_scale, Dtype::num2num2(Dtype::float2num( topk_weights_ptr[sh_block_sorted_ids[idx]]))); @@ -513,7 +528,7 @@ __global__ void Marlin( expert_id = expert_ids_ptr[block_id]; } - if constexpr (w_type == vllm::kFE2M1f) { + if constexpr (w_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) { uint16_t val = scale2_ptr[expert_id]; global_scale = Dtype::num2num2(*reinterpret_cast(&val)); } @@ -526,6 +541,9 @@ __global__ void Marlin( if constexpr (has_act_order) { g_idx += (expert_id - old_expert_id) * prob_k; } + if (has_bias) { + b_bias_ptr += (expert_id - old_expert_id) * b_bias_expert_stride; + } read_moe_block_data(block_id); }; @@ -721,7 +739,7 @@ __global__ void Marlin( s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 4; - s_sh_rd = s_sh_rd * 2 + warp_row % 2; + s_sh_rd = s_sh_rd * 2 + (warp_row / group_blocks) % 2; } else if constexpr (group_blocks != -1) s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + @@ -734,6 +752,18 @@ __global__ void Marlin( s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) % 4; + int bias_sh_rd; + if constexpr (m_block_size_8) { + bias_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + (threadIdx.x % 32) / 8; + } else { + bias_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + (threadIdx.x % 32) % 4; + } + + int bias_sh_wr = threadIdx.x; + int bias_gl_rd = (thread_n_blocks * 16 / 8) * slice_col + threadIdx.x; + // Zero-points have the same read layout as the scales // (without column-wise case) constexpr int num_col_threads = 8; @@ -793,7 +823,19 @@ __global__ void Marlin( constexpr int sh_b_size = stages * b_sh_stage; int4* sh_b = sh_new; int4* sh_red = sh_new; - int4* sh_g_idx = sh_b + (sh_red_size > sh_b_size ? sh_red_size : sh_b_size); + + constexpr int sh_size_b_red_min = + (sh_red_size < sh_b_size ? sh_red_size : sh_b_size); + constexpr int sh_size_b_red_max = + (sh_red_size > sh_b_size ? sh_red_size : sh_b_size); + constexpr int sh_bias_size = (thread_n_blocks * 16 / 8); + constexpr int sh_b_red_bias_size = + sh_size_b_red_max > (sh_size_b_red_min + sh_bias_size) + ? sh_size_b_red_max + : (sh_size_b_red_min + sh_bias_size); + + int4* sh_bias = sh_new + sh_size_b_red_min; + int4* sh_g_idx = sh_new + sh_b_red_bias_size; int4* sh_zp = sh_g_idx + (stages * g_idx_stage); constexpr int sh_s_size = has_act_order ? (act_s_max_num_groups * s_sh_stride) : (stages * s_sh_stage); @@ -803,9 +845,9 @@ __global__ void Marlin( static_assert(thread_m_blocks * 16 * thread_n_blocks * 16 / 8 <= stages * b_sh_stage); int4* sh_a = sh_s + sh_s_size; - constexpr int shm_size_used = - moe_block_size + stages * (g_idx_stage + zp_sh_stage) + sh_s_size + - (sh_red_size > sh_b_size ? sh_red_size : sh_b_size); + constexpr int shm_size_used = moe_block_size + + stages * (g_idx_stage + zp_sh_stage) + + sh_s_size + sh_b_red_bias_size; // all remaining shared memory is used to cache A (input) // sh_a_max_row is at least ` stages * 16 * thread_m_blocks ` @@ -816,7 +858,8 @@ __global__ void Marlin( FragA frag_a[2][thread_m_blocks]; I4 frag_b_quant[2][b_thread_vecs]; FragC frag_c[thread_m_blocks][4][2]; - FragS frag_s[2][4]; // No act-order + FragS frag_s[2][4]; // No act-order + FragS frag_bias[2][4]; FragS act_frag_s[2][4][4]; // For act-order int frag_qzp[2][num_ints_per_thread]; // Zero-points FragZP frag_zp; // Zero-points in fp16 @@ -1065,10 +1108,15 @@ __global__ void Marlin( if constexpr (w_type_id != vllm::kFE2M1f.id()) { reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride]; - } else { + } else if constexpr (group_blocks == 1 || thread_k_blocks > 4) { reinterpret_cast(&frag_s[k % 2])[0] = reinterpret_cast( sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride)]; + } else { + reinterpret_cast(&frag_s[k % 2])[0] = + reinterpret_cast( + sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride) + + k % 2]; } } } @@ -1281,9 +1329,9 @@ __global__ void Marlin( int s_quant_0 = reinterpret_cast(frag_s[k2])[0]; int s_quant_1 = reinterpret_cast(frag_s[k2])[1]; - dequant_fp8_scales(s_quant_0, - reinterpret_cast(&frag_s[k2])); - dequant_fp8_scales( + dequant_fp8_scales( + s_quant_0, reinterpret_cast(&frag_s[k2])); + dequant_fp8_scales( s_quant_1, reinterpret_cast(&frag_s[k2]) + 2); } @@ -1566,7 +1614,7 @@ __global__ void Marlin( // Write out the reduce final result in the correct layout. We only actually // reshuffle matrix fragments in this step, the reduction above is performed // in fragment layout. - auto write_result = [&]() { + auto write_result = [&](bool last) { int c_gl_stride = prob_n / 8; constexpr int c_sh_stride = 2 * thread_n_blocks + 1; int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks)); @@ -1592,7 +1640,7 @@ __global__ void Marlin( // We first reorder in shared memory to guarantee the most efficient final // global write patterns - auto write = [&](int idx, float c0, float c1, FragS& s) { + auto write = [&](int idx, float c0, float c1, FragS& s, FragS& b_bias) { scalar_t2 res = Dtype::nums2num2(Dtype::float2num(c0), Dtype::float2num(c1)); @@ -1601,14 +1649,27 @@ __global__ void Marlin( if constexpr (!has_act_order && group_blocks == -1 && w_type.size_bits() == 4 && (has_zp && dequant_skip_flop || !has_zp)) { - res = __hmul2(res, s[0]); + scalar_t2 tmp_scale = s[0]; + if constexpr (m_block_size_8) { + tmp_scale = Dtype::num2num2( + reinterpret_cast(&s[0])[(threadIdx.x % 8) / 4]); + } + res = __hmul2(res, tmp_scale); } - if constexpr (w_type == vllm::kFE2M1f) { + if constexpr (w_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) { if (!mul_topk_weights) { res = __hmul2(res, global_scale); } } + if (has_bias && last) { + scalar_t2 tmp_bias = b_bias[0]; + if constexpr (m_block_size_8) { + tmp_bias = Dtype::num2num2( + reinterpret_cast(&b_bias[0])[(threadIdx.x % 8) / 4]); + } + res = __hadd2(res, tmp_bias); + } if constexpr (m_block_size_8) { ((scalar_t*)sh_red)[idx] = res.x; @@ -1626,19 +1687,25 @@ __global__ void Marlin( if constexpr (m_block_size_8) { int wr = c_sh_wr + 16 * j; write(wr, frag_c[i][j][0][0], frag_c[i][j][0][1], - frag_s[j / 2][2 * (j % 2) + 0]); + frag_s[j / 2][2 * (j % 2) + 0], + frag_bias[j / 2][2 * (j % 2) + 0]); write(wr + 8, frag_c[i][j][0][2], frag_c[i][j][0][3], - frag_s[j / 2][2 * (j % 2) + 1]); + frag_s[j / 2][2 * (j % 2) + 1], + frag_bias[j / 2][2 * (j % 2) + 1]); } else { int wr = c_sh_wr + 8 * j; write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0], - frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]); + frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0], + frag_bias[j / 2][2 * (j % 2) + 0]); write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2], - frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]); + frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0], + frag_bias[j / 2][2 * (j % 2) + 0]); write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0], - frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]); + frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1], + frag_bias[j / 2][2 * (j % 2) + 1]); write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2], - frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]); + frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1], + frag_bias[j / 2][2 * (j % 2) + 1]); } } c_sh_wr += 16 * (4 * c_sh_stride); @@ -1805,6 +1872,14 @@ __global__ void Marlin( } thread_block_reduce(); + + if (has_bias && last) { + __syncthreads(); + cp_async4_pred(&sh_bias[bias_sh_wr], &b_bias_ptr[bias_gl_rd], + threadIdx.x < 16 * thread_n_blocks / 8); + cp_async_fence(); + } + if constexpr (!has_act_order && group_blocks == -1 && (has_zp && dequant_skip_flop || !has_zp)) { if (w_type.size_bits() == 8 || (last || use_atomic_add)) { @@ -1867,11 +1942,20 @@ __global__ void Marlin( } barrier_release(&locks[locks_off], last); } + + if (has_bias && last) { + cp_async_wait<0>(); + __syncthreads(); + reinterpret_cast(&frag_bias)[0] = sh_bias[bias_sh_rd]; + reinterpret_cast(&frag_bias)[1] = sh_bias[bias_sh_rd + 4]; + __syncthreads(); + } + if (use_atomic_add && slice_count > 1 && slice_idx != 0) wait_negative_and_add(&locks[locks_off]); if (last || use_atomic_add) // only the last block in a slice actually writes the result - write_result(); + write_result(last); int old_slice_row = slice_row; slice_row = 0; slice_col_par++; @@ -1904,6 +1988,7 @@ __global__ void Marlin( for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride; } + bias_gl_rd = (thread_n_blocks * 16 / 8) * slice_col + threadIdx.x; // Update slice k/n for scales loading if constexpr (has_act_order) { slice_k_start = tb_k * slice_row; diff --git a/csrc/moe/marlin_moe_wna16/ops.cu b/csrc/moe/marlin_moe_wna16/ops.cu index 2cff04f69..601e2aa6f 100644 --- a/csrc/moe/marlin_moe_wna16/ops.cu +++ b/csrc/moe/marlin_moe_wna16/ops.cu @@ -51,8 +51,9 @@ __global__ void permute_cols_kernel( } // namespace marlin torch::Tensor moe_wna16_marlin_gemm( - torch::Tensor& a, std::optional const& c_or_none, - torch::Tensor& b_q_weight, torch::Tensor& b_scales, + torch::Tensor& a, std::optional c_or_none, + torch::Tensor& b_q_weight, + std::optional const& b_bias_or_none, torch::Tensor& b_scales, std::optional const& b_zeros_or_none, std::optional const& g_idx_or_none, std::optional const& perm_or_none, torch::Tensor& workspace, @@ -212,7 +213,7 @@ int get_kernel_cache_size(thread_config_t const& th_config, bool m_block_size_8, // Get B size int tb_k = th_config.thread_k; int tb_n = th_config.thread_n; - int tb_m = thread_m_blocks * (m_block_size_8 ? 8 : 16); + int tb_m = thread_m_blocks * 16; // shm size for block_sorted_ids/rd_block_sorted_ids/block_topk_weights // both of them requires tb_m * 4 bytes (tb_m * int32 or tb_m * float32) @@ -220,6 +221,11 @@ int get_kernel_cache_size(thread_config_t const& th_config, bool m_block_size_8, int sh_a_size = pipe_stages * (tb_m * tb_k) * 2; int sh_b_size = pipe_stages * (tb_k * tb_n / pack_factor) * 4; int sh_red_size = tb_m * (tb_n + 8) * 2; + int sh_bias_size = tb_n * 2; + int tmp_size = + (sh_b_size > sh_red_size ? sh_red_size : sh_b_size) + sh_bias_size; + tmp_size = max(max(sh_b_size, sh_red_size), tmp_size); + int sh_s_size = get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits, group_size, has_act_order, is_k_full); @@ -234,8 +240,8 @@ int get_kernel_cache_size(thread_config_t const& th_config, bool m_block_size_8, sh_zp_size = sh_s_size / 2; } - int total_size = max(sh_b_size, sh_red_size) + sh_a_size + sh_s_size + - sh_zp_size + sh_g_idx_size + sh_block_meta_size; + int total_size = tmp_size + sh_a_size + sh_s_size + sh_zp_size + + sh_g_idx_size + sh_block_meta_size; return total_size; } @@ -270,20 +276,25 @@ bool is_valid_config(thread_config_t const& th_config, bool m_block_size_8, int cache_size = get_kernel_cache_size( th_config, m_block_size_8, thread_m_blocks, prob_m, prob_n, prob_k, num_bits, group_size, has_act_order, is_k_full, has_zp, is_zp_float); - return cache_size <= max_shared_mem; + return cache_size + 512 <= max_shared_mem; } - #define _GET_IF(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \ - M_BLOCK_SIZE_8, GROUP_BLOCKS, NUM_THREADS, IS_ZP_FLOAT) \ - else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \ - thread_n_blocks == THREAD_N_BLOCKS && \ - thread_k_blocks == THREAD_K_BLOCKS && \ - m_block_size_8 == M_BLOCK_SIZE_8 && \ - group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS && \ - is_zp_float == IS_ZP_FLOAT) { \ - kernel = Marlin; \ + #define _GET_IF(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \ + M_BLOCK_SIZE_8, GROUP_BLOCKS, NUM_THREADS, IS_ZP_FLOAT) \ + else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \ + thread_n_blocks == THREAD_N_BLOCKS && \ + thread_k_blocks == THREAD_K_BLOCKS && \ + m_block_size_8 == M_BLOCK_SIZE_8 && \ + group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS && \ + is_zp_float == IS_ZP_FLOAT) { \ + constexpr auto S_TYPE = \ + W_TYPE == vllm::kFE2M1f \ + ? (GROUP_BLOCKS == 1 ? vllm::kFE4M3fn : vllm::kFE8M0fnu) \ + : (std::is_same::value ? vllm::kFloat16 \ + : vllm::kBFloat16); \ + kernel = Marlin; \ } // COMMON: cases for (group_blocks in [-1, 2, 4, 8] and is_zp_float == false) @@ -335,31 +346,45 @@ bool is_valid_config(thread_config_t const& th_config, bool m_block_size_8, _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \ _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \ - \ _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) - #define FP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) - - #define FP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) - - #define FP4_GET_IF(W_TYPE) \ - FP4_GET_IF_M1(W_TYPE, 8, 8, 256) \ - FP4_GET_IF_M1(W_TYPE, 8, 4, 128) \ - FP4_GET_IF_M234(W_TYPE, 16, 4, 256) \ - FP4_GET_IF_M234(W_TYPE, 8, 4, 128) - #define BIGGROUP_GET_IF(W_TYPE) \ BIGGROUP_GET_IF_M1(W_TYPE, 8, 8, 256) \ BIGGROUP_GET_IF_M1(W_TYPE, 8, 4, 128) \ BIGGROUP_GET_IF_M234(W_TYPE, 16, 4, 256) \ BIGGROUP_GET_IF_M234(W_TYPE, 8, 4, 128) + #define NVFP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) + + #define NVFP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) + + #define NVFP4_GET_IF(W_TYPE) \ + NVFP4_GET_IF_M1(W_TYPE, 8, 8, 256) \ + NVFP4_GET_IF_M1(W_TYPE, 8, 4, 128) \ + NVFP4_GET_IF_M234(W_TYPE, 16, 4, 256) \ + NVFP4_GET_IF_M234(W_TYPE, 8, 4, 128) + + #define MXFP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 2, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) + + #define MXFP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) + + #define MXFP4_GET_IF(W_TYPE) \ + MXFP4_GET_IF_M1(W_TYPE, 8, 8, 256) \ + MXFP4_GET_IF_M1(W_TYPE, 8, 4, 128) \ + MXFP4_GET_IF_M234(W_TYPE, 16, 4, 256) \ + MXFP4_GET_IF_M234(W_TYPE, 8, 4, 128) + // We currently have 4-bit models only with group_blocks == 4 #define FZP_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 4, NUM_THREADS, true) \ @@ -408,12 +433,17 @@ MarlinFuncPtr get_marlin_kernel(const vllm::ScalarType q_type, COMMON_GET_IF(vllm::kU4B8) COMMON_GET_IF(vllm::kU8B128) - BIGGROUP_GET_IF(vllm::kFE4M3fn) + NVFP4_GET_IF(vllm::kFE2M1f) - FP4_GET_IF(vllm::kFE2M1f) + BIGGROUP_GET_IF(vllm::kFE4M3fn) ACT_GET_IF(vllm::kU4B8) ACT_GET_IF(vllm::kU8B128) + if (std::is_same::value) { + if (false) { + } + MXFP4_GET_IF(vllm::kFE2M1f) + } return kernel; } @@ -482,16 +512,16 @@ exec_config_t determine_exec_config(const vllm::ScalarType& q_type, int prob_m, } template -void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s, - void* s2, void* zp, void* g_idx, void* perm, void* a_tmp, - void* sorted_token_ids, void* expert_ids, +void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias, + void* s, void* s2, void* zp, void* g_idx, void* perm, + void* a_tmp, void* sorted_token_ids, void* expert_ids, void* num_tokens_past_padded, void* topk_weights, int moe_block_size, int top_k, bool mul_topk_weights, bool is_ep, int prob_m, int prob_n, int prob_k, void* workspace, - vllm::ScalarType const& q_type, bool has_act_order, - bool is_k_full, bool has_zp, int num_groups, int group_size, - int dev, cudaStream_t stream, int thread_k, int thread_n, - int sms, bool use_atomic_add, bool use_fp32_reduce, + vllm::ScalarType const& q_type, bool has_bias, + bool has_act_order, bool is_k_full, bool has_zp, int num_groups, + int group_size, int dev, cudaStream_t stream, int thread_k, + int thread_n, int sms, bool use_atomic_add, bool use_fp32_reduce, bool is_zp_float) { int thread_m_blocks = div_ceil(moe_block_size, 16); bool m_block_size_8 = moe_block_size == 8; @@ -538,6 +568,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s, const int4* B_ptr = (const int4*)B; int4* C_ptr = (int4*)C; int4* C_tmp_ptr = (int4*)C_tmp; + const int4* bias_ptr = (const int4*)b_bias; const int4* s_ptr = (const int4*)s; const uint16_t* s2_ptr = (const uint16_t*)s2; const int4* zp_ptr = (const int4*)zp; @@ -648,10 +679,10 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s, // avoid ">>>" being formatted to "> > >" // clang-format off kernel<<>>( - A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, s2_ptr, zp_ptr, g_idx_ptr, + A_ptr, B_ptr, C_ptr, C_tmp_ptr, bias_ptr, s_ptr, s2_ptr, zp_ptr, g_idx_ptr, sorted_token_ids_ptr, expert_ids_ptr, num_tokens_past_padded_ptr, topk_weights_ptr, top_k, mul_topk_weights, is_ep, num_groups, prob_m, - prob_n, prob_k, locks, use_atomic_add, use_fp32_reduce, max_shared_mem); + prob_n, prob_k, locks, has_bias, use_atomic_add, use_fp32_reduce, max_shared_mem); // clang-format on } @@ -659,7 +690,8 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s, torch::Tensor moe_wna16_marlin_gemm( torch::Tensor& a, std::optional const& c_or_none, - torch::Tensor& b_q_weight, torch::Tensor& b_scales, + torch::Tensor& b_q_weight, + std::optional const& b_bias_or_none, torch::Tensor& b_scales, std::optional const& global_scale_or_none, std::optional const& b_zeros_or_none, std::optional const& g_idx_or_none, @@ -766,7 +798,6 @@ torch::Tensor moe_wna16_marlin_gemm( num_groups = b_scales.size(1); torch::Tensor g_idx, perm, a_tmp; - ; if (g_idx_or_none.has_value() && perm_or_none.has_value()) { g_idx = g_idx_or_none.value(); perm = perm_or_none.value(); @@ -815,12 +846,24 @@ torch::Tensor moe_wna16_marlin_gemm( torch::Tensor global_scale; if (global_scale_or_none.has_value()) { global_scale = global_scale_or_none.value(); - TORCH_CHECK(b_q_type == vllm::kFE2M1f, - "global_scale can only be used for float4_e2m1f."); + TORCH_CHECK(b_q_type == vllm::kFE2M1f && group_size == 16, + "global_scale can only be used for nvfp4 format."); } else { global_scale = torch::empty({0}, options); - TORCH_CHECK(!(b_q_type == vllm::kFE2M1f), - "the global_scale parameter must be passed for float4_e2m1f."); + TORCH_CHECK(!(b_q_type == vllm::kFE2M1f && group_size == 16), + "the global_scale parameter must be passed for nvfp4 format."); + } + + bool has_bias = b_bias_or_none.has_value(); + torch::Tensor b_bias; + if (has_bias) { + b_bias = b_bias_or_none.value(); + TORCH_CHECK(b_bias.device().is_cuda(), "b_bias is not on GPU"); + TORCH_CHECK(b_bias.is_contiguous(), "b_bias is not contiguous"); + TORCH_CHECK(b_bias.size(1) == size_n, "b_bias.size(0) != size_n"); + TORCH_CHECK(b_bias.stride(1) == 1, "b_bias.stride(1) != 1"); + } else { + b_bias = torch::empty({0}, options); } torch::Tensor b_zeros; @@ -832,7 +875,6 @@ torch::Tensor moe_wna16_marlin_gemm( b_zeros = torch::empty({0}, options); } bool has_zp = b_zeros.size(-1) > 0; - if (has_zp) { TORCH_CHECK( b_q_type == vllm::kU4 || b_q_type == vllm::kU8, @@ -890,41 +932,58 @@ torch::Tensor moe_wna16_marlin_gemm( if (a.scalar_type() == at::ScalarType::Half) { void* scales_ptr; if (b_q_type == vllm::kFE2M1f) { - scales_ptr = b_scales.data_ptr(); + if (group_size == 16) + scales_ptr = b_scales.data_ptr(); + else if (group_size == 32) + scales_ptr = b_scales.data_ptr(); + else + TORCH_CHECK(false, + "float4_e2m1f only supports group_size == 16 (NVFP4) ", + "and group_size == 32 (MXFP4)"); } else { scales_ptr = b_scales.data_ptr(); } MARLIN_NAMESPACE_NAME::marlin_mm( a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), - c_tmp.data_ptr(), scales_ptr, global_scale.data_ptr(), - b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(), - a_tmp.data_ptr(), sorted_token_ids.data_ptr(), - expert_ids.data_ptr(), num_tokens_past_padded.data_ptr(), - topk_weights.data_ptr(), moe_block_size, top_k, mul_topk_weights, is_ep, - size_m, size_n, size_k, workspace.data_ptr(), b_q_type, has_act_order, - is_k_full, has_zp, num_groups, group_size, dev, + c_tmp.data_ptr(), b_bias.data_ptr(), scales_ptr, + global_scale.data_ptr(), b_zeros.data_ptr(), g_idx.data_ptr(), + perm.data_ptr(), a_tmp.data_ptr(), + sorted_token_ids.data_ptr(), expert_ids.data_ptr(), + num_tokens_past_padded.data_ptr(), topk_weights.data_ptr(), + moe_block_size, top_k, mul_topk_weights, is_ep, size_m, size_n, size_k, + workspace.data_ptr(), b_q_type, has_bias, has_act_order, is_k_full, + has_zp, num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, use_atomic_add, use_fp32_reduce, is_zp_float); } else if (a.scalar_type() == at::ScalarType::BFloat16) { void* scales_ptr; if (b_q_type == vllm::kFE2M1f) { - scales_ptr = b_scales.data_ptr(); + if (group_size == 16) + scales_ptr = b_scales.data_ptr(); + else if (group_size == 32) + scales_ptr = b_scales.data_ptr(); + else + TORCH_CHECK(false, + "float4_e2m1f only supports group_size == 16 (NVFP4) ", + "and group_size == 32 (MXFP4)"); } else { scales_ptr = b_scales.data_ptr(); } MARLIN_NAMESPACE_NAME::marlin_mm( a.data_ptr(), b_q_weight.data_ptr(), - c.data_ptr(), c_tmp.data_ptr(), scales_ptr, + c.data_ptr(), c_tmp.data_ptr(), + b_bias.data_ptr(), scales_ptr, global_scale.data_ptr(), b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(), sorted_token_ids.data_ptr(), expert_ids.data_ptr(), num_tokens_past_padded.data_ptr(), topk_weights.data_ptr(), moe_block_size, top_k, mul_topk_weights, is_ep, size_m, size_n, size_k, - workspace.data_ptr(), b_q_type, has_act_order, is_k_full, has_zp, - num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev), - thread_k, thread_n, sms, use_atomic_add, use_fp32_reduce, is_zp_float); + workspace.data_ptr(), b_q_type, has_bias, has_act_order, is_k_full, + has_zp, num_groups, group_size, dev, + at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, + use_atomic_add, use_fp32_reduce, is_zp_float); } else { TORCH_CHECK(false, "moe_wna16_marlin_gemm only supports bfloat16 and float16"); diff --git a/csrc/moe/torch_bindings.cpp b/csrc/moe/torch_bindings.cpp index d96e082f6..7e49f68f6 100644 --- a/csrc/moe/torch_bindings.cpp +++ b/csrc/moe/torch_bindings.cpp @@ -35,7 +35,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { m.def( "moe_wna16_marlin_gemm(Tensor! a, Tensor? c_or_none," - "Tensor! b_q_weight, Tensor! b_scales, Tensor? global_scale, Tensor? " + "Tensor! b_q_weight, Tensor? b_bias_or_none," + "Tensor! b_scales, Tensor? global_scale, Tensor? " "b_zeros_or_none," "Tensor? g_idx_or_none, Tensor? perm_or_none, Tensor! workspace," "Tensor sorted_token_ids," diff --git a/csrc/ops.h b/csrc/ops.h index 207291ece..3e29f0a97 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -145,22 +145,6 @@ void gelu_fast(torch::Tensor& out, torch::Tensor& input); void gelu_quick(torch::Tensor& out, torch::Tensor& input); -void advance_step_flashattn(int64_t num_seqs, int64_t num_queries, - int64_t block_size, torch::Tensor& input_tokens, - torch::Tensor& sampled_token_ids, - torch::Tensor& input_positions, - torch::Tensor& seq_lens, - torch::Tensor& slot_mapping, - torch::Tensor& block_tables); - -void advance_step_flashinfer( - int64_t num_seqs, int64_t num_queries, int64_t block_size, - torch::Tensor& input_tokens, torch::Tensor& sampled_token_ids, - torch::Tensor& input_positions, torch::Tensor& seq_lens, - torch::Tensor& slot_mapping, torch::Tensor& block_tables, - torch::Tensor& paged_kv_indices, torch::Tensor& paged_kv_indptr, - torch::Tensor& paged_kv_last_page_len, torch::Tensor& block_table_bounds); - void cutlass_mla_decode(torch::Tensor const& out, torch::Tensor const& q_nope, torch::Tensor const& q_pe, torch::Tensor const& kv_c_and_k_pe_cache, diff --git a/csrc/prepare_inputs/advance_step.cu b/csrc/prepare_inputs/advance_step.cu deleted file mode 100644 index 3d5077d9d..000000000 --- a/csrc/prepare_inputs/advance_step.cu +++ /dev/null @@ -1,336 +0,0 @@ -/* - * The goal of this GPU kernel is to advance input tensors on the GPU directly - * PR: https://github.com/vllm-project/vllm/pull/6338 - * Current restrictions: - * 1. Specialized for DraftModelRunner - * 2. Supports flash_attn only - */ - -#include "advance_step.cuh" - -namespace prepare_inputs { - -// -template -__global__ void advance_step_flashattn_kernel( - int num_seqs, int num_queries, int block_size, long* input_tokens_ptr, - long const* sampled_token_ids_ptr, long* input_positions_ptr, - int* seq_lens_ptr, long* slot_mapping_ptr, int const* block_tables_ptr, - int64_t const block_tables_stride) { - int const n_pad = num_seqs - num_queries; - if (n_pad && blockIdx.x == 0) { - // Handle cuda graph padding - int const offset = num_queries; - for (int i = threadIdx.x; i < n_pad; i += blockDim.x) { - input_tokens_ptr[offset + i] = 0; - input_positions_ptr[offset + i] = 0; - slot_mapping_ptr[offset + i] = -1; - } - } - - int num_query_blocks = div_ceil(num_queries, num_threads); - - if (blockIdx.x >= num_query_blocks) { - return; - } - - int cur_query_id = blockIdx.x * num_threads + threadIdx.x; - - if (cur_query_id >= num_queries) { - return; - } - - // Update input_tokens - input_tokens_ptr[cur_query_id] = sampled_token_ids_ptr[cur_query_id]; - - int seq_len = seq_lens_ptr[cur_query_id]; - int next_seq_len = seq_len + 1; - int next_input_pos = next_seq_len - 1; - - // Update seq_lens - seq_lens_ptr[cur_query_id] = next_seq_len; - // Update input_positions - input_positions_ptr[cur_query_id] = next_input_pos; - - int const* seq_block_tables_ptr = - block_tables_ptr + block_tables_stride * cur_query_id; - - int block_index = next_input_pos / block_size; - int block_offset = next_input_pos % block_size; - - int slot_num = seq_block_tables_ptr[block_index] * block_size + block_offset; - // Update slot_mapping - slot_mapping_ptr[cur_query_id] = slot_num; -} - -inline void verify_tensor(std::string const& name, torch::Tensor const& t, - int64_t const size_0, int64_t const size_1, - c10::ScalarType const type) { - bool size_0_cond = true; - if (size_0 != -1) { - size_0_cond = t.size(0) == size_0; - } - - bool size_1_cond = true; - if (size_1 != -1) { - size_1_cond = t.size(1) == size_1; - } - - bool is_contiguous = t.is_contiguous(); - bool same_type = t.dtype() == type; - - bool pass = size_0_cond && size_1_cond && is_contiguous && same_type; - if (!pass) { - TORCH_CHECK(false, "tensor: name = ", name, ", shape = ", t.sizes(), - " is_cont = ", t.is_contiguous(), ", type = ", t.dtype(), - " is not as expected: shape = [", size_0, ", ", size_1, - "], type = ", type); - } -} - -/// each thread processes a block per query -__global__ void advance_step_flashinfer_kernel( - int num_threads, int num_seqs, int num_queries, int block_size, - long* input_tokens_ptr, long const* sampled_token_ids_ptr, - long* input_positions_ptr, int* seq_lens_ptr, long* slot_mapping_ptr, - int const* block_tables_ptr, int64_t const block_tables_stride, - int* paged_kv_last_page_len_ptr, int* block_table_bound_ptr) { - int const n_pad = num_seqs - num_queries; - if (n_pad && blockIdx.x == 0) { - // Handle cuda graph padding - int const offset = num_queries; - for (int i = threadIdx.x; i < n_pad; i += blockDim.x) { - input_tokens_ptr[offset + i] = 0; - input_positions_ptr[offset + i] = 0; - slot_mapping_ptr[offset + i] = -1; - } - } - int num_query_blocks = div_ceil(num_queries, num_threads); - - if (blockIdx.x < num_query_blocks) { - int cur_query_id = blockIdx.x * num_threads + threadIdx.x; - - if (cur_query_id < num_queries) { - // Update input_tokens - input_tokens_ptr[cur_query_id] = sampled_token_ids_ptr[cur_query_id]; - - int seq_len = seq_lens_ptr[cur_query_id]; - int next_seq_len = seq_len + 1; - int next_input_pos = next_seq_len - 1; - - // Update seq_lens - seq_lens_ptr[cur_query_id] = next_seq_len; - // Update input_positions - input_positions_ptr[cur_query_id] = next_input_pos; - - int const* seq_block_tables_ptr = - block_tables_ptr + block_tables_stride * cur_query_id; - - int block_index = next_input_pos / block_size; - int block_offset = next_input_pos % block_size; - - // Update paged_kv_last_page_len - paged_kv_last_page_len_ptr[cur_query_id] = block_offset + 1; - - int slot_num = - seq_block_tables_ptr[block_index] * block_size + block_offset; - // Update slot_mapping - slot_mapping_ptr[cur_query_id] = slot_num; - block_table_bound_ptr[cur_query_id] = div_ceil(next_seq_len, block_size); - } - } -} - -__global__ void advance_step_flashinfer_indptr_kernel( - int num_threads, int num_seqs, int num_queries, int* paged_kv_indptr_ptr, - int* block_table_bound_ptr) { - int idx = blockIdx.x * num_threads + threadIdx.x; - // Update paged_kv_indptr - if (idx == 0) { - paged_kv_indptr_ptr[idx] = 0; - } - if (idx < num_queries) { - int sum = 0; - for (int i = 0; i <= idx; ++i) { - sum += block_table_bound_ptr[i]; - } - paged_kv_indptr_ptr[idx + 1] = sum; - } -} - -__global__ void advance_step_flashinfer_indices_kernel( - int num_seqs, int num_queries, int const* block_tables_ptr, - int64_t const max_num_blocks_per_seq, int* paged_kv_indices_ptr, - int* paged_kv_indptr_ptr, int* block_table_bound_ptr) { - // note: max_num_blocks_per_seq = block_tables.stride(0) - int tid = blockIdx.x * blockDim.x + threadIdx.x; - - // when cuda graphs are enabled, paged_kv_indptr tensor - // has to be updated for the padded queries - // tid represents a query# for paged_kv_indptr tensor - if (num_queries < tid && tid <= num_seqs) { - paged_kv_indptr_ptr[tid] = paged_kv_indptr_ptr[num_queries]; - } - - // each thread processes a block_ptr in block_tables - // block_tables shape: [num_queries, max_num_blocks_per_seq] - // paged_kv_indices is flattened block_tables. - for (int idx = tid; idx < (num_seqs * max_num_blocks_per_seq); - idx += (gridDim.x * blockDim.x)) { - // block_tables-row = paged_kv_indptr[queryNum] - int queryNum = idx / max_num_blocks_per_seq; - int col = idx % max_num_blocks_per_seq; - if (queryNum < num_queries && col < block_table_bound_ptr[queryNum]) { - int indices_arr_idx = paged_kv_indptr_ptr[queryNum] + col; - int block_tables_idx = queryNum * max_num_blocks_per_seq + col; - paged_kv_indices_ptr[indices_arr_idx] = - block_tables_ptr[block_tables_idx]; - } - } -} - -void advance_step_flashattn(int num_seqs, int num_queries, int block_size, - torch::Tensor& input_tokens, // type: long - torch::Tensor& sampled_token_ids, // type: long - torch::Tensor& input_positions, // type: long - torch::Tensor& seq_lens, // type: int - torch::Tensor& slot_mapping, // type: long - torch::Tensor& block_tables) { // type: int - - if (logging) { - printf("advance_step_flashattn:\n"); - printf(" num_seqs = %d\n", num_seqs); - printf(" num_queries = %d\n", num_queries); - printf(" block_size = %d\n", block_size); - } - // Verify all tensors - verify_tensor("input_tokens", input_tokens, num_seqs, -1, at::kLong); - verify_tensor("sampled_token_ids", sampled_token_ids, num_queries, 1, - at::kLong); - verify_tensor("input_positions", input_positions, num_seqs, -1, at::kLong); - verify_tensor("seq_lens", seq_lens, num_seqs, -1, at::kInt); - verify_tensor("slot_mapping", slot_mapping, num_seqs, -1, at::kLong); - verify_tensor("block_tables", block_tables, num_seqs, -1, at::kInt); - - int dev = sampled_token_ids.get_device(); - cudaStream_t stream = at::cuda::getCurrentCUDAStream(dev); - - int blocks; - cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev); - - advance_step_flashattn_kernel - <<>>( - num_seqs, num_queries, block_size, - reinterpret_cast(input_tokens.data_ptr()), - reinterpret_cast(sampled_token_ids.data_ptr()), - reinterpret_cast(input_positions.data_ptr()), - reinterpret_cast(seq_lens.data_ptr()), - reinterpret_cast(slot_mapping.data_ptr()), - reinterpret_cast(block_tables.data_ptr()), - block_tables.stride(0)); -} - -void advance_step_flashinfer( - int num_seqs, int num_queries, int block_size, - torch::Tensor& input_tokens, // type: long - torch::Tensor& sampled_token_ids, // type: long - torch::Tensor& input_positions, // type: long - torch::Tensor& seq_lens, // type: int - torch::Tensor& slot_mapping, // type: long - torch::Tensor& block_tables, // type: int - torch::Tensor& paged_kv_indices, // type: int - torch::Tensor& paged_kv_indptr, // type: int - torch::Tensor& paged_kv_last_page_len, // type: int - torch::Tensor& block_table_bound) { // type: int - - if (logging) { - printf("advance_step_flashinfer:\n"); - printf(" num_seqs = %d\n", num_seqs); - printf(" num_queries = %d\n", num_queries); - printf(" block_size = %d\n", block_size); - printf(" block_tables.stride(0) = %zu\n", block_tables.stride(0)); - } - // Verify all tensors - verify_tensor("input_tokens", input_tokens, num_seqs, -1, at::kLong); - // verify_tensor("sampled_token_ids", sampled_token_ids, num_queries, 1, - // at::kLong); - verify_tensor("input_positions", input_positions, num_seqs, -1, at::kLong); - verify_tensor("seq_lens", seq_lens, num_seqs, -1, at::kInt); - verify_tensor("slot_mapping", slot_mapping, num_seqs, -1, at::kLong); - verify_tensor("block_tables", block_tables, num_seqs, -1, at::kInt); - - verify_tensor("paged_kv_indices", paged_kv_indices, -1, -1, at::kInt); - verify_tensor("paged_kv_indptr", paged_kv_indptr, num_seqs + 1, -1, at::kInt); - verify_tensor("paged_kv_last_page_len", paged_kv_last_page_len, num_seqs, -1, - at::kInt); - - verify_tensor("block_table_bound", block_table_bound, num_seqs, -1, at::kInt); - - int dev = sampled_token_ids.get_device(); - cudaStream_t stream = at::cuda::getCurrentCUDAStream(dev); - - int blocks; - int threads; - cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev); - cudaDeviceGetAttribute(&threads, cudaDevAttrMaxThreadsPerBlock, dev); - - TORCH_CHECK((blocks * threads > num_queries), - "multi-step: not enough threads to map to num_queries = ", - num_queries, " block_tables.stride(0) = ", block_tables.stride(0), - " blocks = ", blocks, " max_threads = ", threads); - if (logging) { - printf("launching kernels with %d blocks and %d threads\n", blocks, - threads); - } - advance_step_flashinfer_kernel<<>>( - threads, num_seqs, num_queries, block_size, - reinterpret_cast(input_tokens.data_ptr()), - reinterpret_cast(sampled_token_ids.data_ptr()), - reinterpret_cast(input_positions.data_ptr()), - reinterpret_cast(seq_lens.data_ptr()), - reinterpret_cast(slot_mapping.data_ptr()), - reinterpret_cast(block_tables.data_ptr()), - block_tables.stride(0), - reinterpret_cast(paged_kv_last_page_len.data_ptr()), - reinterpret_cast(block_table_bound.data_ptr())); - - advance_step_flashinfer_indptr_kernel<<>>( - threads, num_seqs, num_queries, - reinterpret_cast(paged_kv_indptr.data_ptr()), - reinterpret_cast(block_table_bound.data_ptr())); - - advance_step_flashinfer_indices_kernel<<>>( - num_seqs, num_queries, - reinterpret_cast(block_tables.data_ptr()), - block_tables.stride(0), - reinterpret_cast(paged_kv_indices.data_ptr()), - reinterpret_cast(paged_kv_indptr.data_ptr()), - reinterpret_cast(block_table_bound.data_ptr())); -} - -} // namespace prepare_inputs - -void advance_step_flashattn(int64_t num_seqs, int64_t num_queries, - int64_t block_size, torch::Tensor& input_tokens, - torch::Tensor& sampled_token_ids, - torch::Tensor& input_positions, - torch::Tensor& seq_lens, - torch::Tensor& slot_mapping, - torch::Tensor& block_tables) { - prepare_inputs::advance_step_flashattn( - num_seqs, num_queries, block_size, input_tokens, sampled_token_ids, - input_positions, seq_lens, slot_mapping, block_tables); -} - -void advance_step_flashinfer( - int64_t num_seqs, int64_t num_queries, int64_t block_size, - torch::Tensor& input_tokens, torch::Tensor& sampled_token_ids, - torch::Tensor& input_positions, torch::Tensor& seq_lens, - torch::Tensor& slot_mapping, torch::Tensor& block_tables, - torch::Tensor& paged_kv_indices, torch::Tensor& paged_kv_indptr, - torch::Tensor& paged_kv_last_page_len, torch::Tensor& block_table_bound) { - prepare_inputs::advance_step_flashinfer( - num_seqs, num_queries, block_size, input_tokens, sampled_token_ids, - input_positions, seq_lens, slot_mapping, block_tables, paged_kv_indices, - paged_kv_indptr, paged_kv_last_page_len, block_table_bound); -} diff --git a/csrc/prepare_inputs/advance_step.cuh b/csrc/prepare_inputs/advance_step.cuh deleted file mode 100644 index f21574681..000000000 --- a/csrc/prepare_inputs/advance_step.cuh +++ /dev/null @@ -1,19 +0,0 @@ -#pragma once - -#include - -#include -#include -#include -#include -#include -#include - -namespace prepare_inputs { - -static constexpr int max_threads = 256; -static constexpr bool logging = false; - -constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; } - -} // namespace prepare_inputs diff --git a/csrc/quantization/gptq_marlin/dequant.h b/csrc/quantization/gptq_marlin/dequant.h index ae0d6c0f2..e8b0c302b 100644 --- a/csrc/quantization/gptq_marlin/dequant.h +++ b/csrc/quantization/gptq_marlin/dequant.h @@ -470,11 +470,12 @@ __device__ inline void dequant( frag_b[0] = __hmul2(frag_b[0], bias_reg); } -template +template __device__ inline void dequant_fp8_scales(int q, scalar_t2* frag_b); template <> -__device__ inline void dequant_fp8_scales(int q, half2* frag_b) { +__device__ inline void dequant_fp8_scales( + int q, half2* frag_b) { int Out1 = (q & 0xFF00FF00) >> 1; ; q <<= 8; @@ -486,8 +487,8 @@ __device__ inline void dequant_fp8_scales(int q, half2* frag_b) { }; template <> -__device__ inline void dequant_fp8_scales(int q, - nv_bfloat162* frag_b) { +__device__ inline void dequant_fp8_scales( + int q, nv_bfloat162* frag_b) { constexpr int FP8_EXPONENT = 4, BF16_EXPONENT = 8; constexpr int RIGHT_SHIFT = BF16_EXPONENT - FP8_EXPONENT; constexpr int MASK = 0x7F007F00; @@ -502,6 +503,20 @@ __device__ inline void dequant_fp8_scales(int q, frag_b[0] = *reinterpret_cast(&Out2); } +template <> +__device__ inline void dequant_fp8_scales( + int q, nv_bfloat162* frag_b) { + // In this conversion, 2 ** -127 in FP8E8M0 would become 0 in BF16, + // but we assume that such a extreme value would not occur in real models. + int Out1 = (q & 0xFF00FF00) >> 1; + q <<= 7; + int Out2 = q & 0x7F807F80; + + // Note: reverse indexing is intentional because weights are permuted + frag_b[1] = *reinterpret_cast(&Out1); + frag_b[0] = *reinterpret_cast(&Out2); +} + #endif } // namespace MARLIN_NAMESPACE_NAME diff --git a/csrc/quantization/gptq_marlin/generate_kernels.py b/csrc/quantization/gptq_marlin/generate_kernels.py index 18fb6c1a8..7576e0548 100644 --- a/csrc/quantization/gptq_marlin/generate_kernels.py +++ b/csrc/quantization/gptq_marlin/generate_kernels.py @@ -20,6 +20,7 @@ namespace MARLIN_NAMESPACE_NAME { TEMPLATE = ("template __global__ void Marlin<" "{{scalar_t}}, " "{{w_type_id}}, " + "{{s_type_id}}, " "{{threads}}, " "{{thread_m_blocks}}, " "{{thread_n_blocks}}, " @@ -78,7 +79,8 @@ def generate_new_kernels(): if scalar_type == "vllm::kFE4M3fn" and group_blocks not in [-1, 8]: continue # nvfp4 only supports group_size == 16 - if scalar_type == "vllm::kFE2M1f" and group_blocks != 1: + # mxfp4 only supports group_size == 32 + if scalar_type == "vllm::kFE2M1f" and group_blocks not in [1, 2]: continue # other quantization methods don't support group_size = 16 if scalar_type != "vllm::kFE2M1f" and group_blocks == 1: @@ -97,10 +99,23 @@ def generate_new_kernels(): # 4bit quantization and fp16 is_zp_float_list.append(True) + if scalar_type == "vllm::kFE2M1f" and group_blocks == 1: + s_type = "vllm::kFE4M3fn" + elif scalar_type == "vllm::kFE2M1f" and group_blocks == 2: + s_type = "vllm::kFE8M0fnu" + if dtype == "fp16": + # we cannot safely dequantize e8m0 to fp16, so skip this + continue + elif dtype == "fp16": + s_type = "vllm::kFloat16" + elif dtype == "bf16": + s_type = "vllm::kBFloat16" + for is_zp_float in is_zp_float_list: template_str = jinja2.Template(TEMPLATE).render( scalar_t=c_dtype, w_type_id=scalar_type + ".id()", + s_type_id=s_type + ".id()", threads=threads, thread_m_blocks=max(m_blocks, 1), thread_n_blocks=n_blocks, diff --git a/csrc/quantization/gptq_marlin/gptq_marlin.cu b/csrc/quantization/gptq_marlin/gptq_marlin.cu index 4a242f205..cc30abcf0 100644 --- a/csrc/quantization/gptq_marlin/gptq_marlin.cu +++ b/csrc/quantization/gptq_marlin/gptq_marlin.cu @@ -48,7 +48,8 @@ __global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr, torch::Tensor gptq_marlin_gemm( torch::Tensor& a, std::optional c_or_none, - torch::Tensor& b_q_weight, torch::Tensor& b_scales, + torch::Tensor& b_q_weight, + std::optional const& b_bias_or_none, torch::Tensor& b_scales, std::optional const& b_zeros_or_none, std::optional const& g_idx_or_none, std::optional const& perm_or_none, torch::Tensor& workspace, @@ -187,7 +188,12 @@ int get_kernel_cache_size(thread_config_t const& th_config, int thread_m_blocks, int tb_m = thread_m_blocks * 16; int sh_a_size = pipe_stages * (tb_m * tb_k) * 2; int sh_b_size = pipe_stages * (tb_k * tb_n / pack_factor) * 4; - int sh_red_size = tb_m * (tb_n + 8); + int sh_red_size = tb_m * (tb_n + 8) * 2; + int sh_bias_size = tb_n * 2; + int tmp_size = + (sh_b_size > sh_red_size ? sh_red_size : sh_b_size) + sh_bias_size; + tmp_size = max(max(sh_b_size, sh_red_size), tmp_size); + int sh_s_size = get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits, group_size, has_act_order, is_k_full); @@ -202,8 +208,8 @@ int get_kernel_cache_size(thread_config_t const& th_config, int thread_m_blocks, sh_zp_size = sh_s_size / 2; } - int total_size = max(sh_b_size, sh_red_size) + sh_a_size + sh_s_size + - sh_zp_size + sh_g_idx_size; + int total_size = + tmp_size + sh_a_size + sh_s_size + sh_zp_size + sh_g_idx_size; return total_size; } @@ -237,20 +243,25 @@ bool is_valid_config(thread_config_t const& th_config, int thread_m_blocks, int cache_size = get_kernel_cache_size( th_config, thread_m_blocks, prob_m, prob_n, prob_k, num_bits, group_size, has_act_order, is_k_full, has_zp, is_zp_float); - return cache_size <= max_shared_mem; + return cache_size + 512 <= max_shared_mem; } - #define _GET_IF(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \ - M_BLOCK_SIZE_8, GROUP_BLOCKS, NUM_THREADS, IS_ZP_FLOAT) \ - else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \ - thread_n_blocks == THREAD_N_BLOCKS && \ - thread_k_blocks == THREAD_K_BLOCKS && \ - m_block_size_8 == M_BLOCK_SIZE_8 && \ - group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS && \ - is_zp_float == IS_ZP_FLOAT) { \ - kernel = Marlin; \ + #define _GET_IF(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \ + M_BLOCK_SIZE_8, GROUP_BLOCKS, NUM_THREADS, IS_ZP_FLOAT) \ + else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \ + thread_n_blocks == THREAD_N_BLOCKS && \ + thread_k_blocks == THREAD_K_BLOCKS && \ + m_block_size_8 == M_BLOCK_SIZE_8 && \ + group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS && \ + is_zp_float == IS_ZP_FLOAT) { \ + constexpr auto S_TYPE = \ + W_TYPE == vllm::kFE2M1f \ + ? (GROUP_BLOCKS == 1 ? vllm::kFE4M3fn : vllm::kFE8M0fnu) \ + : (std::is_same::value ? vllm::kFloat16 \ + : vllm::kBFloat16); \ + kernel = Marlin; \ } // COMMON: cases for (group_blocks in [-1, 2, 4, 8] and is_zp_float == false) @@ -315,22 +326,39 @@ bool is_valid_config(thread_config_t const& th_config, int thread_m_blocks, BIGGROUP_GET_IF_M234(W_TYPE, 8, 4, 128) \ BIGGROUP_GET_IF_M234(W_TYPE, 4, 8, 128) - #define FP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + #define NVFP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 1, NUM_THREADS, false) \ _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) - #define FP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + #define NVFP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \ _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \ _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) - #define FP4_GET_IF(W_TYPE) \ - FP4_GET_IF_M1(W_TYPE, 8, 8, 256) \ - FP4_GET_IF_M1(W_TYPE, 8, 4, 128) \ - FP4_GET_IF_M1(W_TYPE, 4, 8, 128) \ - FP4_GET_IF_M234(W_TYPE, 16, 4, 256) \ - FP4_GET_IF_M234(W_TYPE, 8, 4, 128) \ - FP4_GET_IF_M234(W_TYPE, 4, 8, 128) + #define NVFP4_GET_IF(W_TYPE) \ + NVFP4_GET_IF_M1(W_TYPE, 8, 8, 256) \ + NVFP4_GET_IF_M1(W_TYPE, 8, 4, 128) \ + NVFP4_GET_IF_M1(W_TYPE, 4, 8, 128) \ + NVFP4_GET_IF_M234(W_TYPE, 16, 4, 256) \ + NVFP4_GET_IF_M234(W_TYPE, 8, 4, 128) \ + NVFP4_GET_IF_M234(W_TYPE, 4, 8, 128) + + #define MXFP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 2, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) + + #define MXFP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) + + #define MXFP4_GET_IF(W_TYPE) \ + MXFP4_GET_IF_M1(W_TYPE, 8, 8, 256) \ + MXFP4_GET_IF_M1(W_TYPE, 8, 4, 128) \ + MXFP4_GET_IF_M1(W_TYPE, 4, 8, 128) \ + MXFP4_GET_IF_M234(W_TYPE, 16, 4, 256) \ + MXFP4_GET_IF_M234(W_TYPE, 8, 4, 128) \ + MXFP4_GET_IF_M234(W_TYPE, 4, 8, 128) // We currently have 4-bit models only with group_blocks == 4 #define FZP_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ @@ -384,7 +412,7 @@ MarlinFuncPtr get_marlin_kernel(const vllm::ScalarType q_type, COMMON_GET_IF(vllm::kU4B8) COMMON_GET_IF(vllm::kU8B128) - FP4_GET_IF(vllm::kFE2M1f) + NVFP4_GET_IF(vllm::kFE2M1f) BIGGROUP_GET_IF(vllm::kFE4M3fn) @@ -396,6 +424,11 @@ MarlinFuncPtr get_marlin_kernel(const vllm::ScalarType q_type, } FZP_GET_IF(vllm::kU4) } + if (std::is_same::value) { + if (false) { + } + MXFP4_GET_IF(vllm::kFE2M1f) + } return kernel; } @@ -453,12 +486,12 @@ exec_config_t determine_exec_config(const vllm::ScalarType& q_type, int prob_m, } template -void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s, - void* s2, void* zp, void* g_idx, void* perm, void* a_tmp, - int prob_m, int prob_n, int prob_k, int lda, void* workspace, - vllm::ScalarType const& q_type, bool has_act_order, - bool is_k_full, bool has_zp, int num_groups, int group_size, - int dev, cudaStream_t stream, int thread_k_init, +void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias, + void* s, void* s2, void* zp, void* g_idx, void* perm, + void* a_tmp, int prob_m, int prob_n, int prob_k, int lda, + void* workspace, vllm::ScalarType const& q_type, bool has_bias, + bool has_act_order, bool is_k_full, bool has_zp, int num_groups, + int group_size, int dev, cudaStream_t stream, int thread_k_init, int thread_n_init, int sms, bool use_atomic_add, bool use_fp32_reduce, bool is_zp_float) { if (has_zp) { @@ -503,6 +536,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s, const int4* B_ptr = (const int4*)B; int4* C_ptr = (int4*)C; int4* C_tmp_ptr = (int4*)C_tmp; + const int4* bias_ptr = (const int4*)b_bias; const int4* s_ptr = (const int4*)s; const uint16_t* s2_ptr = (const uint16_t*)s2; const int4* zp_ptr = (const int4*)zp; @@ -623,8 +657,9 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s, // avoid ">>>" being formatted to "> > >" // clang-format off kernel<<>>( - A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, s2_ptr, zp_ptr, g_idx_ptr, num_groups, - prob_m_split, prob_n, prob_k, lda, locks, part_use_atomic_add, + A_ptr, B_ptr, C_ptr, C_tmp_ptr, bias_ptr, s_ptr, s2_ptr, zp_ptr, + g_idx_ptr, num_groups, + prob_m_split, prob_n, prob_k, lda, locks, has_bias, part_use_atomic_add, use_fp32_reduce, max_shared_mem_new); // clang-format on @@ -638,7 +673,8 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s, torch::Tensor gptq_marlin_gemm( torch::Tensor& a, std::optional c_or_none, - torch::Tensor& b_q_weight, torch::Tensor& b_scales, + torch::Tensor& b_q_weight, + std::optional const& b_bias_or_none, torch::Tensor& b_scales, std::optional const& global_scale_or_none, std::optional const& b_zeros_or_none, std::optional const& g_idx_or_none, @@ -785,12 +821,24 @@ torch::Tensor gptq_marlin_gemm( torch::Tensor global_scale; if (global_scale_or_none.has_value()) { global_scale = global_scale_or_none.value(); - TORCH_CHECK(b_q_type == vllm::kFE2M1f, - "global_scale can only be used for float4_e2m1f."); + TORCH_CHECK(b_q_type == vllm::kFE2M1f && group_size == 16, + "global_scale can only be used for nvfp4 format."); } else { global_scale = torch::empty({0}, options); - TORCH_CHECK(!(b_q_type == vllm::kFE2M1f), - "the global_scale parameter must be passed for float4_e2m1f."); + TORCH_CHECK(!(b_q_type == vllm::kFE2M1f && group_size == 16), + "the global_scale parameter must be passed for nvfp4 format."); + } + + bool has_bias = b_bias_or_none.has_value(); + torch::Tensor b_bias; + if (has_bias) { + b_bias = b_bias_or_none.value(); + TORCH_CHECK(b_bias.device().is_cuda(), "b_bias is not on GPU"); + TORCH_CHECK(b_bias.is_contiguous(), "b_bias is not contiguous"); + TORCH_CHECK(b_bias.size(0) == size_n, "b_bias.size(0) != size_n"); + TORCH_CHECK(b_bias.stride(0) == 1, "b_bias.stride(0) != 1"); + } else { + b_bias = torch::empty({0}, options); } torch::Tensor b_zeros; @@ -857,34 +905,50 @@ torch::Tensor gptq_marlin_gemm( if (a.scalar_type() == at::ScalarType::Half) { void* scales_ptr; if (b_q_type == vllm::kFE2M1f) { - scales_ptr = b_scales.data_ptr(); + if (group_size == 16) + scales_ptr = b_scales.data_ptr(); + else if (group_size == 32) + scales_ptr = b_scales.data_ptr(); + else + TORCH_CHECK(false, + "float4_e2m1f only supports group_size == 16 (NVFP4) ", + "and group_size == 32 (MXFP4)"); } else { scales_ptr = b_scales.data_ptr(); } marlin::marlin_mm( a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), - c_tmp.data_ptr(), scales_ptr, global_scale.data_ptr(), - b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(), - a_tmp.data_ptr(), size_m, size_n, size_k, a.stride(0), - workspace.data_ptr(), b_q_type, has_act_order, is_k_full, has_zp, - num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev), - thread_k, thread_n, sms, use_atomic_add, use_fp32_reduce, is_zp_float); + c_tmp.data_ptr(), b_bias.data_ptr(), scales_ptr, + global_scale.data_ptr(), b_zeros.data_ptr(), g_idx.data_ptr(), + perm.data_ptr(), a_tmp.data_ptr(), size_m, size_n, size_k, + a.stride(0), workspace.data_ptr(), b_q_type, has_bias, has_act_order, + is_k_full, has_zp, num_groups, group_size, dev, + at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, + use_atomic_add, use_fp32_reduce, is_zp_float); } else if (a.scalar_type() == at::ScalarType::BFloat16) { void* scales_ptr; if (b_q_type == vllm::kFE2M1f) { - scales_ptr = b_scales.data_ptr(); + if (group_size == 16) + scales_ptr = b_scales.data_ptr(); + else if (group_size == 32) + scales_ptr = b_scales.data_ptr(); + else + TORCH_CHECK(false, + "float4_e2m1f only supports group_size == 16 (NVFP4) ", + "and group_size == 32 (MXFP4)"); } else { scales_ptr = b_scales.data_ptr(); } marlin::marlin_mm( a.data_ptr(), b_q_weight.data_ptr(), - c.data_ptr(), c_tmp.data_ptr(), scales_ptr, + c.data_ptr(), c_tmp.data_ptr(), + b_bias.data_ptr(), scales_ptr, global_scale.data_ptr(), b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(), size_m, size_n, size_k, a.stride(0), workspace.data_ptr(), b_q_type, - has_act_order, is_k_full, has_zp, num_groups, group_size, dev, + has_bias, has_act_order, is_k_full, has_zp, num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, use_atomic_add, use_fp32_reduce, is_zp_float); } else { diff --git a/csrc/quantization/gptq_marlin/kernel.h b/csrc/quantization/gptq_marlin/kernel.h index f92056589..bb454f6af 100644 --- a/csrc/quantization/gptq_marlin/kernel.h +++ b/csrc/quantization/gptq_marlin/kernel.h @@ -10,15 +10,18 @@ #define MARLIN_KERNEL_PARAMS \ const int4 *__restrict__ A, const int4 *__restrict__ B, \ int4 *__restrict__ C, int4 *__restrict__ C_tmp, \ + const int4 *__restrict__ b_bias_ptr, \ const int4 *__restrict__ scales_ptr, \ const uint16_t *__restrict__ scale2_ptr, \ const int4 *__restrict__ zp_ptr, const int *__restrict__ g_idx, \ int num_groups, int prob_m, int prob_n, int prob_k, int lda, int *locks, \ - bool use_atomic_add, bool use_fp32_reduce, int max_shared_mem + bool has_bias, bool use_atomic_add, bool use_fp32_reduce, \ + int max_shared_mem namespace MARLIN_NAMESPACE_NAME { template ::FragZP; static constexpr auto w_type = vllm::ScalarType::from_id(w_type_id); + static constexpr auto s_type = vllm::ScalarType::from_id(s_type_id); + if constexpr (w_type == vllm::kFE2M1f) { + static_assert(s_type == vllm::kFE4M3fn && group_blocks == 1 || + s_type == vllm::kFE8M0fnu && group_blocks == 2); + } else if constexpr (std::is_same::value) { + static_assert(s_type == vllm::kBFloat16); + } else if constexpr (std::is_same::value) { + static_assert(s_type == vllm::kFloat16); + } + constexpr bool has_zp = w_type == vllm::kU4 || w_type == vllm::kU8; constexpr bool is_int_type = w_type == vllm::kU4 || w_type == vllm::kU8 || w_type == vllm::kU4B8 || w_type == vllm::kU8B128; // see comments of dequant.h for more details constexpr bool dequant_skip_flop = - !is_int_type || + w_type == vllm::kFE4M3fn || + w_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn || has_zp && !is_zp_float && !std::is_same::value || has_zp && !is_zp_float && !(w_type == vllm::kU8); scalar_t2 global_scale; - - if constexpr (w_type == vllm::kFE2M1f) { + if constexpr (w_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) { + // NVFP4 format requires global scale uint16_t val = scale2_ptr[0]; global_scale = Dtype::num2num2(*reinterpret_cast(&val)); } @@ -589,7 +604,7 @@ __global__ void Marlin( s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 4; - s_sh_rd = s_sh_rd * 2 + warp_row % 2; + s_sh_rd = s_sh_rd * 2 + (warp_row / group_blocks) % 2; } else if constexpr (group_blocks != -1) s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + @@ -602,6 +617,18 @@ __global__ void Marlin( s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) % 4; + int bias_sh_rd; + if constexpr (m_block_size_8) { + bias_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + (threadIdx.x % 32) / 8; + } else { + bias_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + (threadIdx.x % 32) % 4; + } + + int bias_sh_wr = threadIdx.x; + int bias_gl_rd = (thread_n_blocks * 16 / 8) * slice_col + threadIdx.x; + // Zero-points have the same read layout as the scales // (without column-wise case) constexpr int num_col_threads = 8; @@ -670,7 +697,19 @@ __global__ void Marlin( constexpr int sh_b_size = stages * b_sh_stage; int4* sh_b = sh; int4* sh_red = sh; - int4* sh_g_idx = sh_b + (sh_red_size > sh_b_size ? sh_red_size : sh_b_size); + + constexpr int sh_size_b_red_min = + (sh_red_size < sh_b_size ? sh_red_size : sh_b_size); + constexpr int sh_size_b_red_max = + (sh_red_size > sh_b_size ? sh_red_size : sh_b_size); + constexpr int sh_bias_size = (thread_n_blocks * 16 / 8); + constexpr int sh_b_red_bias_size = + sh_size_b_red_max > (sh_size_b_red_min + sh_bias_size) + ? sh_size_b_red_max + : (sh_size_b_red_min + sh_bias_size); + + int4* sh_bias = sh + sh_size_b_red_min; + int4* sh_g_idx = sh + sh_b_red_bias_size; int4* sh_zp = sh_g_idx + (stages * g_idx_stage); constexpr int sh_s_size = has_act_order ? (act_s_max_num_groups * s_sh_stride) : (stages * s_sh_stage); @@ -680,15 +719,13 @@ __global__ void Marlin( static_assert(thread_m_blocks * 16 * thread_n_blocks * 16 / 8 <= stages * b_sh_stage); int4* sh_a = sh_s + sh_s_size; - // constexpr int shm_size_used = - // stages * (g_idx_stage + zp_sh_stage) + sh_s_size + - // (sh_red_size > sh_b_size ? sh_red_size : sh_b_size); // Register storage for double buffer of shared memory reads. FragA frag_a[2][thread_m_blocks]; I4 frag_b_quant[2][b_thread_vecs]; FragC frag_c[thread_m_blocks][4][2]; - FragS frag_s[2][4]; // No act-order + FragS frag_s[2][4]; // No act-order + FragS frag_bias[2][4]; FragS act_frag_s[2][4][4]; // For act-order int frag_qzp[2][num_ints_per_thread]; // Zero-points FragZP frag_zp; // Zero-points in fp16 @@ -923,10 +960,15 @@ __global__ void Marlin( if constexpr (w_type_id != vllm::kFE2M1f.id()) { reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride]; - } else { + } else if constexpr (group_blocks == 1 || thread_k_blocks > 4) { reinterpret_cast(&frag_s[k % 2])[0] = reinterpret_cast( sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride)]; + } else { + reinterpret_cast(&frag_s[k % 2])[0] = + reinterpret_cast( + sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride) + + k % 2]; } } } @@ -1139,9 +1181,9 @@ __global__ void Marlin( int s_quant_0 = reinterpret_cast(frag_s[k2])[0]; int s_quant_1 = reinterpret_cast(frag_s[k2])[1]; - dequant_fp8_scales(s_quant_0, - reinterpret_cast(&frag_s[k2])); - dequant_fp8_scales( + dequant_fp8_scales( + s_quant_0, reinterpret_cast(&frag_s[k2])); + dequant_fp8_scales( s_quant_1, reinterpret_cast(&frag_s[k2]) + 2); } @@ -1411,7 +1453,7 @@ __global__ void Marlin( // Write out the reduce final result in the correct layout. We only actually // reshuffle matrix fragments in this step, the reduction above is performed // in fragment layout. - auto write_result = [&]() { + auto write_result = [&](bool last) { int c_gl_stride = prob_n / 8; constexpr int c_sh_stride = 2 * thread_n_blocks + 1; int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks)); @@ -1438,7 +1480,7 @@ __global__ void Marlin( int c_gl_wr_end = c_gl_stride * prob_m; // We first reorder in shared memory to guarantee the most efficient final // global write patterns - auto write = [&](int idx, float c0, float c1, FragS& s) { + auto write = [&](int idx, float c0, float c1, FragS& s, FragS& b_bias) { scalar_t2 res = Dtype::nums2num2(Dtype::float2num(c0), Dtype::float2num(c1)); @@ -1447,12 +1489,25 @@ __global__ void Marlin( if constexpr (!has_act_order && group_blocks == -1 && w_type.size_bits() == 4 && (has_zp && dequant_skip_flop || !has_zp)) { - res = __hmul2(res, s[0]); + scalar_t2 tmp_scale = s[0]; + if constexpr (m_block_size_8) { + tmp_scale = Dtype::num2num2( + reinterpret_cast(&s[0])[(threadIdx.x % 8) / 4]); + } + res = __hmul2(res, tmp_scale); } - if constexpr (w_type == vllm::kFE2M1f) { + if constexpr (w_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) { res = __hmul2(res, global_scale); } + if (has_bias && last) { + scalar_t2 tmp_bias = b_bias[0]; + if constexpr (m_block_size_8) { + tmp_bias = Dtype::num2num2( + reinterpret_cast(&b_bias[0])[(threadIdx.x % 8) / 4]); + } + res = __hadd2(res, tmp_bias); + } if constexpr (m_block_size_8) { ((scalar_t*)sh_red)[idx] = res.x; @@ -1470,19 +1525,25 @@ __global__ void Marlin( if constexpr (m_block_size_8) { int wr = c_sh_wr + 16 * j; write(wr, frag_c[i][j][0][0], frag_c[i][j][0][1], - frag_s[j / 2][2 * (j % 2) + 0]); + frag_s[j / 2][2 * (j % 2) + 0], + frag_bias[j / 2][2 * (j % 2) + 0]); write(wr + 8, frag_c[i][j][0][2], frag_c[i][j][0][3], - frag_s[j / 2][2 * (j % 2) + 1]); + frag_s[j / 2][2 * (j % 2) + 1], + frag_bias[j / 2][2 * (j % 2) + 1]); } else { int wr = c_sh_wr + 8 * j; write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0], - frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]); + frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0], + frag_bias[j / 2][2 * (j % 2) + 0]); write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2], - frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]); + frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0], + frag_bias[j / 2][2 * (j % 2) + 0]); write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0], - frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]); + frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1], + frag_bias[j / 2][2 * (j % 2) + 1]); write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2], - frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]); + frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1], + frag_bias[j / 2][2 * (j % 2) + 1]); } } c_sh_wr += 16 * (4 * c_sh_stride); @@ -1622,6 +1683,14 @@ __global__ void Marlin( } thread_block_reduce(); + + if (has_bias && last) { + __syncthreads(); + cp_async4_pred(&sh_bias[bias_sh_wr], &b_bias_ptr[bias_gl_rd], + threadIdx.x < 16 * thread_n_blocks / 8); + cp_async_fence(); + } + if constexpr (!has_act_order && group_blocks == -1 && (has_zp && dequant_skip_flop || !has_zp)) { if (w_type.size_bits() == 8 || (last || use_atomic_add)) { @@ -1684,11 +1753,20 @@ __global__ void Marlin( } barrier_release(&locks[locks_off], last); } + + if (has_bias && last) { + cp_async_wait<0>(); + __syncthreads(); + reinterpret_cast(&frag_bias)[0] = sh_bias[bias_sh_rd]; + reinterpret_cast(&frag_bias)[1] = sh_bias[bias_sh_rd + 4]; + __syncthreads(); + } + if (use_atomic_add && slice_count > 1 && slice_idx != 0) wait_negative_and_add(&locks[locks_off]); if (last || use_atomic_add) // only the last block in a slice actually writes the result - write_result(); + write_result(last); slice_row = 0; slice_col_par++; slice_col++; @@ -1706,6 +1784,7 @@ __global__ void Marlin( for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride; } + bias_gl_rd = (thread_n_blocks * 16 / 8) * slice_col + threadIdx.x; // Update slice k/n for scales loading if constexpr (has_act_order) { slice_k_start = tb_k * slice_row; diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 85b6abef0..a547baec5 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -142,25 +142,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.def("gelu_quick(Tensor! out, Tensor input) -> ()"); ops.impl("gelu_quick", torch::kCUDA, &gelu_quick); - // prepare_inputs advance_step - ops.def( - "advance_step_flashattn(int num_seqs, int num_queries, int block_size, " - "Tensor! input_tokens, Tensor sampled_token_ids, " - "Tensor! input_positions, Tensor! seq_lens, Tensor! slot_mapping, " - "Tensor block_tables) -> ()"); - ops.impl("advance_step_flashattn", torch::kCUDA, &advance_step_flashattn); - - ops.def( - "advance_step_flashinfer(" - " int num_seqs, int num_queries, int block_size," - " Tensor! input_tokens, Tensor sampled_token_ids," - " Tensor! input_positions, Tensor! seq_lens, Tensor! slot_mapping," - " Tensor block_tables, Tensor! paged_kv_indices," - " Tensor! paged_kv_indptr, Tensor! paged_kv_last_page_len," - " Tensor! block_table_bounds" - ") -> ()"); - ops.impl("advance_step_flashinfer", torch::kCUDA, &advance_step_flashinfer); - // Layernorm // Apply Root Mean Square (RMS) Normalization to the input tensor. ops.def( @@ -326,6 +307,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // gptq_marlin Optimized Quantized GEMM for GPTQ. ops.def( "gptq_marlin_gemm(Tensor a, Tensor? c_or_none, Tensor b_q_weight, " + "Tensor? b_bias_or_none," "Tensor b_scales, Tensor? global_scale, Tensor? b_zeros_or_none, Tensor? " "g_idx_or_none, Tensor? perm_or_none, Tensor workspace, int b_q_type, " "SymInt size_m, SymInt size_n, SymInt size_k, bool is_k_full, " diff --git a/docker/Dockerfile b/docker/Dockerfile index a20a4bfb2..66a6e6fd6 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -497,14 +497,11 @@ ENV HF_HUB_ENABLE_HF_TRANSFER 1 # Copy in the v1 package for testing (it isn't distributed yet) COPY vllm/v1 /usr/local/lib/python${PYTHON_VERSION}/dist-packages/vllm/v1 -# doc requires source code -# we hide them inside `test_docs/` , so that this source code +# Source code is used in the `python_only_compile.sh` test +# We hide it inside `src/` so that this source code # will not be imported by other tests -RUN mkdir test_docs -RUN mv docs test_docs/ -RUN cp -r examples test_docs/ -RUN mv vllm test_docs/ -RUN mv mkdocs.yaml test_docs/ +RUN mkdir src +RUN mv vllm src/vllm #################### TEST IMAGE #################### #################### OPENAI API SERVER #################### diff --git a/docs/models/extensions/fastsafetensor.md b/docs/models/extensions/fastsafetensor.md index 531d58690..2a5a18102 100644 --- a/docs/models/extensions/fastsafetensor.md +++ b/docs/models/extensions/fastsafetensor.md @@ -2,4 +2,5 @@ Loading Model weights with fastsafetensors =================================================================== Using fastsafetensors library enables loading model weights to GPU memory by leveraging GPU direct storage. See [their GitHub repository](https://github.com/foundation-model-stack/fastsafetensors) for more details. -For enabling this feature, set the environment variable ``USE_FASTSAFETENSOR`` to ``true`` + +To enable this feature, use the ``--load-format fastsafetensors`` command-line argument diff --git a/docs/usage/troubleshooting.md b/docs/usage/troubleshooting.md index 9715ad66d..b92c6cef4 100644 --- a/docs/usage/troubleshooting.md +++ b/docs/usage/troubleshooting.md @@ -35,6 +35,7 @@ You can check if this is happening by trying the old defaults with `--generation If other strategies don't solve the problem, it's likely that the vLLM instance is stuck somewhere. You can use the following environment variables to help debug the issue: - `export VLLM_LOGGING_LEVEL=DEBUG` to turn on more logging. +- `export VLLM_LOG_STATS_INTERVAL=1.` to get log statistics more frequently for tracking running queue, waiting queue and cache hit states. - `export CUDA_LAUNCH_BLOCKING=1` to identify which CUDA kernel is causing the problem. - `export NCCL_DEBUG=TRACE` to turn on more logging for NCCL. - `export VLLM_TRACE_FUNCTION=1` to record all function calls for inspection in the log files to tell which function crashes or hangs. Do not use this flag unless absolutely needed for debugging, it will cause significant delays in startup time. diff --git a/tests/compile/piecewise/test_full_cudagraph.py b/tests/compile/piecewise/test_full_cudagraph.py index efe9c843f..97140a9db 100644 --- a/tests/compile/piecewise/test_full_cudagraph.py +++ b/tests/compile/piecewise/test_full_cudagraph.py @@ -3,7 +3,8 @@ import contextlib import os import weakref -from contextlib import ExitStack +from dataclasses import dataclass +from typing import Optional import pytest @@ -32,27 +33,130 @@ def temporary_environ(env_vars): os.environ[k] = v +@dataclass +class BackendConfig: + name: str + env_vars: dict + comp_config: dict + specific_gpu_arch: Optional[tuple] = None + + +# Define all backend configurations of full cudagraph to be tested +backend_configs = { + # FA3 on Hopper + "FA3": + BackendConfig(name="FA3", + env_vars={"VLLM_FLASH_ATTN_VERSION": "3"}, + comp_config={ + "cudagraph_mode": "FULL", + }, + specific_gpu_arch=(9, 0)), + # FlashMLA on Hopper + "FlashMLA": + BackendConfig(name="FlashMLA", + env_vars={ + "VLLM_ATTENTION_BACKEND": "FLASHMLA", + }, + comp_config={ + "cudagraph_mode": "FULL_AND_PIECEWISE", + }, + specific_gpu_arch=(9, 0)), + # Cutlass MLA on Blackwell + "CutlassMLA": + BackendConfig( + name="CutlassMLA", + env_vars={ + "VLLM_USE_V1": "1", + "VLLM_ATTENTION_BACKEND": "CUTLASS_MLA", + "FORCE_NUM_KV_SPLITS": + "1", # TODO: remove this when hang issue is fixed + }, + comp_config={ + "cudagraph_mode": "FULL_AND_PIECEWISE", + "cudagraph_capture_sizes": [16, 32, 64, 128, 256, 512], + }, + specific_gpu_arch=(10, 0)), + # FA2 + "FA2": + BackendConfig(name="FA2", + env_vars={"VLLM_FLASH_ATTN_VERSION": "2"}, + comp_config={ + "cudagraph_mode": "FULL", + }), + # Triton Attention + "TritonAttn": + BackendConfig(name="TritonAttn", + env_vars={"VLLM_ATTENTION_BACKEND": "TRITON_ATTN_VLLM_V1"}, + comp_config={ + "cudagraph_mode": "FULL", + }), + # FlashInfer + "FlashInfer": + BackendConfig(name="FlashInfer", + env_vars={"VLLM_ATTENTION_BACKEND": "FLASHINFER"}, + comp_config={ + "cudagraph_mode": "FULL_AND_PIECEWISE", + }), +} + +test_params_full_cudagraph = [] + +# deepseek-ai/DeepSeek-V2-Lite with MLA +MLA_backends = ["FlashMLA", "CutlassMLA"] +for mla_backend in MLA_backends: + test_params_full_cudagraph.append( + pytest.param( + ("deepseek-ai/DeepSeek-V2-Lite", backend_configs[mla_backend]))) + +# Qwen/Qwen2-1.5B-Instruct with other backends +other_backend_configs = [ + backend_configs[c] for c in backend_configs if c not in MLA_backends +] +for backend_config in other_backend_configs: + test_params_full_cudagraph.append( + pytest.param(("Qwen/Qwen2-1.5B-Instruct", backend_config))) + + @pytest.fixture(scope="class") def llm_pair(request): - model = request.param + model, backend_config = request.param - with temporary_environ({ - "VLLM_USE_V1": "1", - "VLLM_FLASH_ATTN_VERSION": "3" - }): + # Dynamically skip test if GPU capability is not met + if backend_config.specific_gpu_arch and backend_config.specific_gpu_arch\ + != current_platform.get_device_capability(): + if backend_config.specific_gpu_arch == (9, 0): + pytest.skip("Only Hopper GPUs support FA3 and FlashMLA") + elif backend_config.specific_gpu_arch == (10, 0): + pytest.skip("Only Blackwell GPUs support Cutlass MLA") + + env_vars = { + "VLLM_USE_V1": "1", + # Force native sampler to avoid potential nondeterminism in FlashInfer + # when per-request generators are not used in V1. + "VLLM_USE_FLASHINFER_SAMPLER": "0", + **backend_config.env_vars, + } + with temporary_environ(env_vars): full = LLM( model=model, - gpu_memory_utilization=0.45, + gpu_memory_utilization=0.43, trust_remote_code=True, max_model_len=1024, - compilation_config=CompilationConfig(full_cuda_graph=True), + max_num_seqs=128, + compilation_config=\ + CompilationConfig(**backend_config.comp_config), + generation_config="vllm", + seed=42, ) piecewise = LLM( model=model, - gpu_memory_utilization=0.45, + gpu_memory_utilization=0.43, trust_remote_code=True, max_model_len=1024, - compilation_config=CompilationConfig(), + max_num_seqs=128, + compilation_config=CompilationConfig(cudagraph_mode="PIECEWISE"), + generation_config="vllm", + seed=42, ) # PyTest caches the fixture values so we use weakref.proxy to enable GC @@ -66,16 +170,7 @@ def llm_pair(request): ) -@pytest.mark.parametrize( - "llm_pair", - [ - # Model names for the llm_pair fixture - "deepseek-ai/DeepSeek-V2-Lite", - "Qwen/Qwen2-1.5B-Instruct" - ], - indirect=True) -@pytest.mark.skipif(current_platform.get_device_capability() != (9, 0), - reason="Only Hopper GPUs support FA3 and FlashMLA") +@pytest.mark.parametrize("llm_pair", test_params_full_cudagraph, indirect=True) class TestFullCUDAGraph: """ Use a class such that an llm pair is constructed once for all @@ -104,12 +199,14 @@ class TestFullCUDAGraph: full cudagraph compilation works for padded cases too. """ - piecewise_llm, full_cudagraph_llm = llm_pair + full_cudagraph_llm, piecewise_llm = llm_pair - prompts = ["Hello, my name is"] * batch_size + prompts = ["the quick brown fox"] * batch_size + # Use purely greedy decoding to avoid top-p truncation sensitivity + # that can amplify tiny numeric differences across runtimes. sampling_params = SamplingParams(temperature=0.0, max_tokens=max_tokens, - top_p=0.95) + top_p=1.0) piecewise_responses = piecewise_llm.generate(prompts, sampling_params) full_responses = full_cudagraph_llm.generate(prompts, sampling_params) @@ -117,42 +214,16 @@ class TestFullCUDAGraph: # Check that all responses are the same for piecewise_res, full_res in zip(piecewise_responses, full_responses): - assert piecewise_res.outputs[0].text == full_res.outputs[0].text - - -@pytest.mark.parametrize( - "model, supported", - [ - ("Qwen/Qwen2-1.5B-Instruct", True), - # MLA does not support capturing CUDA Graphs with size > max_num_seqs - ("deepseek-ai/DeepSeek-V2-Lite", False), - ]) -@pytest.mark.skipif(current_platform.get_device_capability() != (9, 0), - reason="Only Hopper GPUs support FA3 and FlashMLA") -def test_lower_max_num_seqs(model, supported): - with temporary_environ({ - "VLLM_USE_V1": "1", - "VLLM_FLASH_ATTN_VERSION": "3" - }), ExitStack() as stack: - if not supported: - stack.enter_context(pytest.raises(RuntimeError)) - - llm = LLM(model=model, - max_num_seqs=256, - trust_remote_code=True, - max_model_len=1024, - compilation_config=CompilationConfig( - full_cuda_graph=True, - cudagraph_capture_sizes=[64, 256, 512])) - llm.generate(["Hello, my name is"] * 10) + assert piecewise_res.outputs[0].text.lower() == \ + full_res.outputs[0].text.lower() @pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda") def test_full_cudagraph_with_invalid_backend(): with temporary_environ({ "VLLM_USE_V1": "1", - "VLLM_FLASH_ATTN_VERSION": - "2" #FA2 not supported with full_cuda_graph + "VLLM_ATTENTION_BACKEND": "FLEX_ATTENTION" + # Flex_Attention is not supported with full cuda graph }), pytest.raises(RuntimeError): LLM(model="Qwen/Qwen2-1.5B-Instruct", - compilation_config=CompilationConfig(full_cuda_graph=True)) + compilation_config=CompilationConfig(cudagraph_mode="FULL")) diff --git a/tests/compile/piecewise/test_simple.py b/tests/compile/piecewise/test_simple.py index 06ac3527e..2d1a72d44 100644 --- a/tests/compile/piecewise/test_simple.py +++ b/tests/compile/piecewise/test_simple.py @@ -11,10 +11,10 @@ from torch.library import Library from vllm.compilation.counter import compilation_counter from vllm.compilation.decorators import support_torch_compile -from vllm.config import (CompilationConfig, CompilationLevel, VllmConfig, - set_current_vllm_config) +from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode, + VllmConfig, set_current_vllm_config) from vllm.envs import VLLM_USE_V1 -from vllm.forward_context import set_forward_context +from vllm.forward_context import BatchDescriptor, set_forward_context from vllm.utils import direct_register_custom_op global_counter = 0 @@ -101,16 +101,33 @@ def test_simple_piecewise_compile(use_inductor): num_backend_compilations=3, # num_piecewise_capturable_graphs_seen num_cudagraph_captured= 6, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen - ), set_forward_context({}, vllm_config=vllm_config): - + ), set_forward_context(None, + vllm_config=vllm_config): # background context + # warm up with background context model(inputs) - model(torch.randn(2).cuda()) - model(torch.randn(1).cuda()) + # capturing/replaying should under context of cudagraph dispatching + with set_forward_context( + None, + vllm_config=vllm_config, + cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE, + batch_descriptor=BatchDescriptor(num_tokens=2, )): + model(torch.randn(2).cuda()) + with set_forward_context( + None, + vllm_config=vllm_config, + cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE, + batch_descriptor=BatchDescriptor(num_tokens=1, )): + model(torch.randn(1).cuda()) input = torch.zeros(2).cuda() global global_counter global_counter = 0 - output = model(input) + with set_forward_context( + None, + vllm_config=vllm_config, + cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE, + batch_descriptor=BatchDescriptor(num_tokens=2, )): + output = model(input) assert global_counter == 2 assert torch.allclose(output.cpu(), torch.tensor([3., 1.])) diff --git a/tests/compile/piecewise/test_toy_llama.py b/tests/compile/piecewise/test_toy_llama.py index b7ed8353b..bcfd0d834 100644 --- a/tests/compile/piecewise/test_toy_llama.py +++ b/tests/compile/piecewise/test_toy_llama.py @@ -18,9 +18,9 @@ from torch.library import Library from vllm.compilation.counter import compilation_counter from vllm.compilation.decorators import support_torch_compile -from vllm.config import (CompilationConfig, CompilationLevel, VllmConfig, - set_current_vllm_config) -from vllm.forward_context import set_forward_context +from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode, + VllmConfig, set_current_vllm_config) +from vllm.forward_context import BatchDescriptor, set_forward_context from vllm.utils import direct_register_custom_op # create a library to hold the custom op @@ -276,9 +276,11 @@ def run_model(llama_config, ) if split_attn: compilation_config.splitting_ops = ["silly.attention"] + cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE else: compilation_config = CompilationConfig( level=CompilationLevel.NO_COMPILATION, ) + cudagraph_runtime_mode = CUDAGraphMode.NONE vllm_config = VllmConfig(compilation_config=compilation_config, additional_config=llama_config) @@ -287,17 +289,37 @@ def run_model(llama_config, vllm_config=vllm_config, prefix="").eval().cuda() - with set_forward_context({}, vllm_config=vllm_config): + with set_forward_context({}, + vllm_config=vllm_config): # background context B = 16 # max batch size input_ids = torch.randint(0, llama_config.vocab_size, (B, )).cuda() positions = torch.arange(B).cuda() + # warmup for the model with cudagraph_mode NONE model(input_ids, positions) - model(input_ids[:2], positions[:2]) - model(input_ids[:1], positions[:1]) + + # simulate cudagraphs capturing + with set_forward_context({}, + vllm_config=vllm_config, + cudagraph_runtime_mode=cudagraph_runtime_mode, + batch_descriptor=BatchDescriptor( + num_tokens=2, )): + model(input_ids[:2], positions[:2]) + with set_forward_context({}, + vllm_config=vllm_config, + cudagraph_runtime_mode=cudagraph_runtime_mode, + batch_descriptor=BatchDescriptor( + num_tokens=1, )): + model(input_ids[:1], positions[:1]) input_ids[:2].zero_() - output = model(input_ids[:2], positions[:2]) + # simulate cudagraphs replay + with set_forward_context({}, + vllm_config=vllm_config, + cudagraph_runtime_mode=cudagraph_runtime_mode, + batch_descriptor=BatchDescriptor( + num_tokens=2, )): + output = model(input_ids[:2], positions[:2]) output = output.cpu() diff --git a/tests/kernels/attention/test_aiter_flash_attn.py b/tests/kernels/attention/test_aiter_flash_attn.py index d0687c62b..2d882bdf4 100644 --- a/tests/kernels/attention/test_aiter_flash_attn.py +++ b/tests/kernels/attention/test_aiter_flash_attn.py @@ -9,10 +9,10 @@ import torch import vllm.v1.attention.backends.rocm_aiter_fa # noqa: F401 from vllm.platforms import current_platform -NUM_HEADS = [(4, 4), (8, 2), (16, 2)] +NUM_HEADS = [(4, 4), (8, 2)] HEAD_SIZES = [128, 256] -BLOCK_SIZES = [16, 32] -DTYPES = [torch.float16, torch.bfloat16] +BLOCK_SIZES = [16] +DTYPES = [torch.bfloat16] QDTYPES = [None] # one value large enough to test overflow in index calculation. # one value small enough to test the schema op check diff --git a/tests/kernels/attention/test_attention.py b/tests/kernels/attention/test_attention.py index 2e0b4efeb..708366157 100644 --- a/tests/kernels/attention/test_attention.py +++ b/tests/kernels/attention/test_attention.py @@ -29,17 +29,14 @@ MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512 NUM_BLOCKS = 4321 # Arbitrary values for testing PARTITION_SIZE = 512 PARTITION_SIZE_ROCM = 256 -# flshattF and tritonflashattF supported: {torch.float16, torch.bfloat16} -DTYPES = [ - torch.half, torch.bfloat16, torch.float -] if not current_platform.is_rocm() else [torch.half, torch.bfloat16] +DTYPES = [torch.bfloat16] NUM_GEN_SEQS = [7] # Arbitrary values for testing NUM_PREFILL_SEQS = [3] # Arbitrary values for testing NUM_HEADS = [(40, 40), (64, 8)] # Arbitrary values for testing # This should be sync with get_supported_head_sizes() in # vllm.attention.ops.paged_attn.PagedAttention -HEAD_SIZES = [32, 64, 80, 96, 112, 120, 128, 192, 256] +HEAD_SIZES = [32, 80, 128, 256] BLOCK_SIZES = [16, 32] USE_ALIBI = [False, True] diff --git a/tests/kernels/attention/test_cache.py b/tests/kernels/attention/test_cache.py index 789507615..8c3cc8cba 100644 --- a/tests/kernels/attention/test_cache.py +++ b/tests/kernels/attention/test_cache.py @@ -11,11 +11,11 @@ from vllm import _custom_ops as ops from vllm.platforms import current_platform COPYING_DIRECTION = [('cuda', 'cpu'), ('cuda', 'cuda'), ('cpu', 'cuda')] -DTYPES = [torch.half, torch.bfloat16, torch.float] +DTYPES = [torch.bfloat16, torch.float] NUM_TOKENS = [42] # Arbitrary values for testing NUM_LAYERS = [1] # Arbitrary values for testing NUM_HEADS = [8] # Arbitrary values for testing -HEAD_SIZES = [64, 80, 120, 256] +HEAD_SIZES = [64, 80, 256] BLOCK_SIZES = [8, 16, 32] CACHE_LAYOUTS = ["NHD", "HND"] diff --git a/tests/kernels/attention/test_flash_attn.py b/tests/kernels/attention/test_flash_attn.py index bd3190d09..2544703f8 100644 --- a/tests/kernels/attention/test_flash_attn.py +++ b/tests/kernels/attention/test_flash_attn.py @@ -12,14 +12,16 @@ from vllm.vllm_flash_attn import (fa_version_unsupported_reason, flash_attn_with_kvcache, is_fa_version_supported) -NUM_HEADS = [(4, 4), (8, 2), (16, 2)] +NUM_HEADS = [(4, 4), (8, 2)] HEAD_SIZES = [128, 256] -BLOCK_SIZES = [16, 32] -DTYPES = [torch.float16, torch.bfloat16] +BLOCK_SIZES = [16] +DTYPES = [torch.bfloat16] QDTYPES = [None, torch.float8_e4m3fn] # one value large enough to test overflow in index calculation. # one value small enough to test the schema op check NUM_BLOCKS = [32768, 2048] +SOFT_CAPS = [None, 50.0] +SLIDING_WINDOWS = [None, 256] def ref_paged_attn( @@ -83,9 +85,9 @@ def ref_paged_attn( @pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("block_size", BLOCK_SIZES) @pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0]) +@pytest.mark.parametrize("soft_cap", SOFT_CAPS) @pytest.mark.parametrize("num_blocks", NUM_BLOCKS) -@pytest.mark.parametrize("sliding_window", [None, 256]) +@pytest.mark.parametrize("sliding_window", SLIDING_WINDOWS) @pytest.mark.parametrize("fa_version", [2, 3]) @pytest.mark.parametrize("q_dtype", QDTYPES) @torch.inference_mode() @@ -198,9 +200,9 @@ def test_flash_attn_with_paged_kv( @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("block_size", BLOCK_SIZES) -@pytest.mark.parametrize("sliding_window", [None, 256]) +@pytest.mark.parametrize("sliding_window", SLIDING_WINDOWS) @pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0]) +@pytest.mark.parametrize("soft_cap", SOFT_CAPS) @pytest.mark.parametrize("num_blocks", NUM_BLOCKS) @pytest.mark.parametrize("fa_version", [2, 3]) @pytest.mark.parametrize("q_dtype", QDTYPES) diff --git a/tests/kernels/attention/test_flashinfer.py b/tests/kernels/attention/test_flashinfer.py index 8f9b4ecea..be78f0e4f 100644 --- a/tests/kernels/attention/test_flashinfer.py +++ b/tests/kernels/attention/test_flashinfer.py @@ -9,11 +9,13 @@ import torch from vllm.platforms import current_platform -NUM_HEADS = [(16, 16), (32, 8), (64, 8), (6, 1)] +NUM_HEADS = [(32, 8), (6, 1)] HEAD_SIZES = [128, 256] BLOCK_SIZES = [16, 32] -DTYPES = [torch.float16, torch.bfloat16] +DTYPES = [torch.bfloat16] NUM_BLOCKS = 32768 # Large enough to test overflow in index calculation. +SOFT_CAPS = [None, 30.0] +SLIDING_WINDOWS = [None, 64] def ref_paged_attn( @@ -76,8 +78,8 @@ def ref_paged_attn( @pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("block_size", BLOCK_SIZES) @pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("soft_cap", [None, 30.0, 50.0]) -@pytest.mark.parametrize("sliding_window", [None, 64]) +@pytest.mark.parametrize("soft_cap", SOFT_CAPS) +@pytest.mark.parametrize("sliding_window", SLIDING_WINDOWS) @torch.inference_mode def test_flashinfer_decode_with_paged_kv( kv_lens: list[int], @@ -173,8 +175,8 @@ def test_flashinfer_decode_with_paged_kv( @pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("block_size", BLOCK_SIZES) @pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("soft_cap", [None, 30.0, 50.0]) -@pytest.mark.parametrize("sliding_window", [None, 64]) +@pytest.mark.parametrize("soft_cap", SOFT_CAPS) +@pytest.mark.parametrize("sliding_window", SLIDING_WINDOWS) @torch.inference_mode def test_flashinfer_prefill_with_paged_kv( seq_lens: list[tuple[int, int]], @@ -278,11 +280,11 @@ def test_flashinfer_prefill_with_paged_kv( @pytest.mark.parametrize("seq_lens", [[(1, 132), (5, 18)]]) -@pytest.mark.parametrize("num_heads", [(32, 8), (6, 1)]) +@pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("block_size", BLOCK_SIZES) @pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("soft_cap", [None, 30.0, 50.0]) +@pytest.mark.parametrize("soft_cap", SOFT_CAPS) def test_flashinfer_prefill_with_paged_fp8_kv( seq_lens: list[tuple[int, int]], num_heads: tuple[int, int], head_size: int, dtype: torch.dtype, block_size: int, @@ -385,11 +387,12 @@ def test_flashinfer_prefill_with_paged_fp8_kv( @pytest.mark.parametrize("kv_lens", [[1328, 18, 463], [1, 54, 293, 70]]) -@pytest.mark.parametrize("num_heads", [(32, 8), (64, 8), (6, 1)]) +@pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("block_size", BLOCK_SIZES) @pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("soft_cap", [None, 30.0, 50.0]) +@pytest.mark.parametrize("soft_cap", SOFT_CAPS) +@pytest.mark.skip(reason="TODO: fix the accuracy issue") @torch.inference_mode def test_flashinfer_decode_with_paged_fp8_kv( kv_lens: list[int], @@ -399,7 +402,6 @@ def test_flashinfer_decode_with_paged_fp8_kv( block_size: int, soft_cap: Optional[float], ) -> None: - pytest.skip("TODO: fix the accuracy issue") # test doesn't work for num_heads = (16,16) torch.set_default_device("cuda") current_platform.seed_everything(0) diff --git a/tests/kernels/attention/test_flashinfer_trtllm_attention.py b/tests/kernels/attention/test_flashinfer_trtllm_attention.py index e87ce520b..53e225ea3 100644 --- a/tests/kernels/attention/test_flashinfer_trtllm_attention.py +++ b/tests/kernels/attention/test_flashinfer_trtllm_attention.py @@ -20,11 +20,11 @@ FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 MAX_Q_LEN = 1024 MAX_KV_LEN = 4096 BATCH_SIZES = [4, 12] -NUM_HEADS = [(64, 8), (16, 16), (40, 8), (32, 8)] +NUM_HEADS = [(16, 16), (40, 8)] HEAD_SIZES = [128] -BLOCK_SIZES = [16, 32] +BLOCK_SIZES = [16] KV_LAYOUTS = ["HND"] -DTYPES = [torch.float16, torch.bfloat16] +DTYPES = [torch.bfloat16] KV_CACHE_DTYPES = [None, current_platform.fp8_dtype()] NUM_BLOCKS = 32768 # Large enough to test overflow in index calculation. SOFT_CAPS = [None, 50.0] diff --git a/tests/kernels/attention/test_prefix_prefill.py b/tests/kernels/attention/test_prefix_prefill.py index b09e1bbc4..8544eab3a 100644 --- a/tests/kernels/attention/test_prefix_prefill.py +++ b/tests/kernels/attention/test_prefix_prefill.py @@ -19,13 +19,13 @@ from vllm.platforms import current_platform from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE NUM_HEADS = [64] -NUM_QUERIES_PER_KV = [1, 8, 64] -HEAD_SIZES = [128, 96, 24] +NUM_QUERIES_PER_KV = [1, 64] +HEAD_SIZES = [24, 128] DTYPES = [torch.float16] CUDA_DEVICES = [ f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) ] -SLIDING_WINDOW = [0, 16, 64, 128, 256, 512, 2048] +SLIDING_WINDOW = [0, 16, 2048] KV_CACHE_DTYPES = ["auto", "fp8", "fp8_e5m2"] OPS = [chunked_prefill_paged_decode, context_attention_fwd] diff --git a/tests/kernels/attention/test_triton_unified_attention.py b/tests/kernels/attention/test_triton_unified_attention.py index 0cb7f5963..4b97d51e6 100644 --- a/tests/kernels/attention/test_triton_unified_attention.py +++ b/tests/kernels/attention/test_triton_unified_attention.py @@ -9,11 +9,11 @@ import torch from vllm.attention.ops.triton_unified_attention import unified_attention from vllm.platforms import current_platform -NUM_HEADS = [(4, 4), (8, 2), (16, 2)] +NUM_HEADS = [(4, 4), (8, 2)] HEAD_SIZES = [128, 256] -BLOCK_SIZES = [16, 32] +BLOCK_SIZES = [16] -DTYPES = [torch.float16, torch.bfloat16] +DTYPES = [torch.bfloat16] QDTYPES = [None, torch.float8_e4m3fn] if not current_platform.is_rocm() else [ None, torch.float8_e4m3fnuz ] @@ -85,7 +85,7 @@ def ref_paged_attn( @pytest.mark.parametrize("block_size", BLOCK_SIZES) @pytest.mark.parametrize("sliding_window", [None, 256]) @pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0]) +@pytest.mark.parametrize("soft_cap", [None, 50.0]) @pytest.mark.parametrize("num_blocks", NUM_BLOCKS) @pytest.mark.parametrize("q_dtype", QDTYPES) @torch.inference_mode() diff --git a/tests/kernels/mamba/test_mamba_ssm_ssd.py b/tests/kernels/mamba/test_mamba_ssm_ssd.py index d2b893fff..2c554baaf 100644 --- a/tests/kernels/mamba/test_mamba_ssm_ssd.py +++ b/tests/kernels/mamba/test_mamba_ssm_ssd.py @@ -9,7 +9,7 @@ from einops import rearrange, repeat from vllm.model_executor.layers.mamba.ops.ssd_combined import ( mamba_chunk_scan_combined) from vllm.platforms import current_platform -from vllm.v1.attention.backends.mamba_attn import ( +from vllm.v1.attention.backends.mamba2_attn import ( _query_start_loc_to_chunk_indices_offsets) # Added by the IBM Team, 2024 diff --git a/tests/kernels/moe/test_batched_moe.py b/tests/kernels/moe/test_batched_moe.py index 69317405d..edf3e6189 100644 --- a/tests/kernels/moe/test_batched_moe.py +++ b/tests/kernels/moe/test_batched_moe.py @@ -89,14 +89,11 @@ class BatchedMMTensors: return BatchedMMTensors(A, B, C, num_expert_tokens) -@pytest.mark.parametrize("num_experts", [8, 16, 32]) -@pytest.mark.parametrize("max_tokens_per_expert", - [32, 64, 128, 192, 224, 256, 512]) -@pytest.mark.parametrize("K", [128, 256, 1024]) -@pytest.mark.parametrize("N", [128, 256, 1024]) -@pytest.mark.parametrize( - "dtype", - [torch.float8_e4m3fn, torch.float32, torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("num_experts", [8, 32]) +@pytest.mark.parametrize("max_tokens_per_expert", [32, 224, 512]) +@pytest.mark.parametrize("K", [128, 1024]) +@pytest.mark.parametrize("N", [128, 1024]) +@pytest.mark.parametrize("dtype", [torch.float8_e4m3fn, torch.bfloat16]) @pytest.mark.parametrize("block_shape", [None, [128, 128]]) @pytest.mark.parametrize("per_act_token_quant", [False, True]) def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int, diff --git a/tests/kernels/moe/test_count_expert_num_tokens.py b/tests/kernels/moe/test_count_expert_num_tokens.py index 0872836b6..1768baaf1 100644 --- a/tests/kernels/moe/test_count_expert_num_tokens.py +++ b/tests/kernels/moe/test_count_expert_num_tokens.py @@ -113,8 +113,7 @@ def do_test_compute_expert_num_tokens(num_tokens: int, num_topk: int, rtol=0) -@pytest.mark.parametrize( - "num_tokens", [1, 4, 8, 11, 19, 128, 127, 405, 1024, 3333, 6666, 7317]) +@pytest.mark.parametrize("num_tokens", [1, 4, 8, 11, 127, 128, 3333, 7317]) @pytest.mark.parametrize("num_topk", [2, 6, 8]) @pytest.mark.parametrize("num_experts", [64]) @pytest.mark.parametrize("ep_size", [1, 2, 4]) @@ -126,7 +125,7 @@ def test_compute_expert_num_tokens(num_tokens: int, num_topk: int, ep_size, topk_ids_dtype) -@pytest.mark.parametrize("numel", list(range(1, 8192, 11))) +@pytest.mark.parametrize("numel", list(range(1, 8192, 111))) @pytest.mark.parametrize("num_experts", [32]) @pytest.mark.parametrize("ep_size", [2]) @pytest.mark.parametrize("topk_ids_dtype", [torch.int64]) diff --git a/tests/kernels/moe/test_moe.py b/tests/kernels/moe/test_moe.py index 49c097718..1951eb0c6 100644 --- a/tests/kernels/moe/test_moe.py +++ b/tests/kernels/moe/test_moe.py @@ -24,8 +24,10 @@ from vllm.model_executor.layers.fused_moe.fused_moe import ( fused_topk, modular_triton_fused_moe) from vllm.model_executor.layers.fused_moe.moe_torch_iterative import ( fused_moe as iterative_moe) +from vllm.model_executor.layers.quantization.utils.marlin_utils import ( + marlin_permute_bias) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( - rand_marlin_weight_fp4_like) + rand_marlin_weight_mxfp4_like, rand_marlin_weight_nvfp4_like) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( marlin_quant_fp8_torch) from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( @@ -40,6 +42,24 @@ NUM_EXPERTS = [8, 64, 192] EP_SIZE = [1, 4] TOP_KS = [2, 6] +FUSED_MOE_MNK_FACTORS = [ + (1, 128, 128), + (1, 2048, 128), + (33, 2048, 128), + (222, 1024, 1024), + (32768, 128, 128), + (32768, 2048, 511), + (40000, 1024, 1024), +] + +FUSED_MOE_WN16_MNK_FACTORS = [ + (1, 128, 128), + (1, 1024, 1024), + (32, 2048, 128), + (32, 1024, 1024), + (222, 2048, 1024), +] + vllm_config = VllmConfig() vllm_config.scheduler_config.max_num_seqs = 128 vllm_config.scheduler_config.max_model_len = 8192 @@ -114,13 +134,11 @@ def run_moe_test( return baseline_output -@pytest.mark.parametrize("m", [1, 33, 64, 222, 32768, 40000]) -@pytest.mark.parametrize("n", [128, 1024, 2048]) -@pytest.mark.parametrize("k", [128, 511, 1024]) +@pytest.mark.parametrize("m,n,k", FUSED_MOE_MNK_FACTORS) @pytest.mark.parametrize("e", NUM_EXPERTS) @pytest.mark.parametrize("topk", TOP_KS) @pytest.mark.parametrize("ep_size", EP_SIZE) -@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("padding", [True, False]) @pytest.mark.parametrize("chunk_size", [8192]) def test_fused_moe( @@ -233,13 +251,11 @@ def test_fused_moe( use_cudagraph=use_cudagraph) -@pytest.mark.parametrize("m", [1, 32, 222]) -@pytest.mark.parametrize("n", [128, 1024, 2048]) -@pytest.mark.parametrize("k", [128, 1024]) +@pytest.mark.parametrize("m,n,k", FUSED_MOE_WN16_MNK_FACTORS) @pytest.mark.parametrize("e", NUM_EXPERTS) @pytest.mark.parametrize("topk", TOP_KS) @pytest.mark.parametrize("ep_size", EP_SIZE) -@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("group_size", [64, 128]) @pytest.mark.parametrize("has_zp", [True, False]) @pytest.mark.parametrize("weight_bits", [4, 8]) @@ -350,8 +366,7 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int, torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0) -@pytest.mark.parametrize("dtype", - [torch.float32, torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("padding", [True, False]) @pytest.mark.parametrize( "use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False]) @@ -476,8 +491,11 @@ def marlin_moe_generate_valid_test_cases(): if quant_type == scalar_types.float8_e4m3fn and \ group_size not in [-1, 128]: return False - if quant_type == scalar_types.float4_e2m1f and group_size != 16: - return False + if quant_type == scalar_types.float4_e2m1f: + if group_size not in [16, 32]: + return False + if dtype == torch.float16 and group_size == 32: + return False if quant_type != scalar_types.float4_e2m1f and group_size == 16: return False @@ -520,31 +538,6 @@ def test_fused_marlin_moe( torch.cuda.manual_seed(0) has_zp = quant_type in [scalar_types.uint4, scalar_types.uint8] - if quant_type == scalar_types.float8_e4m3fn: - if group_size not in [-1, 128]: - return - if act_order: - return - - # Filter act_order - if act_order: - if quant_type == scalar_types.float8_e4m3fn: - return - if group_size == -1: - return - if group_size in (k, n): - return - if has_zp: - return - else: - if not is_k_full: - return - - if quant_type == scalar_types.float4_e2m1f and group_size != 16: - return - if quant_type != scalar_types.float4_e2m1f and group_size == 16: - return - a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 20 w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 20 @@ -569,13 +562,19 @@ def test_fused_marlin_moe( for i in range(w1.shape[0]): if quant_type == scalar_types.float4_e2m1f: - w_ref1, qweight1, scales1, global_scale1 = \ - rand_marlin_weight_fp4_like(w1[i], group_size) + if group_size == 16: + w_ref1, qweight1, scales1, global_scale1 = \ + rand_marlin_weight_nvfp4_like(w1[i], group_size) + else: + w_ref1, qweight1, scales1 = \ + rand_marlin_weight_mxfp4_like(w1[i], group_size) + global_scale1 = None w_ref1_l.append(w_ref1.T) qweight1_l.append(qweight1) scales1_l.append(scales1) - global_scale1_l.append(global_scale1) + if global_scale1 is not None: + global_scale1_l.append(global_scale1) elif quant_type == scalar_types.float8_e4m3fn: w_ref1, qweight1, scales1 = marlin_quant_fp8_torch( w1[i], group_size) @@ -620,13 +619,19 @@ def test_fused_marlin_moe( for i in range(w2.shape[0]): if quant_type == scalar_types.float4_e2m1f: - w_ref2, qweight2, scales2, global_scale2 = \ - rand_marlin_weight_fp4_like(w2[i], group_size) + if group_size == 16: + w_ref2, qweight2, scales2, global_scale2 = \ + rand_marlin_weight_nvfp4_like(w2[i], group_size) + else: + w_ref2, qweight2, scales2 = \ + rand_marlin_weight_mxfp4_like(w2[i], group_size) + global_scale2 = None w_ref2_l.append(w_ref2.T) qweight2_l.append(qweight2) scales2_l.append(scales2) - global_scale2_l.append(global_scale2) + if global_scale2 is not None: + global_scale2_l.append(global_scale2) elif quant_type == scalar_types.float8_e4m3fn: w_ref2, qweight2, scales2 = marlin_quant_fp8_torch( w2[i], group_size) @@ -677,6 +682,8 @@ def test_fused_marlin_moe( a, qweight1, qweight2, + None, + None, scales1, scales2, score, @@ -698,6 +705,119 @@ def test_fused_marlin_moe( torch.testing.assert_close(marlin_output, torch_output, atol=5e-2, rtol=0) +@pytest.mark.flaky(reruns=2) +@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm") +@pytest.mark.parametrize("m", [1, 256]) +def test_fused_marlin_moe_with_bias(m): + torch.cuda.manual_seed(0) + + e, topk = 32, 4 + n, k = 2048, 2048 + group_size = 128 + act_order = False + is_k_full = True + quant_type = scalar_types.uint4b8 + dtype = torch.half + + a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 + w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 + w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 + b_bias1 = torch.randn((e, 2 * n), device="cuda", dtype=dtype) / 10 + b_bias2 = torch.randn((e, k), device="cuda", dtype=dtype) / 10 + + b_bias1_l = [] + w_ref1_l = [] + qweight1_l = [] + scales1_l = [] + g_idx1_l = [] + sort_indices1_l = [] + + for i in range(w1.shape[0]): + test_perm = torch.randperm(k) + w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = \ + marlin_quantize(w1[i].transpose(1, 0), quant_type, + group_size, act_order, test_perm) + + w_ref1_l.append(w_ref1.T) + qweight1_l.append(qweight1) + scales1_l.append(scales1) + g_idx1_l.append(g_idx1) + sort_indices1_l.append(sort_indices1) + b_bias1_l.append(marlin_permute_bias(b_bias1[i])) + + w_ref1 = stack_and_dev(w_ref1_l) + qweight1 = stack_and_dev(qweight1_l).contiguous() + scales1 = stack_and_dev(scales1_l) + global_scale1 = None + g_idx1 = stack_and_dev(g_idx1_l) if g_idx1_l else None + zeros1 = None + sort_indices1 = stack_and_dev(sort_indices1_l) if sort_indices1_l else None + marlin_bias1 = stack_and_dev(b_bias1_l) if b_bias1_l else None + + b_bias2_l = [] + w_ref2_l = [] + qweight2_l = [] + scales2_l = [] + g_idx2_l = [] + sort_indices2_l = [] + + for i in range(w2.shape[0]): + test_perm = torch.randperm(n) + w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = \ + marlin_quantize(w2[i].transpose(1, 0), quant_type, + group_size, act_order, test_perm) + + w_ref2_l.append(w_ref2.T) + qweight2_l.append(qweight2) + scales2_l.append(scales2) + g_idx2_l.append(g_idx2) + sort_indices2_l.append(sort_indices2) + b_bias2_l.append(marlin_permute_bias(b_bias2[i])) + + w_ref2 = stack_and_dev(w_ref2_l) + qweight2 = stack_and_dev(qweight2_l).contiguous() + scales2 = stack_and_dev(scales2_l) + global_scale2 = None + g_idx2 = stack_and_dev(g_idx2_l) if g_idx2_l else None + zeros2 = None + sort_indices2 = stack_and_dev(sort_indices2_l) if sort_indices2_l else None + marlin_bias2 = stack_and_dev(b_bias2_l) if b_bias2_l else None + + score = torch.randn((m, e), device="cuda", dtype=dtype) + + topk_weights, topk_ids, _ = fused_topk(a, score, topk, False) + + with set_current_vllm_config(vllm_config): + torch_output = torch_moe(a, w_ref1, w_ref2, score, topk, b_bias1, + b_bias2) + + marlin_output = torch.ops.vllm.fused_marlin_moe( + a, + qweight1, + qweight2, + marlin_bias1, + marlin_bias2, + scales1, + scales2, + score, + topk_weights, + topk_ids, + global_num_experts=e, + expert_map=None, + global_scale1=global_scale1, + global_scale2=global_scale2, + g_idx1=g_idx1, + g_idx2=g_idx2, + sort_indices1=sort_indices1, + sort_indices2=sort_indices2, + w1_zeros=zeros1, + w2_zeros=zeros2, + quant_type_id=quant_type.id, + is_k_full=is_k_full) + + torch.testing.assert_close(marlin_output, torch_output, atol=5e-2, rtol=0) + + def test_moe_align_block_size_opcheck(): num_experts = 4 block_size = 4 diff --git a/tests/kernels/moe/test_moe_align_block_size.py b/tests/kernels/moe/test_moe_align_block_size.py index 12ef9e776..5dfc8d9fa 100644 --- a/tests/kernels/moe/test_moe_align_block_size.py +++ b/tests/kernels/moe/test_moe_align_block_size.py @@ -15,10 +15,10 @@ from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( from vllm.platforms import current_platform from vllm.utils import round_up -NUM_TOKENS = [1, 3, 7, 16, 256, 2256, 4096] -NUM_EXPERTS = [32, 160, 256, 257, 512] +NUM_TOKENS = [1, 3, 256, 2256, 4096] +NUM_EXPERTS = [32, 160, 256, 257] TOP_KS = [1, 2, 16, 32] -BLOCK_SIZES = [32, 64, 128, 256] +BLOCK_SIZES = [32, 128] current_platform.seed_everything(0) diff --git a/tests/kernels/moe/test_moe_permute_unpermute.py b/tests/kernels/moe/test_moe_permute_unpermute.py index 8d215a0cb..6ca01f927 100644 --- a/tests/kernels/moe/test_moe_permute_unpermute.py +++ b/tests/kernels/moe/test_moe_permute_unpermute.py @@ -18,7 +18,7 @@ from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import ( from vllm.platforms import current_platform NUM_EXPERTS = [16, 64, 256] -TOP_KS = [2, 4, 6, 8] +TOP_KS = [2, 6, 8] EP_SIZE = [1, 4, 16] current_platform.seed_everything(0) @@ -177,11 +177,11 @@ def torch_unpermute(permuted_hidden_states: torch.Tensor, return output -@pytest.mark.parametrize("n_token", [1, 33, 64, 222, 1024, 2048, 3000, 5000]) -@pytest.mark.parametrize("n_hidden", [2048, 4096, 7168]) +@pytest.mark.parametrize("n_token", [1, 33, 1024, 5000]) +@pytest.mark.parametrize("n_hidden", [2048, 7168]) @pytest.mark.parametrize("n_expert", NUM_EXPERTS) @pytest.mark.parametrize("topk", TOP_KS) -@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("ep_size", EP_SIZE) @pytest.mark.parametrize("align_block_size", [None, 128]) def test_moe_permute_unpermute(n_token: int, n_hidden: int, topk: int, diff --git a/tests/kernels/moe/test_pplx_moe.py b/tests/kernels/moe/test_pplx_moe.py index f7a661b4b..fbef6706b 100644 --- a/tests/kernels/moe/test_pplx_moe.py +++ b/tests/kernels/moe/test_pplx_moe.py @@ -44,6 +44,14 @@ requires_pplx = pytest.mark.skipif( reason="Requires PPLX kernels", ) +BATCHED_MOE_MNK_FACTORS = [ + (1, 128, 128), + (33, 2048, 128), + (64, 128, 2048), + (222, 128, 128), + (222, 2048, 1024), +] + PPLX_COMBOS = [ # TODO: figure out why this fails, seems to be test problem #(1, 128, 128), @@ -152,9 +160,7 @@ def torch_batched_moe( return torch_finalize(out, topk_weight, topk_ids) -@pytest.mark.parametrize("m", [1, 33, 64, 222]) -@pytest.mark.parametrize("n", [128, 1024, 2048]) -@pytest.mark.parametrize("k", [128, 512, 1024]) +@pytest.mark.parametrize("m,n,k", BATCHED_MOE_MNK_FACTORS) @pytest.mark.parametrize("e", NUM_EXPERTS) @pytest.mark.parametrize("topk", TOP_KS) @pytest.mark.parametrize("dtype", [torch.bfloat16]) diff --git a/tests/kernels/quantization/test_flashinfer_nvfp4_scaled_mm.py b/tests/kernels/quantization/test_flashinfer_nvfp4_scaled_mm.py new file mode 100644 index 000000000..131086a5f --- /dev/null +++ b/tests/kernels/quantization/test_flashinfer_nvfp4_scaled_mm.py @@ -0,0 +1,139 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import pytest +import torch +from nvfp4_utils import (FLOAT4_E2M1_MAX, FLOAT8_E4M3_MAX, + convert_swizzled_to_linear, dequantize_nvfp4_to_dtype) + +from vllm import _custom_ops as ops +from vllm.platforms import current_platform +from vllm.utils.flashinfer import flashinfer_scaled_fp4_mm + +if not current_platform.has_device_capability(100): + pytest.skip( + reason="Nvfp4 Requires compute capability of 10 or above.", + allow_module_level=True, + ) + +DTYPES = [torch.float16, torch.bfloat16] +# m, n, k +SHAPES = [(128, 128, 64), (128, 128, 128), (256, 128, 64), (128, 256, 128)] +PAD_SHAPES = [(150, 128, 64), (128, 128, 96)] +SHAPES.extend(PAD_SHAPES) + +SEEDS = [42] +CUDA_DEVICES = ["cuda:0"] + + +def get_ref_results( + a_fp4, + b_fp4, + a_sf, + b_sf, + a_global_scale, + b_global_scale, + m, + n, + dtype, + block_size, + device, +): + _, m_k = a_fp4.shape + _, n_k = b_fp4.shape + assert m_k == n_k + a_in_dtype = dequantize_nvfp4_to_dtype(a_fp4, + a_sf, + a_global_scale, + dtype=dtype, + device=device, + block_size=block_size) + b_in_dtype = dequantize_nvfp4_to_dtype(b_fp4, + b_sf, + b_global_scale, + dtype=dtype, + device=device, + block_size=block_size) + return torch.matmul(a_in_dtype, b_in_dtype.t()) + + +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("shape", SHAPES) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("backend", ["cutlass", "trtllm"]) +@pytest.mark.parametrize("autotune", [False, True]) +@torch.inference_mode() +def test_flashinfer_nvfp4_gemm( + dtype: torch.dtype, + shape: tuple[int, int, int], + seed: int, + device: str, + backend: str, + autotune: bool, +) -> None: + if backend == "trtllm" and dtype == torch.float16: + pytest.skip( + "Only torch.bfloat16 is supported for TRTLLM FP4 GEMM operations") + + current_platform.seed_everything(seed) + m, n, packed_k = shape + k = packed_k * 2 + block_size = 16 + a_dtype = torch.randn((m, k), dtype=dtype, device=device) + b_dtype = torch.randn((n, k), dtype=dtype, device=device) + + a_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / + torch.amax(a_dtype.flatten(), dim=-1)).to(torch.float32) + b_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / + torch.amax(b_dtype.flatten(), dim=-1)).to(torch.float32) + alpha = 1.0 / (a_global_scale * b_global_scale) + # ops.scaled_fp4_quant returns swizzled scales, while weights + # from checkpoints are in linear scales. + # So instead of needing to swizzle for cutlass as in modelopt.py, + # we need to unswizzle for trtllm here. + a_fp4, a_scale_interleaved = ops.scaled_fp4_quant(a_dtype, a_global_scale) + b_fp4, b_scale_interleaved = ops.scaled_fp4_quant(b_dtype, b_global_scale) + + # get_ref_results unswizzles the scales internally. + expected_out = get_ref_results( + a_fp4, + b_fp4, + a_scale_interleaved, + b_scale_interleaved, + a_global_scale, + b_global_scale, + m, + n, + dtype, + block_size, + device, + ) + + import flashinfer + + if backend == "trtllm": + epilogue_tile_m = 128 + b_fp4 = flashinfer.shuffle_matrix_a(b_fp4.view(torch.uint8), + epilogue_tile_m) + + b_scale_interleaved = convert_swizzled_to_linear( + b_scale_interleaved, n, k, block_size) + b_scale_interleaved = (flashinfer.shuffle_matrix_sf_a( + b_scale_interleaved.view(torch.uint8), epilogue_tile_m).reshape( + b_scale_interleaved.shape).view(torch.float8_e4m3fn)) + + with flashinfer.autotune(autotune): + out = flashinfer_scaled_fp4_mm( + a_fp4, + b_fp4, + a_scale_interleaved, + b_scale_interleaved, + alpha, + dtype, + backend=backend, + ) + + torch.testing.assert_close(out, + expected_out.to(dtype=dtype), + atol=1e-1, + rtol=1e-1) diff --git a/tests/kernels/quantization/test_fp8_quant.py b/tests/kernels/quantization/test_fp8_quant.py index 0a3edd4dd..c2e70ffb8 100644 --- a/tests/kernels/quantization/test_fp8_quant.py +++ b/tests/kernels/quantization/test_fp8_quant.py @@ -11,11 +11,9 @@ from tests.kernels.quant_utils import (FP8_DTYPE, from tests.kernels.utils import opcheck from vllm.platforms import current_platform -DTYPES = [torch.half, torch.bfloat16, torch.float] -HIDDEN_SIZES = [1, 2, 3, 4, 16, 67, 768, 2048, 5120, 5137, 8192, - 8193] # Arbitrary values for testing -HIDDEN_SIZES += list(range(1024, 1033)) # vectorized conversion edge cases -NUM_TOKENS = [1, 7, 83, 4096] # Arbitrary values for testing +DTYPES = [torch.bfloat16, torch.float] +HIDDEN_SIZES = [17, 1024, 1025, 1026, 5137, 8193] +NUM_TOKENS = [1, 7, 4096] SCALE_UBS = [True, False] SEEDS = [0] diff --git a/tests/kernels/quantization/test_int8_quant.py b/tests/kernels/quantization/test_int8_quant.py index 5a37b976d..c1c9bf191 100644 --- a/tests/kernels/quantization/test_int8_quant.py +++ b/tests/kernels/quantization/test_int8_quant.py @@ -9,10 +9,9 @@ from tests.kernels.utils import opcheck from vllm._custom_ops import scaled_int8_quant from vllm.platforms import current_platform -DTYPES = [torch.half, torch.bfloat16, torch.float] -HIDDEN_SIZES = [16, 67, 768, 5137, 8193] # Arbitrary values for testing -HIDDEN_SIZES += list(range(1024, 1033)) # vectorized conversion edge cases -NUM_TOKENS = [1, 7, 83, 4096] # Arbitrary values for testing +DTYPES = [torch.bfloat16, torch.float] +HIDDEN_SIZES = [17, 1024, 1025, 1026, 5137, 8193] +NUM_TOKENS = [1, 7, 4096] SEEDS = [0] SCALE = [0.1, 2.1] diff --git a/tests/kernels/quantization/test_machete_mm.py b/tests/kernels/quantization/test_machete_mm.py index a7cb2a4e7..a842d2f1c 100644 --- a/tests/kernels/quantization/test_machete_mm.py +++ b/tests/kernels/quantization/test_machete_mm.py @@ -34,8 +34,6 @@ IS_SUPPORTED_BY_GPU = current_platform.get_device_capability()[0] >= 9 MNK_SHAPES = [ (1, 128, 128), - (1, 512, 1024), - (1, 4096, 4096), (1, 8192, 28672), (13, 8192, 4096), (26, 4096, 8192), @@ -43,8 +41,6 @@ MNK_SHAPES = [ (64, 8192, 28672), (257, 128, 4096), (257, 4224, 4160), - (257, 4096, 4096), - (1024, 4096, 8192), (1024, 8192, 4096), ] diff --git a/tests/kernels/quantization/test_marlin_gemm.py b/tests/kernels/quantization/test_marlin_gemm.py index 92914bd5c..cea7700ac 100644 --- a/tests/kernels/quantization/test_marlin_gemm.py +++ b/tests/kernels/quantization/test_marlin_gemm.py @@ -19,10 +19,11 @@ from vllm.model_executor.layers.quantization.qqq import ( from vllm.model_executor.layers.quantization.utils.marlin_utils import ( GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, MARLIN_SUPPORTED_GROUP_SIZES, marlin_make_empty_g_idx, - marlin_make_workspace_new, marlin_permute_scales, + marlin_make_workspace_new, marlin_permute_bias, marlin_permute_scales, query_marlin_supported_quant_types) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( - FP4_MARLIN_SUPPORTED_GROUP_SIZES, rand_marlin_weight_fp4_like) + FP4_MARLIN_SUPPORTED_GROUP_SIZES, rand_marlin_weight_mxfp4_like, + rand_marlin_weight_nvfp4_like) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( marlin_quant_fp8_torch) from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( @@ -39,7 +40,7 @@ from vllm.scalar_type import scalar_types ACT_ORDER_OPTS = [False, True] K_FULL_OPTS = [False, True] USE_ATOMIC_ADD_OPTS = [False, True] -USE_FP32_REDUCE_OPTS = [False, True] +USE_FP32_REDUCE_OPTS = [True] MARLIN_K_CHUNKS = [128] MARLIN_N_CHUNKS = [64, 256] @@ -52,12 +53,8 @@ HQQ_SUPPORTED_GROUP_SIZES = [64] MNK_FACTORS = [ (1, 1, 1), (1, 4, 8), - (1, 7, 5), - (13, 17, 67), (26, 37, 13), - (67, 13, 11), (257, 13, 11), - (658, 13, 11), ] DTYPES = [torch.float16, torch.bfloat16] @@ -202,17 +199,10 @@ def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, group_size, @pytest.mark.parametrize("is_k_full", K_FULL_OPTS) @pytest.mark.parametrize("use_atomic_add", USE_ATOMIC_ADD_OPTS) @pytest.mark.parametrize("use_fp32_reduce", USE_FP32_REDUCE_OPTS) -def test_gptq_marlin_gemm( - k_chunk, - n_chunk, - quant_type, - group_size, - mnk_factors, - act_order, - is_k_full, - use_atomic_add, - use_fp32_reduce, -): +@pytest.mark.parametrize("dtype", DTYPES) +def test_gptq_marlin_gemm(k_chunk, n_chunk, quant_type, group_size, + mnk_factors, act_order, is_k_full, use_atomic_add, + use_fp32_reduce, dtype): m_factor, n_factor, k_factor = mnk_factors has_zp = quant_type in [scalar_types.uint4, scalar_types.uint8] @@ -231,14 +221,23 @@ def test_gptq_marlin_gemm( if size_k % group_size != 0: return - a_input = rand_data((size_m, size_k)) - b_weight = rand_data((size_k, size_n)) + a_input = rand_data((size_m, size_k), dtype) + b_weight = rand_data((size_k, size_n), dtype) if quant_type == scalar_types.float4_e2m1f: - if group_size != 16 or act_order: + if group_size not in [16, 32] or act_order: return - w_ref, marlin_q_w, marlin_s, marlin_s2 = rand_marlin_weight_fp4_like( - b_weight.T, group_size) + if group_size == 32 and dtype == torch.float16: + return + + if group_size == 16: + w_ref, marlin_q_w, marlin_s, marlin_s2 = \ + rand_marlin_weight_nvfp4_like(b_weight.T, group_size) + else: + w_ref, marlin_q_w, marlin_s = \ + rand_marlin_weight_mxfp4_like(b_weight.T, group_size) + marlin_s2 = None + g_idx = None sort_indices = None marlin_zp = None @@ -272,8 +271,8 @@ def test_gptq_marlin_gemm( workspace = marlin_make_workspace_new(w_ref.device) opcheck(torch.ops._C.gptq_marlin_gemm, - (a_input, None, marlin_q_w, marlin_s, marlin_s2, marlin_zp, g_idx, - sort_indices, workspace, quant_type.id, a_input.shape[0], + (a_input, None, marlin_q_w, None, marlin_s, marlin_s2, marlin_zp, + g_idx, sort_indices, workspace, quant_type.id, a_input.shape[0], b_weight.shape[1], a_input.shape[1], is_k_full, use_atomic_add, use_fp32_reduce, False), test_utils=DEFAULT_OPCHECK_TEST_UTILS) @@ -282,6 +281,7 @@ def test_gptq_marlin_gemm( a_input, None, marlin_q_w, + None, marlin_s, marlin_s2, marlin_zp, @@ -418,6 +418,7 @@ def test_hqq_marlin_gemm( a_input, None, marlin_w_q, + None, marlin_s, None, marlin_zp, @@ -531,6 +532,7 @@ def test_marlin_gemm_subset_input(): a_input, None, marlin_q_w, + None, marlin_s, None, marlin_zp, @@ -555,6 +557,53 @@ def test_marlin_gemm_subset_input(): assert max_diff < 0.04 +@pytest.mark.parametrize("size_m", [1, 256]) +def test_marlin_gemm_with_bias(size_m): + quant_type = scalar_types.uint4b8 + group_size = 128 + + size_k, size_n = 1024, 2048 + a_input = rand_data((size_m, size_k)) + b_weight = rand_data((size_k, size_n)) + b_bias = rand_data((size_n, )) * 10 + + marlin_bias = marlin_permute_bias(b_bias) + + w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize( + b_weight, quant_type, group_size, False) + + marlin_zp = marlin_make_empty_g_idx(marlin_s.device) + workspace = marlin_make_workspace_new(a_input.device) + + output = ops.gptq_marlin_gemm( + a_input, + None, + marlin_q_w, + marlin_bias, + marlin_s, + None, + marlin_zp, + g_idx, + sort_indices, + workspace, + quant_type, + a_input.shape[0], + b_weight.shape[1], + a_input.shape[1], + is_k_full=True, + use_atomic_add=False, + use_fp32_reduce=True, + is_zp_float=False, + ) + output_ref = torch.matmul(a_input, w_ref) + b_bias.view(1, -1) + + torch.cuda.synchronize() + + max_diff = compute_max_diff(output, output_ref) + + assert max_diff < 0.04 + + def test_marlin_gemm_opcheck(): size_m = 2048 size_n = 4096 diff --git a/tests/kernels/quantization/test_nvfp4_scaled_mm.py b/tests/kernels/quantization/test_nvfp4_scaled_mm.py index 0b45c2298..67e041f2b 100644 --- a/tests/kernels/quantization/test_nvfp4_scaled_mm.py +++ b/tests/kernels/quantization/test_nvfp4_scaled_mm.py @@ -65,9 +65,12 @@ def test_nvfp4_gemm( b_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(b_dtype.flatten(), dim=-1)).to(torch.float32) alpha = 1. / (a_global_scale * b_global_scale) + # ops.scaled_fp4_quant returns swizzled scales, while weights + # from checkpoints are in linear scales. a_fp4, a_scale_interleaved = ops.scaled_fp4_quant(a_dtype, a_global_scale) b_fp4, b_scale_interleaved = ops.scaled_fp4_quant(b_dtype, b_global_scale) + # get_ref_results unswizzles the scales internally. expected_out = get_ref_results(a_fp4, b_fp4, a_scale_interleaved, b_scale_interleaved, a_global_scale, b_global_scale, m, n, dtype, block_size, diff --git a/tests/kernels/quantization/test_rocm_skinny_gemms.py b/tests/kernels/quantization/test_rocm_skinny_gemms.py index 533a4fe59..03d5d9873 100644 --- a/tests/kernels/quantization/test_rocm_skinny_gemms.py +++ b/tests/kernels/quantization/test_rocm_skinny_gemms.py @@ -8,15 +8,55 @@ from tests.kernels.quant_utils import ref_dynamic_per_tensor_fp8_quant from vllm.platforms import current_platform DTYPES = [torch.bfloat16, torch.float16] -M = [16, 32, 64, 128, 256, 512, 1024, 4096, 8192] -K = [8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 6144, 8192] # k % 8 == 0 -N = [1, 2, 3, 4] +# Specific (N, K, M) combinations for targeted testing +NKM_FACTORS_LLMM1 = [ + # Small, medium, large cases + (1, 8, 16), + (1, 32, 64), + (1, 128, 256), + (1, 512, 1024), + (1, 2048, 4096), + # Edge cases with specific K sizes + (1, 6144, 1024), + (1, 8192, 2048), + # Very large case + (1, 4096, 8192), +] + +NKM_FACTORS_WVSPLITK = [ + # Different batch sizes with key dimensions + (1, 16, 16), + (1, 64, 64), + (2, 256, 256), + (3, 1024, 1024), + (4, 4096, 4096), + # Extended K values + (1, 9216, 512), + (2, 10240, 1024), + (4, 16384, 8192), + # Minimum M constraint validation (m >= 8) + (1, 64, 8), + (2, 128, 8), + (4, 256, 8), +] + +NKM_FACTORS_WVSPLITK_FP8 = [ + # FP8-specific cases with K % 16 == 0 + (1, 16, 16), + (1, 64, 64), + (2, 512, 512), + (3, 2048, 2048), + (4, 4096, 4096), + # Extended FP8 dimensions not covered by WVSPLITK + (1, 14336, 1024), + (2, 24576, 2048), + (4, 32768, 28672), +] + SEEDS = [0] -@pytest.mark.parametrize("n", [1]) # only test for batch size 1 -@pytest.mark.parametrize("k", K) -@pytest.mark.parametrize("m", M) +@pytest.mark.parametrize("n,k,m", NKM_FACTORS_LLMM1) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("rows_per_block", [2, 4, 8, 16]) @pytest.mark.parametrize("seed", SEEDS) @@ -34,9 +74,7 @@ def test_rocm_llmm1_kernel(n, k, m, dtype, rows_per_block, seed): assert torch.allclose(out, ref_out, rtol=0.01) -@pytest.mark.parametrize("n", N) # only test for batch size <= 4 -@pytest.mark.parametrize("k", K + [9216, 10240, 16384]) -@pytest.mark.parametrize("m", [8] + M) # m >= 8 +@pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITK) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) @pytest.mark.skipif(not current_platform.is_rocm(), @@ -54,9 +92,7 @@ def test_rocm_wvsplitk_kernel(n, k, m, dtype, seed): assert torch.allclose(out, ref_out, rtol=0.01) -@pytest.mark.parametrize("n", N) # only test for batch size <= 4 -@pytest.mark.parametrize("k", K[1:] + [14336, 24576, 32768]) # k % 16 == 0 -@pytest.mark.parametrize("m", M + [28672]) # m >= 16 +@pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITK_FP8) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) @pytest.mark.skipif( diff --git a/tests/kernels/quantization/test_triton_scaled_mm.py b/tests/kernels/quantization/test_triton_scaled_mm.py index 8a2cc3bac..24245663f 100644 --- a/tests/kernels/quantization/test_triton_scaled_mm.py +++ b/tests/kernels/quantization/test_triton_scaled_mm.py @@ -60,10 +60,18 @@ def test_rocm_compressed_tensors_w8a8(vllm_runner, example_prompts, model_path, num_logprobs) -@pytest.mark.parametrize("M", [1, 33, 64, 512]) -@pytest.mark.parametrize("N", [256, 971, 20486]) -@pytest.mark.parametrize("K", [128, 496, 1024]) -@pytest.mark.parametrize("out_dtype", [torch.float16, torch.bfloat16]) +MNK_FACTORS = [ + (1, 256, 128), + (33, 256, 496), + (64, 971, 1024), + (64, 20486, 128), + (512, 256, 496), + (512, 20486, 1024), +] + + +@pytest.mark.parametrize("M,N,K", MNK_FACTORS) +@pytest.mark.parametrize("out_dtype", [torch.bfloat16]) @pytest.mark.parametrize("in_dtype", get_8bit_types()) @pytest.mark.parametrize("use_scalar_scale_a", [True, False]) @pytest.mark.parametrize("use_scalar_scale_b", [True, False]) diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index 2e8febbdc..fa4125840 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -1064,6 +1064,8 @@ def torch_experts( topk_weight: torch.Tensor, topk_ids: torch.Tensor, global_num_experts: int = -1, + b_bias1: Optional[torch.Tensor] = None, + b_bias2: Optional[torch.Tensor] = None, expert_map: Optional[torch.Tensor] = None, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, @@ -1108,8 +1110,13 @@ def torch_experts( if mask.sum(): if quant_dtype is None: tmp1 = a[mask] @ w1[i].transpose(0, 1) + if b_bias1 is not None: + tmp1 = tmp1 + b_bias1[i].view(1, -1).to(tmp1.dtype) tmp2 = SiluAndMul()(tmp1) out[mask] = tmp2 @ w2[i].transpose(0, 1) + if b_bias2 is not None: + out[mask] = out[mask] + b_bias2[i].view(1, -1).to( + tmp1.dtype) elif block_shape is not None: # block quantized assert (a_scale is not None and w1_scale is not None @@ -1117,6 +1124,8 @@ def torch_experts( tmp1 = native_w8a8_block_matmul(a[mask], w1[i], a_scale[mask], w1_scale[i], block_shape, out.dtype) + if b_bias1 is not None: + tmp1 = tmp1 + b_bias1[i].view(1, -1).to(tmp1.dtype) tmp2 = SiluAndMul()(tmp1) tmp2, b_scale = moe_kernel_quantize_input( tmp2, a2_scale, quant_dtype, per_act_token_quant, @@ -1125,6 +1134,9 @@ def torch_experts( out[mask] = native_w8a8_block_matmul(tmp2, w2[i], b_scale, w2_scale[i], block_shape, out.dtype) + if b_bias2 is not None: + out[mask] = out[mask] + b_bias2[i].view(1, -1).to( + tmp1.dtype) else: assert (a_scale is not None and w1_scale is not None and w2_scale is not None) @@ -1133,6 +1145,8 @@ def torch_experts( tmp1 = a[mask].to(f32) * scales w1_dq = (w1[i].to(f32) * w1_scale[i]).transpose(0, 1) tmp1 = (tmp1 @ w1_dq).to(out.dtype) + if b_bias1 is not None: + tmp1 = tmp1 + b_bias1[i].view(1, -1).to(out.dtype) tmp2 = SiluAndMul()(tmp1).to(out.dtype) @@ -1144,6 +1158,9 @@ def torch_experts( tmp2 = tmp2.to(f32) * b_scale w2_dq = (w2[i].to(f32) * w2_scale[i]).transpose(0, 1) out[mask] = (tmp2 @ w2_dq).to(out.dtype) + if b_bias2 is not None: + out[mask] = out[mask] + b_bias2[i].view(1, -1).to( + out.dtype) if apply_router_weights_on_input: return out @@ -1157,12 +1174,14 @@ def torch_moe(a: torch.Tensor, w2: torch.Tensor, score: torch.Tensor, topk: int, + b_bias1: Optional[torch.Tensor] = None, + b_bias2: Optional[torch.Tensor] = None, global_num_experts: int = -1, expert_map: Optional[torch.Tensor] = None) -> torch.Tensor: score = torch.softmax(score, dim=-1, dtype=torch.float32) topk_weight, topk_ids = torch.topk(score, topk) return torch_experts(a, w1, w2, topk_weight, topk_ids, global_num_experts, - expert_map) + b_bias1, b_bias2, expert_map) def torch_moe_single(a, w, score, topk): diff --git a/tests/models/language/generation/test_hybrid.py b/tests/models/language/generation/test_hybrid.py index 19fcbf561..aee0a5033 100644 --- a/tests/models/language/generation/test_hybrid.py +++ b/tests/models/language/generation/test_hybrid.py @@ -57,6 +57,13 @@ V1_SUPPORTED_MODELS = [ # Avoid OOM MAX_NUM_SEQS = 4 +# Once we add support for FCG in Mamba1, this list will be removed and tests +# all test cases will use enforce_eager=False +ENFORCE_EAGER_MODELS_V1 = [ + "state-spaces/mamba-130m-hf", + "ai21labs/Jamba-tiny-dev", +] + @pytest.mark.parametrize("model", SSM_MODELS + HYBRID_MODELS) @pytest.mark.parametrize("max_tokens", [64]) @@ -94,13 +101,19 @@ def test_models( example_prompts, max_tokens, num_logprobs) if model in V1_SUPPORTED_MODELS: + enforce_eager = False with monkeypatch.context() as m: m.setenv("VLLM_USE_V1", "1") if model in HYBRID_MODELS: # required due to reorder_batch behaviour m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER") + + if model in ENFORCE_EAGER_MODELS_V1: + enforce_eager = True + with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS, + enforce_eager=enforce_eager, enable_prefix_caching=False) as vllm_model: vllm_v1_outputs = vllm_model.generate_greedy_logprobs( example_prompts, max_tokens, num_logprobs) @@ -418,3 +431,65 @@ def test_full_cuda_graph( name_0="hf" if hf_outputs is not None else "vllm-v0", name_1="vllm-v1", ) + + +@pytest.mark.parametrize("model", ["Zyphra/Zamba2-1.2B-instruct"]) +@pytest.mark.parametrize("max_tokens", [64]) +@pytest.mark.parametrize("num_logprobs", [5]) +def test_fp32_state( + hf_runner, + vllm_runner, + example_prompts, + monkeypatch, + model: str, + max_tokens: int, + num_logprobs: int, +) -> None: + + try: + model_info = HF_EXAMPLE_MODELS.find_hf_info(model) + model_info.check_available_online(on_fail="skip") + model_info.check_transformers_version(on_fail="skip") + except ValueError: + pass + + with hf_runner(model) as hf_model: + if model not in HF_UNSUPPORTED_MODELS: + hf_outputs = hf_model.generate_greedy_logprobs_limit( + example_prompts, max_tokens, num_logprobs) + else: + hf_outputs = None + + with vllm_runner(model, + max_num_seqs=MAX_NUM_SEQS, + mamba_ssm_cache_dtype="float32") as vllm_model: + vllm_v0_outputs = vllm_model.generate_greedy_logprobs( + example_prompts, max_tokens, num_logprobs) + + with monkeypatch.context() as m: + m.setenv("VLLM_USE_V1", "1") + if model in HYBRID_MODELS: + # required due to reorder_batch behaviour + m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER") + with vllm_runner(model, + max_num_seqs=MAX_NUM_SEQS, + mamba_ssm_cache_dtype="float32", + enable_prefix_caching=False) as vllm_model: + vllm_v1_outputs = vllm_model.generate_greedy_logprobs( + example_prompts, max_tokens, num_logprobs) + + if hf_outputs is not None: + check_logprobs_close( + outputs_0_lst=hf_outputs, + outputs_1_lst=vllm_v0_outputs, + name_0="hf", + name_1="vllm-v0", + ) + + ref_outputs = hf_outputs if hf_outputs is not None else vllm_v0_outputs + check_logprobs_close( + outputs_0_lst=ref_outputs, + outputs_1_lst=vllm_v1_outputs, + name_0="hf" if hf_outputs is not None else "vllm-v0", + name_1="vllm-v1", + ) diff --git a/tests/models/language/pooling/mteb_utils.py b/tests/models/language/pooling/mteb_utils.py index d024c76dd..4a1f8a53d 100644 --- a/tests/models/language/pooling/mteb_utils.py +++ b/tests/models/language/pooling/mteb_utils.py @@ -18,7 +18,7 @@ from tests.models.utils import EmbedModelInfo, RerankModelInfo # - Different model results in differences more than 1e-3 # 1e-4 is a good tolerance threshold MTEB_EMBED_TASKS = ["STS12"] -MTEB_EMBED_TOL = 1e-4 +MTEB_EMBED_TOL = 0.02 # See #19344 MTEB_RERANK_TASKS = ["NFCorpus"] @@ -175,6 +175,7 @@ def mteb_test_embed_models(hf_runner, with vllm_runner(model_info.name, runner="pooling", max_model_len=None, + enforce_eager=True, **vllm_extra_kwargs) as vllm_model: model_config = vllm_model.llm.llm_engine.model_config @@ -198,6 +199,7 @@ def mteb_test_embed_models(hf_runner, st_main_score = run_mteb_embed_task(hf_model, MTEB_EMBED_TASKS) st_dtype = next(hf_model.model.parameters()).dtype + print("Model:", model_info.name) print("VLLM:", vllm_dtype, vllm_main_score) print("SentenceTransformers:", st_dtype, st_main_score) print("Difference:", st_main_score - vllm_main_score) @@ -286,6 +288,7 @@ def mteb_test_rerank_models(hf_runner, runner="pooling", max_model_len=None, max_num_seqs=8, + enforce_eager=True, **vllm_extra_kwargs) as vllm_model: model_config = vllm_model.llm.llm_engine.model_config @@ -304,6 +307,7 @@ def mteb_test_rerank_models(hf_runner, st_main_score, st_dtype = mteb_test_rerank_models_hf( hf_runner, model_info.name, hf_model_callback) + print("Model:", model_info.name) print("VLLM:", vllm_dtype, vllm_main_score) print("SentenceTransformers:", st_dtype, st_main_score) print("Difference:", st_main_score - vllm_main_score) diff --git a/tests/models/registry.py b/tests/models/registry.py index eb48c0f6a..3efc9a99e 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -151,7 +151,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { "BailingMoeForCausalLM": _HfExamplesInfo("inclusionAI/Ling-lite-1.5", trust_remote_code=True), "BambaForCausalLM": _HfExamplesInfo("ibm-ai-platform/Bamba-9B-v1", - min_transformers_version="4.55.1", + min_transformers_version="4.56.0", extras={"tiny": "hmellor/tiny-random-BambaForCausalLM"}), # noqa: E501 "BloomForCausalLM": _HfExamplesInfo("bigscience/bloom-560m", {"1b": "bigscience/bloomz-1b1"}), @@ -227,7 +227,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { trust_remote_code=True), "JAISLMHeadModel": _HfExamplesInfo("inceptionai/jais-13b-chat"), "JambaForCausalLM": _HfExamplesInfo("ai21labs/AI21-Jamba-1.5-Mini", - min_transformers_version="4.55.1", + min_transformers_version="4.56.0", extras={ "tiny": "ai21labs/Jamba-tiny-dev", "random": "ai21labs/Jamba-tiny-random", # noqa: E501 diff --git a/tests/multimodal/test_hasher.py b/tests/multimodal/test_hasher.py index 42cb40739..75a233c25 100644 --- a/tests/multimodal/test_hasher.py +++ b/tests/multimodal/test_hasher.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import uuid from pathlib import Path import numpy as np @@ -72,3 +73,22 @@ def test_hash_non_contiguous_array(): hasher = MultiModalHasher # Both should be hashable and produce the same hashes assert hasher.hash_kwargs(data=arr) == hasher.hash_kwargs(data=arr_c) + + +def test_hash_image_exif_id(): + # Test that EXIF ImageId tag can be used to store UUID + # and the hasher will use that instead of the image data. + image1 = image2 = Image.new("1", size=(10, 20)) + id = uuid.uuid4() + image1.getexif()[Image.ExifTags.Base.ImageID] = id + image2 = Image.open(ASSETS_DIR / "image1.png") + image2.getexif()[Image.ExifTags.Base.ImageID] = "Not a UUID" + image2a = Image.open(ASSETS_DIR / "image1.png") + + hasher = MultiModalHasher + # first image has UUID in ImageID, so it should hash to that UUID + assert hasher.hash_kwargs(image=image1) == hasher.hash_kwargs( + image=id.bytes) + # second image has non-UUID in ImageID, so it should hash to the image data + assert hasher.hash_kwargs(image=image2) == hasher.hash_kwargs( + image=image2a) diff --git a/tests/multimodal/test_utils.py b/tests/multimodal/test_utils.py index 41f4773a1..ea964a543 100644 --- a/tests/multimodal/test_utils.py +++ b/tests/multimodal/test_utils.py @@ -148,6 +148,32 @@ async def test_fetch_image_local_files(image_url: str): f"file://{temp_dir}/../{os.path.basename(image_url)}") +@pytest.mark.asyncio +async def test_fetch_image_local_files_with_space_in_name(): + image_url = TEST_IMAGE_URLS[0] + connector = MediaConnector() + + with TemporaryDirectory() as temp_dir: + local_connector = MediaConnector(allowed_local_media_path=temp_dir) + + origin_image = connector.fetch_image(image_url) + filename = "file name with space.jpg" + origin_image.save(os.path.join(temp_dir, filename), + quality=100, + icc_profile=origin_image.info.get('icc_profile')) + + try: + image_async = await local_connector.fetch_image_async( + f"file://{temp_dir}/{filename}") + image_sync = local_connector.fetch_image( + f"file://{temp_dir}/{filename}") + except FileNotFoundError as e: + pytest.fail( + "Failed to fetch image with space in name: {}".format(e)) + # Check that the images are equal + assert not ImageChops.difference(image_sync, image_async).getbbox() + + @pytest.mark.asyncio async def test_fetch_image_error_conversion(): connector = MediaConnector() diff --git a/tests/standalone_tests/python_only_compile.sh b/tests/standalone_tests/python_only_compile.sh index ec1bcbcc5..7cc5ef659 100644 --- a/tests/standalone_tests/python_only_compile.sh +++ b/tests/standalone_tests/python_only_compile.sh @@ -10,7 +10,7 @@ cd /vllm-workspace/ # uninstall vllm pip3 uninstall -y vllm # restore the original files -mv test_docs/vllm ./vllm +mv src/vllm ./vllm # remove all compilers apt remove --purge build-essential -y diff --git a/tests/v1/attention/test_mamba_selectors.py b/tests/v1/attention/test_mamba_selectors.py index 8eaafc5e1..4245b50c7 100644 --- a/tests/v1/attention/test_mamba_selectors.py +++ b/tests/v1/attention/test_mamba_selectors.py @@ -4,7 +4,7 @@ import pytest -from vllm.v1.attention.backends.mamba_attn import Mamba2AttentionBackend +from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionBackend from vllm.v1.attention.backends.mamba_selectors import get_mamba_attn_backend diff --git a/tests/v1/cudagraph/__init__.py b/tests/v1/cudagraph/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/v1/cudagraph/test_cudagraph_dispatch.py b/tests/v1/cudagraph/test_cudagraph_dispatch.py new file mode 100644 index 000000000..64f2fa462 --- /dev/null +++ b/tests/v1/cudagraph/test_cudagraph_dispatch.py @@ -0,0 +1,406 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from unittest.mock import MagicMock, patch + +import pytest +import torch +import torch.nn as nn + +from tests.utils import create_new_process_for_each_test +from vllm.compilation.cuda_graph import CUDAGraphWrapper +from vllm.compilation.monitor import set_cudagraph_capturing_enabled +from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode, + ParallelConfig, SchedulerConfig, VllmConfig) +from vllm.forward_context import BatchDescriptor, set_forward_context +from vllm.platforms import current_platform +from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher + + +# Helper MLP for testing +class SimpleMLP(nn.Module): + + def __init__(self): + super().__init__() + self.fc1 = nn.Linear(10, 10) + self.fc2 = nn.Linear(10, 10) + + def forward(self, x): + return self.fc2(self.fc1(x)) + + +def _create_vllm_config(compilation_config: CompilationConfig, + max_num_seqs: int = 8) -> MagicMock: + mock_config = MagicMock(spec=VllmConfig) + mock_config.compilation_config = compilation_config + mock_config.scheduler_config = SchedulerConfig(max_num_seqs=max_num_seqs) + mock_config.parallel_config = ParallelConfig() + + # Mimic the behavior of VllmConfig.__post_init__() + if compilation_config.level == CompilationLevel.PIECEWISE: + compilation_config.set_splitting_ops_for_v1() + + return mock_config + + +class TestCudagraphDispatcher: + + @pytest.mark.parametrize( + "params", + [ + # Test case 0: Full CG for mixed batches, no separate routine + { + "case_id": 0, + "cudagraph_mode": "FULL", + "compilation_level": CompilationLevel.NO_COMPILATION, + }, + # Test case 1: Full CG for uniform batches, piecewise for mixed + { + "case_id": 1, + "cudagraph_mode": "FULL_AND_PIECEWISE", + "compilation_level": CompilationLevel.PIECEWISE, + }, + # Test case 2: Full CG for uniform batches, no CG for mixed + { + "case_id": 2, + "cudagraph_mode": "FULL_DECODE_ONLY", + "compilation_level": CompilationLevel.NO_COMPILATION, + }, + # Test case 3: Piecewise for all + { + "case_id": 3, + "cudagraph_mode": "PIECEWISE", + "compilation_level": CompilationLevel.PIECEWISE, + }, + ]) + def test_dispatcher(self, params): + # Setup dispatcher + comp_config = CompilationConfig( + cudagraph_mode=params["cudagraph_mode"], + level=params["compilation_level"], + cudagraph_capture_sizes=[1, 8]) + + config = _create_vllm_config(comp_config, max_num_seqs=8) + dispatcher = CudagraphDispatcher(config) + dispatcher.initialize_cudagraph_keys( + cudagraph_mode=comp_config.cudagraph_mode, + uniform_decode_query_len=1) + + # Verify the key is initialized correctly + if params["cudagraph_mode"] in ["FULL_AND_PIECEWISE", "PIECEWISE"]: + assert len(dispatcher.cudagraph_keys[CUDAGraphMode.PIECEWISE]) == 2 + else: + assert len(dispatcher.cudagraph_keys[CUDAGraphMode.PIECEWISE]) == 0 + if params["cudagraph_mode"] not in ["NONE", "PIECEWISE"]: + assert len(dispatcher.cudagraph_keys[CUDAGraphMode.FULL]) == 2 + else: + assert len(dispatcher.cudagraph_keys[CUDAGraphMode.FULL]) == 0 + + # Test dispatch logic + # 1. non-uniform batch, size in cudagraph size list + desc_full_exact = BatchDescriptor(num_tokens=8, uniform_decode=False) + rt_mode, key = dispatcher.dispatch(desc_full_exact) + if params["cudagraph_mode"] == "FULL": + assert rt_mode == CUDAGraphMode.FULL + assert key == desc_full_exact + elif params["cudagraph_mode"] in ["FULL_AND_PIECEWISE", "PIECEWISE"]: + assert rt_mode == CUDAGraphMode.PIECEWISE + assert key == desc_full_exact + else: + assert rt_mode == CUDAGraphMode.NONE + + # 2. uniform decode batch, size in cudagraph size list + desc_uniform_exact = BatchDescriptor(num_tokens=8, uniform_decode=True) + rt_mode, key = dispatcher.dispatch(desc_uniform_exact) + if params["cudagraph_mode"] == "FULL": + assert rt_mode == CUDAGraphMode.FULL + assert key == desc_uniform_exact.non_uniform + elif params["cudagraph_mode"] in [ + "FULL_DECODE_ONLY", "FULL_AND_PIECEWISE" + ]: + assert rt_mode == CUDAGraphMode.FULL + assert key == desc_uniform_exact + elif params["cudagraph_mode"] == "PIECEWISE": + assert rt_mode == CUDAGraphMode.PIECEWISE + assert key == desc_uniform_exact.non_uniform + else: + assert rt_mode == CUDAGraphMode.NONE + + # 3. No key match + desc_no_match = BatchDescriptor(num_tokens=15, uniform_decode=False) + rt_mode, key = dispatcher.dispatch(desc_no_match) + assert rt_mode == CUDAGraphMode.NONE + assert key is None + + +@pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda") +class TestCUDAGraphWrapper: + + def setup_method(self): + self.vllm_config = _create_vllm_config(CompilationConfig()) + self.model = SimpleMLP().to("cuda") + self.persistent_input_buffer = torch.zeros(1, 10, device="cuda") + self.input_tensor = torch.randn(1, 10, device="cuda") + + @create_new_process_for_each_test("spawn") + def test_capture_and_replay(self): + wrapper = CUDAGraphWrapper(self.model, + self.vllm_config, + runtime_mode=CUDAGraphMode.FULL) + batch_descriptor = BatchDescriptor(num_tokens=10) + + # 0. global warmup + with set_forward_context(attn_metadata=None, + vllm_config=self.vllm_config, + cudagraph_runtime_mode=CUDAGraphMode.NONE, + batch_descriptor=None): + wrapper(self.input_tensor) + + # 1. Capture + with set_forward_context( + attn_metadata=None, + vllm_config=self.vllm_config, + cudagraph_runtime_mode=CUDAGraphMode.FULL, + batch_descriptor=batch_descriptor),\ + patch("torch.cuda.graph", + wraps=torch.cuda.graph) as mock_cuda_graph: + output1 = wrapper(self.input_tensor) + # capturing phase should generate a zero output + assert torch.allclose(output1, torch.zeros_like(output1)) + mock_cuda_graph.assert_called_once() + + assert batch_descriptor in wrapper.concrete_cudagraph_entries + entry = wrapper.concrete_cudagraph_entries[batch_descriptor] + assert entry.cudagraph is not None + + # 2. Replay + with set_forward_context( + attn_metadata=None, + vllm_config=self.vllm_config, + cudagraph_runtime_mode=CUDAGraphMode.FULL, + batch_descriptor=batch_descriptor),\ + patch.object(entry.cudagraph, 'replay', + wraps=entry.cudagraph.replay) as mock_replay: + output2 = wrapper(self.input_tensor) + mock_replay.assert_called_once() + + # Compare with eager output + eager_output = self.model(self.input_tensor) + torch.testing.assert_close(eager_output, output2) + + @create_new_process_for_each_test("spawn") + def test_bypass_on_mode_mismatch(self): + wrapper = CUDAGraphWrapper(self.model, + self.vllm_config, + runtime_mode=CUDAGraphMode.FULL) + batch_descriptor = BatchDescriptor(num_tokens=10) + + with set_forward_context( + attn_metadata=None, + vllm_config=self.vllm_config, + cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE, + batch_descriptor=batch_descriptor), \ + patch('torch.cuda.graph', + wraps=torch.cuda.graph) as mock_cuda_graph, \ + patch.object(self.model, 'forward', + wraps=self.model.forward) as mock_forward: + wrapper(self.input_tensor) + mock_cuda_graph.assert_not_called() + mock_forward.assert_called_once() + assert not wrapper.concrete_cudagraph_entries + + @create_new_process_for_each_test("spawn") + def test_bypass_on_mode_none(self): + wrapper = CUDAGraphWrapper(self.model, + self.vllm_config, + runtime_mode=CUDAGraphMode.FULL) + batch_descriptor = BatchDescriptor(num_tokens=10) + + with set_forward_context( + attn_metadata=None, + vllm_config=self.vllm_config, + cudagraph_runtime_mode=CUDAGraphMode.NONE, + batch_descriptor=batch_descriptor), \ + patch('torch.cuda.graph', + wraps=torch.cuda.graph) as mock_cuda_graph: + wrapper(self.input_tensor) + mock_cuda_graph.assert_not_called() + assert not wrapper.concrete_cudagraph_entries + + +@pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda") +class TestCudagraphIntegration: + + def setup_method(self): + # only FULL mode for non-uniform batches + self.comp_config = CompilationConfig(level=CompilationLevel.PIECEWISE, + cudagraph_mode="FULL", + cudagraph_capture_sizes=[10, 20]) + self.vllm_config = _create_vllm_config(self.comp_config) + self.dispatcher = CudagraphDispatcher(self.vllm_config) + self.dispatcher.initialize_cudagraph_keys( + self.comp_config.cudagraph_mode, uniform_decode_query_len=1) + + def _run_and_monitor_call(self, wrapper, input_tensor, runtime_mode, + batch_descriptor): + """Helper to run a single call and monitor the action.""" + + with patch('torch.cuda.graph', + wraps=torch.cuda.graph) as mock_graph_context, \ + patch.object(wrapper, 'runnable', + wraps=wrapper.runnable) as mock_runnable: + + entry = wrapper.concrete_cudagraph_entries.get( + batch_descriptor, None) + + context = set_forward_context(attn_metadata=None, + vllm_config=self.vllm_config, + cudagraph_runtime_mode=runtime_mode, + batch_descriptor=batch_descriptor) + mock_replay = MagicMock() + if entry and entry.cudagraph: + with context, \ + patch.object(entry.cudagraph, 'replay', + new_callable=MagicMock) as mock_replay: + wrapper(input_tensor) + else: + with context: + wrapper(input_tensor) + + if mock_graph_context.called: + # note that this is globally mocked, so it will be detected + # even whether called by the inner or outer wrapper + return "capture_global" + if mock_replay.called: + # only for outer wrapper + return "replay" + if mock_runnable.call_count > 0: + # only for outer wrapper + return "bypass" + return "unknown" + + @create_new_process_for_each_test("spawn") + def test_capture_replay_bypass_logic(self): + model = SimpleMLP().to("cuda") + full_wrapper = CUDAGraphWrapper(model, self.vllm_config, + CUDAGraphMode.FULL) + max_bs = 16 + persistent_input_buffer = torch.zeros(max_bs, 10, device="cuda") + input_1 = persistent_input_buffer[:1] + input_2 = persistent_input_buffer[:2] + input_3 = persistent_input_buffer[:3] + + desc_1 = BatchDescriptor(num_tokens=1) + desc_2 = BatchDescriptor(num_tokens=2) + desc_3_unseen = BatchDescriptor(num_tokens=3) + + # 0. global warmup + with set_forward_context(attn_metadata=None, + vllm_config=self.vllm_config, + cudagraph_runtime_mode=CUDAGraphMode.NONE, + batch_descriptor=None): + full_wrapper(input_1) + + rt_mode, key = self.dispatcher.dispatch(desc_1) + # 1. Capture first shape + action = self._run_and_monitor_call(full_wrapper, input_1, rt_mode, + key) + assert action == "capture_global" + + # 2. Replay first shape + action = self._run_and_monitor_call(full_wrapper, input_1, rt_mode, + key) + assert action == "replay" + + rt_mode, key = self.dispatcher.dispatch(desc_2) + # 3. Capture second shape + action = self._run_and_monitor_call(full_wrapper, input_2, rt_mode, + key) + assert action == "capture_global" + + # 4. Replay second shape + action = self._run_and_monitor_call(full_wrapper, input_2, + CUDAGraphMode.FULL, desc_2) + assert action == "replay" + + # 5. Bypass if no key match + rt_mode, key = self.dispatcher.dispatch(desc_3_unseen) + assert rt_mode == CUDAGraphMode.NONE + action = self._run_and_monitor_call(full_wrapper, input_3, rt_mode, + key) + assert action == "bypass" + + # capture unseen shape is not allowed after disable + set_cudagraph_capturing_enabled(False) + with pytest.raises(RuntimeError): + self._run_and_monitor_call(full_wrapper, input_3, + CUDAGraphMode.FULL, desc_3_unseen) + set_cudagraph_capturing_enabled(True) + + @create_new_process_for_each_test("spawn") + def test_nested_wrappers(self): + """Tests a scenario with a PIECEWISE wrapper inside a FULL one.""" + model = SimpleMLP().to("cuda") + full_wrapper = CUDAGraphWrapper(model, self.vllm_config, + CUDAGraphMode.FULL) + input_1 = torch.randn(1, 10, device="cuda") + + # Setup: Inner model is wrapped with PIECEWISE, outer with FULL + inner_model = SimpleMLP().to("cuda") + piecewise_wrapper = CUDAGraphWrapper(inner_model, self.vllm_config, + CUDAGraphMode.PIECEWISE) + inner_model.forward = MagicMock(wraps=inner_model.forward) + outer_model = SimpleMLP().to("cuda") + # When outer model is called, it calls the piecewise_wrapper + outer_model.forward = MagicMock(wraps=outer_model.forward, + side_effect=piecewise_wrapper) + full_wrapper = CUDAGraphWrapper(outer_model, self.vllm_config, + CUDAGraphMode.FULL) + + desc_1 = BatchDescriptor(num_tokens=1) + + # 0. global warmup + with set_forward_context(attn_metadata=None, + vllm_config=self.vllm_config, + cudagraph_runtime_mode=CUDAGraphMode.NONE, + batch_descriptor=None): + full_wrapper(input_1) + + # --- Test runtime mode FULL--- + # Run with FULL mode context. Expect outer wrapper to capture. + # The inner mock should be called once inside the graph capture. + outer_model.forward.reset_mock() + inner_model.forward.reset_mock() + action = self._run_and_monitor_call(full_wrapper, input_1, + CUDAGraphMode.FULL, desc_1) + assert action == "capture_global" + assert outer_model.forward.call_count == 1 + assert inner_model.forward.call_count == 1 + + # Run again. Expect outer wrapper to replay. + # The outer model should NOT be called because the whole graph + # is replayed. + action = self._run_and_monitor_call(full_wrapper, input_1, + CUDAGraphMode.FULL, desc_1) + assert action == "replay" + assert outer_model.forward.call_count == 1 # No new call + assert inner_model.forward.call_count == 1 + + # --- Test runtime mode PIECEWISE --- + outer_model.forward.reset_mock() + inner_model.forward.reset_mock() + # Run with PIECEWISE mode context. + # Expect outer wrapper to bypass and call inner wrapper. + # Inner wrapper should capture. + action = self._run_and_monitor_call(full_wrapper, input_1, + CUDAGraphMode.PIECEWISE, desc_1) + assert action == "capture_global" + assert outer_model.forward.call_count == 1 + assert inner_model.forward.call_count == 1 + + # Run again with PIECEWISE. + # Outer bypasses, inner replays. + action = self._run_and_monitor_call(full_wrapper, input_1, + CUDAGraphMode.PIECEWISE, desc_1) + assert action == "bypass" + assert outer_model.forward.call_count == 2 + assert inner_model.forward.call_count == 1 diff --git a/tests/v1/cudagraph/test_cudagraph_mode.py b/tests/v1/cudagraph/test_cudagraph_mode.py new file mode 100644 index 000000000..81655e417 --- /dev/null +++ b/tests/v1/cudagraph/test_cudagraph_mode.py @@ -0,0 +1,187 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import contextlib +import os +import weakref +from contextlib import ExitStack +from dataclasses import dataclass +from typing import Optional + +import pytest + +from tests.utils import wait_for_gpu_memory_to_clear +from vllm import LLM +from vllm.config import CompilationConfig +from vllm.platforms import current_platform + + +@contextlib.contextmanager +def temporary_environ(env_vars): + """ + Temporarily set environment variables and restore them afterward. + We have to do this vs monkeypatch because monkeypatch doesn't work + with "module" scoped fixtures. + """ + original_env = {k: os.environ.get(k) for k in env_vars} + try: + os.environ.update(env_vars) + yield + finally: + for k, v in original_env.items(): + if v is None: + os.environ.pop(k, None) + else: + os.environ[k] = v + + +@dataclass +class BackendConfig: + name: str + env_vars: dict + comp_config: dict + specific_gpu_arch: Optional[tuple] = None + + +# Define all backend configurations of full cudagraph to be tested +backend_configs = { + # FA3 on Hopper + "FA3": + BackendConfig(name="FA3", + env_vars={"VLLM_FLASH_ATTN_VERSION": "3"}, + comp_config={ + "cudagraph_mode": "FULL", + }, + specific_gpu_arch=(9, 0)), + # FlashMLA on Hopper + "FlashMLA": + BackendConfig(name="FlashMLA", + env_vars={ + "VLLM_ATTENTION_BACKEND": "FLASHMLA", + }, + comp_config={ + "cudagraph_mode": "FULL_AND_PIECEWISE", + }, + specific_gpu_arch=(9, 0)), + # FA2 + "FA2": + BackendConfig(name="FA2", + env_vars={"VLLM_FLASH_ATTN_VERSION": "2"}, + comp_config={ + "cudagraph_mode": "FULL_AND_PIECEWISE", + }), + # Triton Attention + "TritonAttn": + BackendConfig(name="TritonAttn", + env_vars={"VLLM_ATTENTION_BACKEND": "TRITON_ATTN_VLLM_V1"}, + comp_config={ + "cudagraph_mode": "FULL_AND_PIECEWISE", + }), + # FlashInfer + "FlashInfer": + BackendConfig(name="FlashInfer", + env_vars={"VLLM_ATTENTION_BACKEND": "FLASHINFER"}, + comp_config={ + "cudagraph_mode": "FULL_AND_PIECEWISE", + }), +} + +# test attention backend and cudagraph_mode combo +# (backend_name, cudagraph_mode, supported) +combo_cases_1 = [ + ("FA3", "FULL", True), + ("FA3", "FULL_AND_PIECEWISE", True), + ("FA2", "FULL", True), # Should fallback to FULL_AND_PIECEWISE + ("FA2", "FULL_AND_PIECEWISE", True), + ("FlashInfer", "FULL", True), # Should fallback to FULL_AND_PIECEWISE + ("FlashInfer", "FULL_AND_PIECEWISE", True), +] + + +@pytest.mark.parametrize("combo_case", combo_cases_1) +def test_backend_and_cudagraph_mode_combo(combo_case): + backend_name, cudagraph_mode, supported = combo_case + if backend_name == "FlashInfer": + try: + import flashinfer # noqa: F401 + except ImportError: + pytest.skip("FlashInfer is not installed") + backend_config = backend_configs[backend_name] + # Dynamically skip test if GPU capability is not met + if backend_config.specific_gpu_arch and backend_config.specific_gpu_arch\ + != current_platform.get_device_capability(): + pytest.skip("Only Hopper GPUs support FA3 and FlashMLA") + + env_vars = {"VLLM_USE_V1": "1", **backend_configs[backend_name].env_vars} + + with temporary_environ(env_vars), ExitStack() as stack: + if not supported: + stack.enter_context(pytest.raises(Exception)) + + llm = LLM(model="Qwen/Qwen2-1.5B-Instruct", + max_num_seqs=256, + trust_remote_code=True, + gpu_memory_utilization=0.45, + max_model_len=1024, + compilation_config=CompilationConfig( + level=3, cudagraph_mode=cudagraph_mode)) + llm.generate(["Hello, my name is"] * 10) + + try: + llm = weakref.proxy(llm) + del llm + except UnboundLocalError: + pass + + wait_for_gpu_memory_to_clear( + devices=[0], + threshold_ratio=0.1, + ) + + +# test cudagraph_mode with different compilation level. +# (backend_name, cudagraph_mode, compilation_level, supported) +combo_cases_2 = [ + ("FA2", "FULL", 0, True), # no compilation + full cudagraph + ("FA2", "FULL", 3, True), # piecewise compilation + full cudagraph + ("FA2", "PIECEWISE", 0, False), # no compilation + piecewise cudagraph + ("FA2", "PIECEWISE", 3, + True), # piecewise compilation + piecewise cudagraph + ("FA2", "FULL_AND_PIECEWISE", 0, + False), # piecewise cudagraph not supported without piecewise compilation + ("FA2", "FULL_AND_PIECEWISE", 3, True), + ("FA2", "FULL_DECODE_ONLY", 0, True), + ("FA2", "FULL_DECODE_ONLY", 3, True), + ("FA2", "NONE", 0, True), # no compilation + no cudagraph + ("FA2", "NONE", 3, True), # piecewise compilation + no cudagraph +] + + +@pytest.mark.parametrize("combo_case", combo_cases_2) +def test_cudagraph_compilation_combo(combo_case): + backend_name, cudagraph_mode, compilation_level, supported\ + = combo_case + + env_vars = {"VLLM_USE_V1": "1", **backend_configs[backend_name].env_vars} + + with temporary_environ(env_vars), ExitStack() as stack: + if not supported: + stack.enter_context(pytest.raises(Exception)) + + llm = LLM(model="Qwen/Qwen2-1.5B-Instruct", + max_num_seqs=256, + trust_remote_code=True, + gpu_memory_utilization=0.45, + max_model_len=1024, + compilation_config=CompilationConfig( + level=compilation_level, cudagraph_mode=cudagraph_mode)) + llm.generate(["Hello, my name is"] * 10) + try: + llm = weakref.proxy(llm) + del llm + except UnboundLocalError: + pass + finally: + wait_for_gpu_memory_to_clear( + devices=[0], + threshold_ratio=0.1, + ) diff --git a/tests/v1/e2e/test_spec_decode.py b/tests/v1/e2e/test_spec_decode.py index dde95fbe5..7b3f45831 100644 --- a/tests/v1/e2e/test_spec_decode.py +++ b/tests/v1/e2e/test_spec_decode.py @@ -162,6 +162,12 @@ def test_eagle_correctness( mm_enabled: bool, attn_backend: str, ): + if attn_backend == "TREE_ATTN": + # TODO: Fix this flaky test + pytest.skip( + "TREE_ATTN is flaky in the test disable for now until it can be " + "reolved (see https://github.com/vllm-project/vllm/issues/22922)") + # Generate test prompts inside the function instead of using fixture test_prompts = get_test_prompts(mm_enabled) ''' diff --git a/tests/v1/engine/test_async_llm.py b/tests/v1/engine/test_async_llm.py index 21694491d..484640233 100644 --- a/tests/v1/engine/test_async_llm.py +++ b/tests/v1/engine/test_async_llm.py @@ -13,6 +13,7 @@ from vllm.assets.image import ImageAsset from vllm.config import VllmConfig from vllm.engine.arg_utils import AsyncEngineArgs from vllm.inputs import PromptType +from vllm.outputs import RequestOutput from vllm.platforms import current_platform from vllm.sampling_params import RequestOutputKind from vllm.utils import set_default_torch_num_threads @@ -398,3 +399,89 @@ async def test_check_health(monkeypatch: pytest.MonkeyPatch): # Test 3: Verify healthy engine still works after mock await engine.check_health() + + +@pytest.mark.parametrize( + "output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY]) +@pytest.mark.asyncio +async def test_abort_final_output( + monkeypatch: pytest.MonkeyPatch, + output_kind: RequestOutputKind, +): + """Test that abort() returns a final output with correct information.""" + + with monkeypatch.context() as m, ExitStack() as after: + m.setenv("VLLM_USE_V1", "1") + + with set_default_torch_num_threads(1): + engine = AsyncLLM.from_engine_args(TEXT_ENGINE_ARGS) + after.callback(engine.shutdown) + + request_id = "test-abort-final-output" + + # Start a long-running request + sampling_params = SamplingParams( + max_tokens=3000, # Long enough to allow abort + ignore_eos=True, + output_kind=output_kind, + temperature=0.5, + seed=42, + ) + + outputs: list[RequestOutput] = [] + generated = asyncio.create_task( + collect_outputs(engine, request_id, TEXT_PROMPT, sampling_params, + outputs)) + + # Let it generate some tokens + await asyncio.sleep(0.5) + + # Abort the request + await engine.abort(request_id) + + # Wait for generation to complete and return final output + final_output = await generated + + # Verify we got a final output + assert final_output is not None + assert final_output.finished + assert len(final_output.outputs) == 1 + + assert final_output.outputs[0].finish_reason == "abort" + assert final_output.outputs[0].stop_reason is None + + # Verify num_cached_tokens is set correctly + assert hasattr(final_output, 'num_cached_tokens') + assert final_output.num_cached_tokens >= 0 + + # If we got intermediate outputs, verify they are consistent + if output_kind == RequestOutputKind.DELTA: + # For DELTA, sum all intermediate tokens should <= final tokens + token_count = sum( + len(output.outputs[0].token_ids) for output in outputs) + assert token_count > 0 + assert len(final_output.outputs[0].token_ids) == 0 + else: + # For FINAL_ONLY, we should only get the final output + assert len(outputs) == 0 + assert len(final_output.outputs[0].token_ids) > 0 + + assert not engine.output_processor.has_unfinished_requests() + + +async def collect_outputs( + engine: AsyncLLM, + request_id: str, + prompt: PromptType, + sampling_params: SamplingParams, + outputs_list: list[RequestOutput], +) -> Optional[RequestOutput]: + """Helper to collect outputs and return the final one.""" + final_output: Optional[RequestOutput] = None + async for output in engine.generate(request_id=request_id, + prompt=prompt, + sampling_params=sampling_params): + if not output.finished: + outputs_list.append(output) + final_output = output + return final_output diff --git a/tests/v1/test_internal_lb_dp.py b/tests/v1/test_internal_lb_dp.py index ca80d3a49..2b031865c 100644 --- a/tests/v1/test_internal_lb_dp.py +++ b/tests/v1/test_internal_lb_dp.py @@ -4,6 +4,8 @@ import asyncio import os import threading import time +import traceback +from typing import Optional, cast import openai # use the official client for correctness check import pytest @@ -41,12 +43,15 @@ class MultinodeInternalLBServerManager: self.tp_size = tp_size self.api_server_count = api_server_count self.base_server_args = base_server_args - self.servers: list[tuple[RemoteOpenAIServer, list[str]]] = [] + self.servers: list[Optional[tuple[RemoteOpenAIServer, + list[str]]]] = [None] * (dp_size // + dp_per_node) self.server_threads: list[threading.Thread] = [] def __enter__(self) -> list[tuple[RemoteOpenAIServer, list[str]]]: """Start all server instances for multi-node internal LB mode.""" - for rank in range(0, self.dp_size, self.dp_per_node): + for server_idx, rank in enumerate( + range(0, self.dp_size, self.dp_per_node)): # Create server args for this specific rank server_args = self.base_server_args.copy() @@ -87,7 +92,7 @@ class MultinodeInternalLBServerManager: ]) # Use a thread to start each server to allow parallel initialization - def start_server(r: int, sargs: list[str]): + def start_server(sidx: int, r: int, sargs: list[str]): gpus_per_node = self.tp_size * self.dp_per_node try: # Start the server @@ -110,13 +115,14 @@ class MultinodeInternalLBServerManager: f"{self.api_server_count} API servers") else: print(f"Headless node (rank {r}) started successfully") - self.servers.append((server, sargs)) + self.servers[sidx] = (server, sargs) except Exception as e: print(f"Failed to start server rank {r}: {e}") + traceback.print_exc() raise thread = threading.Thread(target=start_server, - args=(rank, server_args)) + args=(server_idx, rank, server_args)) thread.start() self.server_threads.append(thread) @@ -128,18 +134,20 @@ class MultinodeInternalLBServerManager: # Give servers additional time to fully initialize and coordinate time.sleep(3) - if len(self.servers) != self.dp_size // self.dp_per_node: + if not all(self.servers): raise Exception("Servers failed to start") - return self.servers + return cast(list[tuple[RemoteOpenAIServer, list[str]]], self.servers) def __exit__(self, exc_type, exc_val, exc_tb): """Stop all server instances.""" while self.servers: - try: - self.servers.pop()[0].__exit__(exc_type, exc_val, exc_tb) - except Exception as e: - print(f"Error stopping server: {e}") + if server := self.servers.pop(): + try: + server[0].__exit__(exc_type, exc_val, exc_tb) + except Exception as e: + print(f"Error stopping server: {e}") + traceback.print_exc() class APIOnlyServerManager: @@ -157,7 +165,8 @@ class APIOnlyServerManager: self.tp_size = tp_size self.api_server_count = api_server_count self.base_server_args = base_server_args - self.servers: list[tuple[RemoteOpenAIServer, list[str]]] = [] + self.servers: list[Optional[tuple[RemoteOpenAIServer, + list[str]]]] = [None] * 2 self.server_threads: list[threading.Thread] = [] def __enter__(self) -> list[tuple[RemoteOpenAIServer, list[str]]]: @@ -209,7 +218,7 @@ class APIOnlyServerManager: server.__enter__() print(f"API-only server started successfully with " f"{self.api_server_count} API servers") - self.servers.append((server, api_server_args)) + self.servers[0] = (server, api_server_args) except Exception as e: print(f"Failed to start API-only server: {e}") raise @@ -231,7 +240,7 @@ class APIOnlyServerManager: server.__enter__() print(f"Headless engines server started successfully with " f"{self.dp_size} engines") - self.servers.append((server, engines_server_args)) + self.servers[1] = (server, engines_server_args) except Exception as e: print(f"Failed to start headless engines server: {e}") raise @@ -253,18 +262,20 @@ class APIOnlyServerManager: # Give servers additional time to fully initialize and coordinate time.sleep(3) - if len(self.servers) != 2: + if not all(self.servers): raise Exception("Both servers failed to start") - return self.servers + return cast(list[tuple[RemoteOpenAIServer, list[str]]], self.servers) def __exit__(self, exc_type, exc_val, exc_tb): """Stop both server instances.""" while self.servers: - try: - self.servers.pop()[0].__exit__(exc_type, exc_val, exc_tb) - except Exception as e: - print(f"Error stopping server: {e}") + if server := self.servers.pop(): + try: + server[0].__exit__(exc_type, exc_val, exc_tb) + except Exception as e: + print(f"Error stopping server: {e}") + traceback.print_exc() @pytest.fixture(scope="module") @@ -560,7 +571,7 @@ async def test_api_only_multinode_dp_completion( assert len(results) == num_requests assert all(completion is not None for completion in results) - _, api_server_args = api_only_servers[0] + api_server, api_server_args = api_only_servers[0] api_server_count = ( api_server_args.count('--api-server-count') and api_server_args[api_server_args.index('--api-server-count') + 1] @@ -569,7 +580,6 @@ async def test_api_only_multinode_dp_completion( f"engines on headless server (API server count: {api_server_count})") # Check request balancing via Prometheus metrics - api_server = api_only_servers[0][0] check_request_balancing(api_server, DP_SIZE) diff --git a/tests/v1/worker/test_gpu_model_runner.py b/tests/v1/worker/test_gpu_model_runner.py index e97cdf482..4bcc63f29 100644 --- a/tests/v1/worker/test_gpu_model_runner.py +++ b/tests/v1/worker/test_gpu_model_runner.py @@ -772,6 +772,8 @@ def test_hybrid_attention_mamba_tensor_shapes(monkeypatch): head_dim=hf_config.mamba_d_head, rms_norm_eps=hf_config.rms_norm_eps, activation=hf_config.hidden_act, + cache_config=cache_config, + model_config=model_config, prefix=key, ) # suppress var not used error diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 70605d3c5..a318637c5 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -319,38 +319,6 @@ def apply_repetition_penalties(logits: torch.Tensor, prompt_mask: torch.Tensor, repetition_penalties) -def advance_step_flashattn(num_seqs: int, num_queries: int, block_size: int, - input_tokens: torch.Tensor, - sampled_token_ids: torch.Tensor, - input_positions: torch.Tensor, - seq_lens: torch.Tensor, slot_mapping: torch.Tensor, - block_tables: torch.Tensor) -> None: - """Advance a step on GPU for existing inputs for a multi-step runner""" - return torch.ops._C.advance_step_flashattn(num_seqs, num_queries, - block_size, input_tokens, - sampled_token_ids, - input_positions, seq_lens, - slot_mapping, block_tables) - - -def advance_step_flashinfer(num_seqs: int, num_queries: int, block_size: int, - input_tokens: torch.Tensor, - sampled_token_ids: torch.Tensor, - input_positions: torch.Tensor, - seq_lens: torch.Tensor, slot_mapping: torch.Tensor, - block_tables: torch.Tensor, - paged_kv_indices: torch.Tensor, - paged_kv_indptr: torch.Tensor, - paged_kv_last_page_len: torch.Tensor, - block_table_bound: torch.Tensor) -> None: - - return torch.ops._C.advance_step_flashinfer( - num_seqs, num_queries, block_size, input_tokens, sampled_token_ids, - input_positions, seq_lens, slot_mapping, block_tables, - paged_kv_indices, paged_kv_indptr, paged_kv_last_page_len, - block_table_bound) - - # fused quant layer norm ops def rms_norm_dynamic_per_token_quant( input: torch.Tensor, @@ -452,6 +420,7 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"): def _gptq_marlin_gemm_fake(a: torch.Tensor, c: Optional[torch.Tensor], b_q_weight: torch.Tensor, + b_bias: Optional[torch.Tensor], b_scales: torch.Tensor, global_scale: Optional[torch.Tensor], b_zeros: Optional[torch.Tensor], @@ -1048,6 +1017,7 @@ def awq_marlin_moe_repack(b_q_weight: torch.Tensor, perm: torch.Tensor, def gptq_marlin_gemm(a: torch.Tensor, c: Optional[torch.Tensor], b_q_weight: torch.Tensor, + b_bias: Optional[torch.Tensor], b_scales: torch.Tensor, global_scale: Optional[torch.Tensor], b_zeros: Optional[torch.Tensor], @@ -1062,7 +1032,7 @@ def gptq_marlin_gemm(a: torch.Tensor, use_atomic_add: bool = False, use_fp32_reduce: bool = False, is_zp_float: bool = False) -> torch.Tensor: - return torch.ops._C.gptq_marlin_gemm(a, c, b_q_weight, b_scales, + return torch.ops._C.gptq_marlin_gemm(a, c, b_q_weight, b_bias, b_scales, global_scale, b_zeros, g_idx, perm, workspace, b_q_type.id, size_m, size_n, size_k, is_k_full, @@ -1540,7 +1510,9 @@ def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor, def moe_wna16_marlin_gemm(input: torch.Tensor, output: Optional[torch.Tensor], - b_qweight: torch.Tensor, b_scales: torch.Tensor, + b_qweight: torch.Tensor, + b_bias: Optional[torch.Tensor], + b_scales: torch.Tensor, global_scale: Optional[torch.Tensor], b_qzeros: Optional[torch.Tensor], g_idx: Optional[torch.Tensor], @@ -1556,11 +1528,11 @@ def moe_wna16_marlin_gemm(input: torch.Tensor, output: Optional[torch.Tensor], use_fp32_reduce: bool, is_zp_float: bool) -> torch.Tensor: return torch.ops._moe_C.moe_wna16_marlin_gemm( - input, output, b_qweight, b_scales, global_scale, b_qzeros, g_idx, - perm, workspace, sorted_token_ids, expert_ids, num_tokens_past_padded, - topk_weights, moe_block_size, top_k, mul_topk_weights, is_ep, - b_q_type.id, size_m, size_n, size_k, is_k_full, use_atomic_add, - use_fp32_reduce, is_zp_float) + input, output, b_qweight, b_bias, b_scales, global_scale, b_qzeros, + g_idx, perm, workspace, sorted_token_ids, expert_ids, + num_tokens_past_padded, topk_weights, moe_block_size, top_k, + mul_topk_weights, is_ep, b_q_type.id, size_m, size_n, size_k, + is_k_full, use_atomic_add, use_fp32_reduce, is_zp_float) if supports_moe_ops and hasattr(torch.ops._moe_C, "marlin_gemm_moe"): diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 2417fe06a..d21f07756 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -101,11 +101,6 @@ class AttentionBackend(ABC): ) -> None: raise NotImplementedError - def advance_step(self, model_input: "ModelRunnerInputBase", - sampled_token_ids: Optional[torch.Tensor], - block_size: int, num_seqs: int, num_queries: int) -> None: - raise NotImplementedError - @classmethod def full_cls_name(cls) -> tuple[str, str]: return (cls.__module__, cls.__qualname__) diff --git a/vllm/attention/backends/differential_flash_attn.py b/vllm/attention/backends/differential_flash_attn.py index bd9bc4277..fac3c318a 100644 --- a/vllm/attention/backends/differential_flash_attn.py +++ b/vllm/attention/backends/differential_flash_attn.py @@ -35,8 +35,7 @@ from vllm.vllm_flash_attn import (flash_attn_varlen_func, flash_attn_with_kvcache) if TYPE_CHECKING: - from vllm.worker.model_runner import (ModelInputForGPUBuilder, - ModelInputForGPUWithSamplingMetadata) + from vllm.worker.model_runner import ModelInputForGPUBuilder logger = init_logger(__name__) @@ -326,79 +325,6 @@ class DifferentialFlashAttentionMetadata(AttentionMetadata): cross_block_tables=self.cross_block_tables) return self._cached_decode_metadata - def advance_step(self, - model_input: "ModelInputForGPUWithSamplingMetadata", - sampled_token_ids: Optional[torch.Tensor], - block_size: int, - num_seqs: int, - num_queries: int, - turn_prefills_into_decodes: bool = False): - """ - Update metadata in-place to advance one decode step. - """ - # When using cudagraph, the num_seqs is padded to the next captured - # batch sized, but num_queries tracks the actual number of requests in - # the batch. For --enforce-eager mode, num_seqs == num_queries - if num_seqs != num_queries: - assert num_seqs > num_queries - assert self.use_cuda_graph - - if turn_prefills_into_decodes: - # When Multi-Step is enabled with Chunked-Prefill, prefills and - # decodes are scheduled together. In the first step, all the - # prefills turn into decodes. This update reflects that - # conversion. - assert self.num_decode_tokens + self.num_prefills == num_seqs - self.num_decode_tokens += self.num_prefills - self.num_prefills = 0 - self.num_prefill_tokens = 0 - self.max_prefill_seq_len = 0 - self.max_query_len = 1 - - self.slot_mapping = self.slot_mapping[:num_seqs] - else: - assert self.seq_lens is not None - assert self.max_decode_seq_len == max(self.seq_lens) - - assert self.num_prefills == 0 - assert self.num_prefill_tokens == 0 - assert self.num_decode_tokens == num_seqs - assert self.slot_mapping.shape == (num_seqs, ) - - assert self.seq_lens is not None - assert len(self.seq_lens) == num_seqs - assert self.seq_lens_tensor is not None - assert self.seq_lens_tensor.shape == (num_seqs, ) - assert self.max_query_len == 1 - assert self.max_prefill_seq_len == 0 - - assert self.query_start_loc is not None - assert self.query_start_loc.shape == (num_queries + 1, ) - assert self.seq_start_loc is not None - assert self.seq_start_loc.shape == (num_seqs + 1, ) - - assert self.context_lens_tensor is not None - assert self.context_lens_tensor.shape == (num_queries, ) - - assert self.block_tables is not None - assert self.block_tables.shape[0] == num_seqs - - # Update query lengths. Note that we update only queries and not seqs, - # since tensors may be padded due to captured cuda graph batch size - for i in range(num_queries): - self.seq_lens[i] += 1 - self.max_decode_seq_len = max(self.seq_lens) - - ops.advance_step_flashattn(num_seqs=num_seqs, - num_queries=num_queries, - block_size=block_size, - input_tokens=model_input.input_tokens, - sampled_token_ids=sampled_token_ids, - input_positions=model_input.input_positions, - seq_lens=self.seq_lens_tensor, - slot_mapping=self.slot_mapping, - block_tables=self.block_tables) - class DifferentialFlashAttentionMetadataBuilder( AttentionMetadataBuilder[DifferentialFlashAttentionMetadata]): diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index ee36fd19e..e52480d5c 100755 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -32,8 +32,7 @@ from vllm.vllm_flash_attn import (flash_attn_varlen_func, flash_attn_with_kvcache) if TYPE_CHECKING: - from vllm.worker.model_runner import (ModelInputForGPUBuilder, - ModelInputForGPUWithSamplingMetadata) + from vllm.worker.model_runner import ModelInputForGPUBuilder logger = init_logger(__name__) @@ -309,79 +308,6 @@ class FlashAttentionMetadata(AttentionMetadata): cross_block_tables=self.cross_block_tables) return self._cached_decode_metadata - def advance_step(self, - model_input: "ModelInputForGPUWithSamplingMetadata", - sampled_token_ids: Optional[torch.Tensor], - block_size: int, - num_seqs: int, - num_queries: int, - turn_prefills_into_decodes: bool = False): - """ - Update metadata in-place to advance one decode step. - """ - # When using cudagraph, the num_seqs is padded to the next captured - # batch sized, but num_queries tracks the actual number of requests in - # the batch. For --enforce-eager mode, num_seqs == num_queries - if num_seqs != num_queries: - assert num_seqs > num_queries - assert self.use_cuda_graph - - if turn_prefills_into_decodes: - # When Multi-Step is enabled with Chunked-Prefill, prefills and - # decodes are scheduled together. In the first step, all the - # prefills turn into decodes. This update reflects that - # conversion. - assert self.num_decode_tokens + self.num_prefills == num_seqs - self.num_decode_tokens += self.num_prefills - self.num_prefills = 0 - self.num_prefill_tokens = 0 - self.max_prefill_seq_len = 0 - self.max_query_len = 1 - - self.slot_mapping = self.slot_mapping[:num_seqs] - else: - assert self.seq_lens is not None - assert self.max_decode_seq_len == max(self.seq_lens) - - assert self.num_prefills == 0 - assert self.num_prefill_tokens == 0 - assert self.num_decode_tokens == num_seqs - assert self.slot_mapping.shape == (num_seqs, ) - - assert self.seq_lens is not None - assert len(self.seq_lens) == num_seqs - assert self.seq_lens_tensor is not None - assert self.seq_lens_tensor.shape == (num_seqs, ) - assert self.max_query_len == 1 - assert self.max_prefill_seq_len == 0 - - assert self.query_start_loc is not None - assert self.query_start_loc.shape == (num_queries + 1, ) - assert self.seq_start_loc is not None - assert self.seq_start_loc.shape == (num_seqs + 1, ) - - assert self.context_lens_tensor is not None - assert self.context_lens_tensor.shape == (num_queries, ) - - assert self.block_tables is not None - assert self.block_tables.shape[0] == num_seqs - - # Update query lengths. Note that we update only queries and not seqs, - # since tensors may be padded due to captured cuda graph batch size - for i in range(num_queries): - self.seq_lens[i] += 1 - self.max_decode_seq_len = max(self.seq_lens) - - ops.advance_step_flashattn(num_seqs=num_seqs, - num_queries=num_queries, - block_size=block_size, - input_tokens=model_input.input_tokens, - sampled_token_ids=sampled_token_ids, - input_positions=model_input.input_positions, - seq_lens=self.seq_lens_tensor, - slot_mapping=self.slot_mapping, - block_tables=self.block_tables) - class FlashAttentionMetadataBuilder( AttentionMetadataBuilder[FlashAttentionMetadata]): diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 78d8a67e3..208cacec3 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -51,8 +51,7 @@ from vllm.utils.flashinfer import use_trtllm_attention logger = init_logger(__name__) if TYPE_CHECKING: - from vllm.worker.model_runner import (ModelInputForGPUBuilder, - ModelInputForGPUWithSamplingMetadata) + from vllm.worker.model_runner import ModelInputForGPUBuilder class FlashInferBackend(AttentionBackend): @@ -428,7 +427,7 @@ class FlashInferMetadata(AttentionMetadata): query_start_loc: Optional[torch.Tensor] = None block_tables: Optional[torch.Tensor] = None - # used for GPU in-place advance_step + # used for GPU operations seq_lens_tensor: Optional[torch.Tensor] = None block_table_bound: Optional[torch.Tensor] = None @@ -587,66 +586,6 @@ class FlashInferMetadata(AttentionMetadata): return None return self - def advance_step(self, - model_input: "ModelInputForGPUWithSamplingMetadata", - sampled_token_ids: Optional[torch.Tensor], - block_size: int, - num_seqs: int, - num_queries: int, - turn_prefills_into_decodes: bool = False): - """ - Update metadata in-place to advance one decode step. - """ - - if turn_prefills_into_decodes: - # When Multi-Step is enabled with Chunked-Prefill, prefills and - # decodes are scheduled together. In the first step, all the - # prefills turn into decodes. This update reflects that - # conversion. - assert self.num_decode_tokens + self.num_prefills == num_seqs - # Flashinfer doesn't support speculative decoding + chunked-prefill - # + multi-step scheduling yet. - assert self.decode_query_len == 1 - self.num_decode_tokens += self.num_prefills - self.num_prefills = 0 - self.num_prefill_tokens = 0 - self.max_prefill_seq_len = 0 - self.max_query_len = 1 - - self.slot_mapping = self.slot_mapping[:num_seqs] - else: - assert self.seq_lens_tensor is not None - - assert num_seqs > 0 - assert num_queries > 0 - assert model_input.attn_metadata is not None - assert sampled_token_ids is not None - - # When using cudagraph, the num_seqs is padded to the next captured - # batch sized, but num_queries tracks the actual number of requests in - # the batch. For --enforce-eager mode, num_seqs == num_queries - if num_seqs != num_queries: - assert num_seqs > num_queries - assert self.use_cuda_graph - - model_input.input_tokens[:num_queries] = sampled_token_ids.flatten() - - # Update GPU tensors - ops.advance_step_flashinfer( - num_seqs=num_seqs, - num_queries=num_queries, - block_size=block_size, - input_tokens=model_input.input_tokens, - sampled_token_ids=model_input.input_tokens, - input_positions=model_input.input_positions, - seq_lens=self.seq_lens_tensor, - slot_mapping=self.slot_mapping, - block_tables=self.block_tables, - paged_kv_indices=self.paged_kv_indices, - paged_kv_indptr=self.paged_kv_indptr, - paged_kv_last_page_len=self.paged_kv_last_page_len, - block_table_bound=self.block_table_bound) - class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): diff --git a/vllm/attention/backends/flashmla.py b/vllm/attention/backends/flashmla.py index a242ac9bb..f23c09695 100644 --- a/vllm/attention/backends/flashmla.py +++ b/vllm/attention/backends/flashmla.py @@ -3,7 +3,7 @@ from contextlib import contextmanager from dataclasses import dataclass -from typing import TYPE_CHECKING, List, Optional, Tuple, Type +from typing import List, Optional, Tuple, Type import torch @@ -18,9 +18,6 @@ from vllm.attention.ops.flashmla import (flash_mla_with_kvcache, get_mla_metadata, is_flashmla_supported) -if TYPE_CHECKING: - from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata - class FlashMLABackend(MLACommonBackend): @@ -62,16 +59,6 @@ class FlashMLAMetadata(MLACommonMetadata): self.decode_num_splits return decode_metadata - def advance_step(self, - model_input: "ModelInputForGPUWithSamplingMetadata", - sampled_token_ids: Optional[torch.Tensor], - block_size: int, - num_seqs: int, - num_queries: int, - turn_prefills_into_decodes: bool = False): - raise NotImplementedError( - "advance_step is not implemented for FlashMLA") - class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): diff --git a/vllm/attention/backends/mla/common.py b/vllm/attention/backends/mla/common.py index 52c4a9e7d..8ff7f5674 100644 --- a/vllm/attention/backends/mla/common.py +++ b/vllm/attention/backends/mla/common.py @@ -234,8 +234,7 @@ except ImportError: flash_attn_varlen_func = None if TYPE_CHECKING: - from vllm.worker.model_runner import (ModelInputForGPUBuilder, - ModelInputForGPUWithSamplingMetadata) + from vllm.worker.model_runner import ModelInputForGPUBuilder is_hip = current_platform.is_rocm() @@ -631,90 +630,6 @@ class MLACommonMetadata(AttentionMetadata): is_profile_run=self.is_profile_run) return self._cached_decode_metadata - def advance_step(self, - model_input: "ModelInputForGPUWithSamplingMetadata", - sampled_token_ids: Optional[torch.Tensor], - block_size: int, - num_seqs: int, - num_queries: int, - turn_prefills_into_decodes: bool = False): - """ - Update metadata in-place to advance one decode step. - """ - # When using cudagraph, the num_seqs is padded to the next captured - # batch sized, but num_queries tracks the actual number of requests in - # the batch. For --enforce-eager mode, num_seqs == num_queries - if num_seqs != num_queries: - assert num_seqs > num_queries - - if turn_prefills_into_decodes: - # When Multi-Step is enabled with Chunked-Prefill, prefills and - # decodes are scheduled together. In the first step, all the - # prefills turn into decodes. This update reflects that - # conversion. - assert self.num_decode_tokens + self.num_prefills == num_seqs - self.num_decode_tokens += self.num_prefills - self.num_prefills = 0 - self.num_prefill_tokens = 0 - self.max_prefill_seq_len = 0 - self.max_query_len = 1 - - self.slot_mapping = self.slot_mapping[:num_seqs] - else: - assert self.seq_lens is not None - assert self.max_decode_seq_len == max(self.seq_lens) - - assert self.num_prefills == 0 - assert self.num_prefill_tokens == 0 - assert self.num_decode_tokens == num_seqs - assert self.slot_mapping.shape == (num_seqs, ) - - assert self.seq_lens is not None - assert len(self.seq_lens) == num_seqs - assert self.seq_lens_tensor is not None - assert self.seq_lens_tensor.shape == (num_seqs, ) - assert self.max_query_len == 1 - assert self.max_prefill_seq_len == 0 - - assert self.query_start_loc is not None - assert self.query_start_loc.shape == (num_queries + 1, ) - assert self.seq_start_loc is not None - assert self.seq_start_loc.shape == (num_seqs + 1, ) - - assert self.context_lens_tensor is not None - assert self.context_lens_tensor.shape == (num_queries, ) - - assert self.block_tables is not None - assert self.block_tables.shape[0] == num_seqs - - # Update query lengths. Note that we update only queries and not seqs, - # since tensors may be padded due to captured cuda graph batch size - for i in range(num_queries): - self.seq_lens[i] += 1 - self.max_decode_seq_len = max(self.seq_lens) - - self._ops_advance_step(num_seqs=num_seqs, - num_queries=num_queries, - block_size=block_size, - input_tokens=model_input.input_tokens, - sampled_token_ids=sampled_token_ids, - input_positions=model_input.input_positions) - - def _ops_advance_step(self, num_seqs: int, num_queries: int, - block_size: int, input_tokens: torch.Tensor, - sampled_token_ids: torch.Tensor, - input_positions: torch.Tensor) -> None: - # here we use advance_step_flashinfo to update the paged_kv_* tensors - ops.advance_step_flashattn(num_seqs=num_seqs, - num_queries=num_queries, - block_size=block_size, - input_tokens=input_tokens, - sampled_token_ids=sampled_token_ids, - input_positions=input_positions, - seq_lens=self.seq_lens_tensor, - slot_mapping=self.slot_mapping, - block_tables=self.block_tables) - class MLACommonMetadataBuilder(AttentionMetadataBuilder[T], Generic[T]): """ diff --git a/vllm/attention/backends/placeholder_attn.py b/vllm/attention/backends/placeholder_attn.py index 820ddcab7..e630a6c6d 100644 --- a/vllm/attention/backends/placeholder_attn.py +++ b/vllm/attention/backends/placeholder_attn.py @@ -15,8 +15,7 @@ from vllm.attention.backends.utils import CommonAttentionState from vllm.multimodal import MultiModalPlaceholderMap if TYPE_CHECKING: - from vllm.worker.model_runner import (ModelInputForGPUBuilder, - ModelInputForGPUWithSamplingMetadata) + from vllm.worker.model_runner import (ModelInputForGPUBuilder) from vllm.utils import async_tensor_h2d # Placeholder attention backend for models like Mamba and pooling models that @@ -201,65 +200,6 @@ class PlaceholderAttentionMetadata(AttentionMetadata): ) return self._cached_decode_metadata - def advance_step(self, - model_input: "ModelInputForGPUWithSamplingMetadata", - sampled_token_ids: Optional[torch.Tensor], - block_size: int, - num_seqs: int, - num_queries: int, - turn_prefills_into_decodes: bool = False): - """ - Update metadata in-place to advance one decode step. - """ - # When using cudagraph, the num_seqs is padded to the next captured - # batch sized, but num_queries tracks the actual number of requests in - # the batch. For --enforce-eager mode, num_seqs == num_queries - if num_seqs != num_queries: - assert num_seqs > num_queries - assert self.use_cuda_graph - - assert not turn_prefills_into_decodes, \ - ("Multi-Step + Chunked-Prefill is not supported for attention-free" - "models. turn_prefills_into_decodes is a " - "Multi-Step + Chunked-Prefill specific parameter.") - - assert self.seq_lens is not None - assert self.max_decode_seq_len == max(self.seq_lens) - - assert self.num_prefills == 0 - assert self.num_prefill_tokens == 0 - assert self.num_decode_tokens == num_seqs - - assert self.seq_lens is not None - assert len(self.seq_lens) == num_seqs - assert self.seq_lens_tensor is not None - assert self.seq_lens_tensor.shape == (num_seqs, ) - assert self.max_query_len == 1 - assert self.max_prefill_seq_len == 0 - - assert self.query_start_loc is not None - assert self.query_start_loc.shape == (num_queries + 1, ) - assert self.seq_start_loc is not None - assert self.seq_start_loc.shape == (num_seqs + 1, ) - - assert self.context_lens_tensor is not None - assert self.context_lens_tensor.shape == (num_queries, ) - - # Update query lengths. Note that we update only queries and not seqs, - # since tensors may be padded due to captured cuda graph batch size - for i in range(num_queries): - self.seq_lens[i] += 1 - self.max_decode_seq_len = max(self.seq_lens) - - # Update sequences, masking off entries greater than num_queries - device = self.seq_lens_tensor.device - mask = torch.arange(self.seq_lens_tensor.size(0), - device=device) < num_queries - self.seq_lens_tensor += mask.to(self.seq_lens_tensor.dtype) - if sampled_token_ids is not None: - model_input.input_tokens.masked_scatter_( - mask, sampled_token_ids[:num_queries]) - class PlaceholderAttentionMetadataBuilder( AttentionMetadataBuilder[PlaceholderAttentionMetadata]): diff --git a/vllm/attention/backends/rocm_aiter_mla.py b/vllm/attention/backends/rocm_aiter_mla.py index a165a786d..a2e971043 100644 --- a/vllm/attention/backends/rocm_aiter_mla.py +++ b/vllm/attention/backends/rocm_aiter_mla.py @@ -7,7 +7,6 @@ from typing import TYPE_CHECKING, Optional, Type, Union import torch -import vllm._custom_ops as ops import vllm.envs as envs from vllm.attention.backends.mla.common import (MLACommonBackend, MLACommonImpl, @@ -107,26 +106,6 @@ class AiterMLAMetadata(MLACommonMetadata): return self._cached_decode_metadata - def _ops_advance_step(self, num_seqs: int, num_queries: int, - block_size: int, input_tokens: torch.Tensor, - sampled_token_ids: torch.Tensor, - input_positions: torch.Tensor) -> None: - - ops.advance_step_flashinfer( - num_seqs=num_seqs, - num_queries=num_queries, - block_size=block_size, - input_tokens=input_tokens, - sampled_token_ids=sampled_token_ids, - input_positions=input_positions, - seq_lens=self.seq_lens_tensor, - slot_mapping=self.slot_mapping, - block_tables=self.block_tables, - paged_kv_indices=self.paged_kv_indices, - paged_kv_indptr=self.paged_kv_indptr, - paged_kv_last_page_lens=self.paged_kv_last_page_lens, - block_table_bound=self.block_table_bound) - class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): BLOCK_TABLE_EXTENDER: list[list[int]] = [[]] diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index da3d9ff32..63e467f5a 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -4,7 +4,7 @@ import itertools from dataclasses import dataclass from functools import cache -from typing import TYPE_CHECKING, List, Optional, Tuple, Type +from typing import List, Optional, Tuple, Type import torch @@ -23,9 +23,6 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape) from vllm.platforms import current_platform -if TYPE_CHECKING: - from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata - logger = init_logger(__name__) _PARTITION_SIZE_ROCM = 256 @@ -261,69 +258,6 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata): self._cached_decode_metadata.query_start_loc = qs - qs[0] return self._cached_decode_metadata - def advance_step(self, - model_input: "ModelInputForGPUWithSamplingMetadata", - sampled_token_ids: Optional[torch.Tensor], - block_size: int, - num_seqs: int, - num_queries: int, - turn_prefills_into_decodes: bool = False): - """ - Update metadata in-place to advance one decode step. - """ - - assert not turn_prefills_into_decodes, \ - ("Chunked prefill is not supported with rocm_flash_attn yet." - "turn_prefills_into_decodes is a Multi-Step + Chunked-Prefill " - "specific parameter.") - - # When using cudagraph, the num_seqs is padded to the next captured - # batch sized, but num_queries tracks the actual number of requests in - # the batch. For --enforce-eager mode, num_seqs == num_queries - if num_seqs != num_queries: - assert num_seqs > num_queries - assert self.use_cuda_graph - - assert self.num_prefills == 0 - assert self.num_prefill_tokens == 0 - assert self.num_decode_tokens == num_seqs - assert self.slot_mapping.shape == (num_seqs, ) - - assert self.seq_lens is not None - assert len(self.seq_lens) == num_seqs - assert self.seq_lens_tensor is not None - assert self.seq_lens_tensor.shape == (num_seqs, ) - assert self.max_query_len == 1 - assert self.max_prefill_seq_len == 0 - assert self.max_decode_seq_len == max(self.seq_lens) - - assert self.query_start_loc is not None - assert self.query_start_loc.shape == (num_queries + 1, ) - assert self.seq_start_loc is not None - assert self.seq_start_loc.shape == (num_seqs + 1, ) - - assert self.context_lens_tensor is not None - assert self.context_lens_tensor.shape == (num_queries, ) - - assert self.block_tables is not None - assert self.block_tables.shape[0] == num_seqs - - # Update query lengths. Note that we update only queries and not seqs, - # since tensors may be padded due to captured cuda graph batch size - for i in range(num_queries): - self.seq_lens[i] += 1 - self.max_decode_seq_len = max(self.seq_lens) - - ops.advance_step_flashattn(num_seqs=num_seqs, - num_queries=num_queries, - block_size=block_size, - input_tokens=model_input.input_tokens, - sampled_token_ids=sampled_token_ids, - input_positions=model_input.input_positions, - seq_lens=self.seq_lens_tensor, - slot_mapping=self.slot_mapping, - block_tables=self.block_tables) - class ROCmFlashAttentionMetadataBuilder( CommonMetadataBuilder[ROCmFlashAttentionMetadata]): diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 673fb5866..059e7a3b2 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -15,7 +15,7 @@ import torch.fx as fx from torch._dispatch.python import enable_python_dispatcher import vllm.envs as envs -from vllm.config import CompilationConfig, VllmConfig +from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.utils import is_torch_equal_or_newer, resolve_obj_by_qualname @@ -277,9 +277,6 @@ def split_graph(graph: fx.GraphModule, return split_gm, outputs -# we share the global graph pool among all the backends -global_graph_pool = None - compilation_start_time = 0.0 @@ -339,14 +336,37 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter): graph_index=index, num_graphs=len(self.compile_submod_names), runtime_shape=None) + # Lazy import here to avoid circular import + from .cuda_graph import CUDAGraphOptions + from .cuda_piecewise_backend import PiecewiseBackend - piecewise_backend = resolve_obj_by_qualname( - current_platform.get_piecewise_backend_cls()) - self.module.__dict__[target] = piecewise_backend( - submod, self.vllm_config, self.graph_pool, index, + piecewise_backend = PiecewiseBackend( + submod, self.vllm_config, index, len(self.compile_submod_names), sym_shape_indices, compiled_graph_for_dynamic_shape, self.vllm_backend) + if self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE: + # resolve the static graph wrapper class (e.g. CUDAGraphWrapper + # class) as platform dependent. + static_graph_wrapper_class = resolve_obj_by_qualname( + current_platform.get_static_graph_wrapper_cls()) + + # Always assign PIECEWISE runtime mode to the + # CUDAGraphWrapper for piecewise_backend, to distinguish + # it from the FULL cudagraph runtime mode, no matter it + # is wrapped on a full or piecewise fx graph. + self.module.__dict__[target] = static_graph_wrapper_class( + runnable=piecewise_backend, + vllm_config=self.vllm_config, + runtime_mode=CUDAGraphMode.PIECEWISE, + graph_pool=self.graph_pool, + cudagraph_options=CUDAGraphOptions( + debug_log_enable=piecewise_backend.is_first_graph, + gc_disable=not piecewise_backend.is_first_graph, + weak_ref_output=piecewise_backend.is_last_graph)) + else: + self.module.__dict__[target] = piecewise_backend + compilation_counter.num_piecewise_capturable_graphs_seen += 1 return output @@ -413,9 +433,7 @@ class VllmBackend: # them, e.g. backbone (default), eagle_head, etc. self.prefix = prefix or model_tag - global global_graph_pool - if global_graph_pool is None: - global_graph_pool = current_platform.graph_pool_handle() + global_graph_pool = current_platform.get_global_graph_pool() # TODO: in the future, if we want to use multiple # streams, it might not be safe to share a global pool. @@ -585,7 +603,7 @@ class VllmBackend: self._called = True - if not self.compilation_config.use_cudagraph or \ + if self.compilation_config.cudagraph_mode == CUDAGraphMode.NONE or \ not self.compilation_config.cudagraph_copy_inputs: return self.split_gm diff --git a/vllm/compilation/base_piecewise_backend.py b/vllm/compilation/base_piecewise_backend.py deleted file mode 100644 index 4d7aeeb4d..000000000 --- a/vllm/compilation/base_piecewise_backend.py +++ /dev/null @@ -1,72 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from typing import Any, Callable, Protocol - -import torch.fx as fx - -from vllm.compilation.backends import VllmBackend -from vllm.config import VllmConfig - - -class AbstractPiecewiseBackend(Protocol): - """ - PiecewiseBackend interface that allows platforms to extend - piecewise static graph. - """ - - def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig, - graph_pool: Any, piecewise_compile_index: int, - total_piecewise_compiles: int, sym_shape_indices: list[int], - compiled_graph_for_general_shape: Callable, - vllm_backend: VllmBackend, **kwargs): - """ - Initializes the PiecewiseBackend class with compilation and - execution-related configurations. - - This class handles piecewise compilation, graph capturing, - and dispatching for specific input shapes. - - Args: - graph (fx.GraphModule): The graph represented in fx. - vllm_config (VllmConfig): Global configuration for vLLM. - graph_pool (Any): - Graph memory pool handle, e.g., - `torch.cuda.graph_pool_handle()`. - piecewise_compile_index (int): - Index of the current piecewise subgraph. - total_piecewise_compiles (int): - Total number of piecewise-compiled graphs. - sym_shape_indices (list[int]): - Indices of symbolic shape. - compiled_graph_for_general_shape (Callable): - Callable that executes the graph compiled for general shapes. - vllm_backend (VllmBackend): - Backend compiler that manages compilation and graph runtime - for vLLM. - - Keyword Args: - kwargs: Additional keyword arguments reserved for future - extensions or custom platforms. - """ - raise NotImplementedError - - def __call__(self, *args) -> Any: - """Executes the compiled graph for given input args. - - If this is the first invocation, executes the general compiled graph - and initiates the compilation process tracking. For subsequent calls, - dynamically dispatches execution to either a compiled graph or a static - graph based on the input shape. - - Args: - *args: Variable length input arguments to be passed into the - graph. The symbolic shape is expected to be in position - `sym_shape_indices[0]`. - - Returns: - Any: Output of the executed graph. This can be from the general - compiled graph, a specialized compiled version for the given shape, - or a replayed static graph. - """ - raise NotImplementedError diff --git a/vllm/compilation/base_static_graph.py b/vllm/compilation/base_static_graph.py new file mode 100644 index 000000000..1c3f52c53 --- /dev/null +++ b/vllm/compilation/base_static_graph.py @@ -0,0 +1,54 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Any, Callable, Protocol + +from vllm.config import CUDAGraphMode, VllmConfig + + +class AbstractStaticGraphWrapper(Protocol): + """ + StaticGraphWrapper interface that allows platforms to wrap a callable + to be captured as a static graph. + """ + + def __init__(self, runnable: Callable, vllm_config: VllmConfig, + runtime_mode: CUDAGraphMode, graph_pool: Any, **kwargs): + """ + Initializes the StaticGraphWrapper class with graph capturing and + execution-related configurations. + + Args: + runnable (Callable): The callable to be wrapped and captured. + vllm_config (VllmConfig): Global configuration for vLLM. + runtime_mode (CUDAGraphMode): The style of the static + graph runtime. See CUDAGraphMode in vllm/config.py. + Note that only the subset enum `NONE`, `PIECEWISE` and `FULL` + are used as concrete runtime mode for cudagraph dispatching. + graph_pool (Any): + Graph memory pool handle, e.g., + `torch.cuda.graph_pool_handle()`. + Keyword Args: + kwargs: Additional keyword arguments for platform-specific + configurations. + """ + raise NotImplementedError + + def __call__(self, *args, **kwargs) -> Any: + """ + Executes the wrapped callable. + + If the current runtime mode in the ForwardContext matches the runtime + mode of this instance, it replays the CUDAGraph or captures it using + the callable if it hasn't been captured yet. Otherwise, it calls the + original callable directly. + + Args: + *args: Variable length input arguments to be passed into the + callable. + **kwargs: Keyword arguments to be passed into the callable. + + Returns: + Any: Output of the executed callable. + """ + raise NotImplementedError diff --git a/vllm/compilation/cuda_graph.py b/vllm/compilation/cuda_graph.py new file mode 100644 index 000000000..65a38197a --- /dev/null +++ b/vllm/compilation/cuda_graph.py @@ -0,0 +1,193 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import dataclasses +from contextlib import ExitStack +from typing import Any, Callable, Optional +from unittest.mock import patch + +import torch + +import vllm.envs as envs +from vllm.compilation.counter import compilation_counter +from vllm.compilation.monitor import validate_cudagraph_capturing_enabled +from vllm.config import CUDAGraphMode, VllmConfig +from vllm.forward_context import BatchDescriptor, get_forward_context +from vllm.logger import init_logger +from vllm.platforms import current_platform +from vllm.utils import weak_ref_tensors + +logger = init_logger(__name__) + + +@dataclasses.dataclass +class CUDAGraphEntry: + batch_descriptor: BatchDescriptor + cudagraph: Optional[torch.cuda.CUDAGraph] = None + output: Optional[Any] = None + + # for cudagraph debugging, track the input addresses + # during capture, and check if they are the same during replay + input_addresses: Optional[list[int]] = None + + +@dataclasses.dataclass +class CUDAGraphOptions: + debug_log_enable: bool = True + gc_disable: bool = False + weak_ref_output: bool = True + + +class CUDAGraphWrapper: + """Wraps a runnable to add CUDA graph capturing and replaying ability. And + provide attribute access to the underlying `runnable` via `__getattr__`. + + The workflow of this wrapper in the cudagraph dispatching is as follows: + 1. At initialization, a runtime mode is assigned to the wrapper (FULL or + PIECEWISE). + 2. At runtime, the wrapper receives a runtime_mode and a + batch_descriptor(key) from the forward context and blindly trust them + for cudagraph dispatching. + 3. If runtime_mode is NONE or runtime_mode does not match the mode of the + wrapper, just call the runnable directly. + 4. Otherwise, i.e., the runtime_mode matches the mode of the wrapper, + the wrapper will perform cudagraph capture(if key does not exist, create + a new entry and cache it) or replay (if key exists in the cache). + + Note: CUDAGraphWrapper does not store persistent buffers or copy any + runtime inputs into that buffers for replay. We assume implementing them + is done outside of the wrapper. That is because we do not make any + assumption on the dynamic shape (batch size) of the runtime inputs, as a + trade-off for staying orthogonal to compilation logic. Nevertheless, + tracing and checking the input addresses to be consistent during replay is + guaranteed when VLLM_LOGGING_LEVEL == "DEBUG". + """ + + def __init__(self, + runnable: Callable, + vllm_config: VllmConfig, + runtime_mode: CUDAGraphMode, + graph_pool: Any = None, + cudagraph_options: Optional[CUDAGraphOptions] = None): + self.runnable = runnable + self.vllm_config = vllm_config + self.graph_pool = graph_pool + self.runtime_mode = runtime_mode + self.compilation_config = vllm_config.compilation_config + + self.first_run_finished = False + self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG" + + # assert runtime_mode is not NONE(no cudagraph), otherwise, we don't + # need to initialize a CUDAGraphWrapper. + assert self.runtime_mode != CUDAGraphMode.NONE + if self.graph_pool is None: + self.graph_pool = current_platform.get_global_graph_pool() + + if cudagraph_options is None: + cudagraph_options = CUDAGraphOptions() + self.cudagraph_options = cudagraph_options + # the entries for different batch descriptors that we need to capture + # cudagraphs for. + self.concrete_cudagraph_entries: dict[BatchDescriptor, CUDAGraphEntry]\ + = {} + + def __getattr__(self, key: str): + # allow accessing the attributes of the runnable. + if hasattr(self.runnable, key): + return getattr(self.runnable, key) + raise AttributeError(f"Attribute {key} not exists in the runnable of " + f"cudagraph wrapper: {self.runnable}") + + def unwrap(self) -> Callable: + # in case we need to access the original runnable. + return self.runnable + + def __call__(self, *args, **kwargs): + forward_context = get_forward_context() + batch_descriptor = forward_context.batch_descriptor + cudagraph_runtime_mode = forward_context.cudagraph_runtime_mode + + if cudagraph_runtime_mode == CUDAGraphMode.NONE or \ + cudagraph_runtime_mode != self.runtime_mode: + # CUDAGraphMode.NONE could mean the profile run, a warmup run, or + # running without cudagraphs. + # We do not trigger capture/replay if the runtime mode is not + # matches. This enables properly dispatching to the correct + # CUDAGraphWrapper when nesting multiple instances with different + # runtime modes. + return self.runnable(*args, **kwargs) + + if batch_descriptor not in self.concrete_cudagraph_entries: + # create a new entry for this batch descriptor + self.concrete_cudagraph_entries[batch_descriptor] = \ + CUDAGraphEntry(batch_descriptor=batch_descriptor) + + entry = self.concrete_cudagraph_entries[batch_descriptor] + + if entry.cudagraph is None: + if self.cudagraph_options.debug_log_enable: + # Since we capture cudagraph for many different shapes and + # capturing is fast, we don't need to log it for every + # shape. E.g. we only log it for the first subgraph in + # piecewise mode. + logger.debug("Capturing a cudagraph on (%s,%s)", + self.runtime_mode.name, entry.batch_descriptor) + # validate that cudagraph capturing is legal at this point. + validate_cudagraph_capturing_enabled() + + input_addresses = [ + x.data_ptr() for x in args if isinstance(x, torch.Tensor) + ] + entry.input_addresses = input_addresses + cudagraph = torch.cuda.CUDAGraph() + + with ExitStack() as stack: + if self.cudagraph_options.gc_disable: + # during every model forward for piecewise cudagraph + # mode, we will capture many pieces of cudagraphs + # (roughly one per layer). running gc again and again + # across layers will make the cudagraph capture very slow. + # therefore, we only run gc for the first graph, + # and disable gc for the rest of the graphs. + stack.enter_context(patch("gc.collect", lambda: None)) + stack.enter_context( + patch("torch.cuda.empty_cache", lambda: None)) + + # mind-exploding: carefully manage the reference and memory. + with torch.cuda.graph(cudagraph, pool=self.graph_pool): + # `output` is managed by pytorch's cudagraph pool + output = self.runnable(*args, **kwargs) + if self.cudagraph_options.weak_ref_output: + # by converting it to weak ref, + # the original `output` will immediately be released + # to save memory. It is only safe to do this for + # the last graph in piecewise cuadgraph mode, because + # the output of the last graph will not be used by + # any other cuda graph. + output = weak_ref_tensors(output) + + # here we always use weak ref for the output + # to save memory + entry.output = weak_ref_tensors(output) + entry.cudagraph = cudagraph + + compilation_counter.num_cudagraph_captured += 1 + + # important: we need to return the output, rather than + # the weak ref of the output, so that pytorch can correctly + # manage the memory during cuda graph capture + return output + + if self.is_debugging_mode: + # check if the input addresses are the same + new_input_addresses = [ + x.data_ptr() for x in args if isinstance(x, torch.Tensor) + ] + assert new_input_addresses == entry.input_addresses, ( + f"Input addresses for cudagraphs are different " + f"during replay. Expected {entry.input_addresses}, " + f"got {new_input_addresses}") + + entry.cudagraph.replay() + return entry.output diff --git a/vllm/compilation/cuda_piecewise_backend.py b/vllm/compilation/cuda_piecewise_backend.py index 8c49ea6cc..ae26e9f1b 100644 --- a/vllm/compilation/cuda_piecewise_backend.py +++ b/vllm/compilation/cuda_piecewise_backend.py @@ -2,21 +2,15 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import dataclasses -from contextlib import ExitStack -from typing import Any, Callable, Optional -from unittest.mock import patch +from typing import Any, Callable -import torch import torch.fx as fx import vllm.envs as envs from vllm.compilation.backends import VllmBackend -from vllm.compilation.counter import compilation_counter from vllm.compilation.monitor import end_monitoring_torch_compile from vllm.config import VllmConfig -from vllm.forward_context import get_forward_context from vllm.logger import init_logger -from vllm.utils import weak_ref_tensors logger = init_logger(__name__) @@ -24,44 +18,29 @@ logger = init_logger(__name__) @dataclasses.dataclass class ConcreteSizeEntry: runtime_shape: int - need_to_compile: bool # the size is in compile_sizes - use_cudagraph: bool # the size is in cudagraph_capture_sizes - compiled: bool = False runnable: Callable = None # type: ignore - num_finished_warmup: int = 0 - cudagraph: Optional[torch.cuda.CUDAGraph] = None - output: Optional[Any] = None - - # for cudagraph debugging, track the input addresses - # during capture, and check if they are the same during replay - input_addresses: Optional[list[int]] = None -class CUDAPiecewiseBackend: +class PiecewiseBackend: def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig, - graph_pool: Any, piecewise_compile_index: int, - total_piecewise_compiles: int, sym_shape_indices: list[int], + piecewise_compile_index: int, total_piecewise_compiles: int, + sym_shape_indices: list[int], compiled_graph_for_general_shape: Callable, vllm_backend: VllmBackend): """ The backend for piecewise compilation. - It mainly handles the compilation and cudagraph capturing. + It mainly handles the compilation of static shapes and + dispatching based on runtime shape. We will compile `self.graph` once for the general shape, and then compile for different shapes specified in `compilation_config.compile_sizes`. - - Independently, we will capture cudagraph for different shapes. - - If a shape needs both compilation and cudagraph, we will - compile it first, and then capture cudagraph. """ self.graph = graph self.vllm_config = vllm_config self.compilation_config = vllm_config.compilation_config - self.graph_pool = graph_pool self.piecewise_compile_index = piecewise_compile_index self.total_piecewise_compiles = total_piecewise_compiles self.vllm_backend = vllm_backend @@ -70,11 +49,10 @@ class CUDAPiecewiseBackend: self.is_last_graph = ( piecewise_compile_index == total_piecewise_compiles - 1) + self.is_full_graph = total_piecewise_compiles == 1 + self.compile_sizes: set[int] = set( self.compilation_config.compile_sizes) - self.cudagraph_capture_sizes: set[int] = set( - self.compilation_config.cudagraph_capture_sizes - ) if self.compilation_config.use_cudagraph else set() self.first_run_finished = False @@ -84,18 +62,18 @@ class CUDAPiecewiseBackend: self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG" - # the entries for different shapes that we need to either - # compile or capture cudagraph + # the entries for different shapes that we need to compile self.concrete_size_entries: dict[int, ConcreteSizeEntry] = {} # to_be_compiled_sizes tracks the remaining sizes to compile, # and updates during the compilation process, so we need to copy it self.to_be_compiled_sizes: set[int] = self.compile_sizes.copy() - for shape in self.compile_sizes.union(self.cudagraph_capture_sizes): + + # We only keep compilation management inside this class directly. + for shape in self.compile_sizes: self.concrete_size_entries[shape] = ConcreteSizeEntry( runtime_shape=shape, - need_to_compile=shape in self.compile_sizes, - use_cudagraph=shape in self.cudagraph_capture_sizes, + runnable=self.compiled_graph_for_general_shape, ) def check_for_ending_compilation(self): @@ -112,16 +90,14 @@ class CUDAPiecewiseBackend: return self.compiled_graph_for_general_shape(*args) runtime_shape = args[self.sym_shape_indices[0]] + if runtime_shape not in self.concrete_size_entries: # we don't need to do anything for this shape return self.compiled_graph_for_general_shape(*args) entry = self.concrete_size_entries[runtime_shape] - if entry.runnable is None: - entry.runnable = self.compiled_graph_for_general_shape - - if entry.need_to_compile and not entry.compiled: + if not entry.compiled: entry.compiled = True self.to_be_compiled_sizes.remove(runtime_shape) # args are real arguments @@ -138,81 +114,4 @@ class CUDAPiecewiseBackend: if self.is_last_graph and not self.to_be_compiled_sizes: self.check_for_ending_compilation() - # Skip CUDA graphs if this entry doesn't use them OR - # if we're supposed to skip them globally - skip_cuda_graphs = get_forward_context().skip_cuda_graphs - if not entry.use_cudagraph or skip_cuda_graphs: - return entry.runnable(*args) - - if entry.cudagraph is None: - if entry.num_finished_warmup < self.compilation_config.cudagraph_num_of_warmups: # noqa - entry.num_finished_warmup += 1 - if self.is_first_graph: - logger.debug( - "Warming up %s/%s for shape %s", - entry.num_finished_warmup, - self.compilation_config.cudagraph_num_of_warmups, - runtime_shape) - return entry.runnable(*args) - - if self.is_first_graph: - # Since we capture cudagraph for many different shapes and - # capturing is fast, we don't need to log it for every shape. - # We only log it in the debug mode. - logger.debug("Capturing a cudagraph for shape %s", - runtime_shape) - - input_addresses = [ - x.data_ptr() for x in args if isinstance(x, torch.Tensor) - ] - entry.input_addresses = input_addresses - cudagraph = torch.cuda.CUDAGraph() - - with ExitStack() as stack: - if not self.is_first_graph: - # during every model forward, we will capture - # many pieces of cudagraphs (roughly one per layer). - # running gc again and again across layers will - # make the cudagraph capture very slow. - # therefore, we only run gc for the first graph, - # and disable gc for the rest of the graphs. - stack.enter_context(patch("gc.collect", lambda: None)) - stack.enter_context( - patch("torch.cuda.empty_cache", lambda: None)) - - # mind-exploding: carefully manage the reference and memory. - with torch.cuda.graph(cudagraph, pool=self.graph_pool): - # `output` is managed by pytorch's cudagraph pool - output = entry.runnable(*args) - if self.is_last_graph: - # by converting it to weak ref, - # the original `output` will immediately be released - # to save memory. It is only safe to do this for - # the last graph, because the output of the last graph - # will not be used by any other cuda graph. - output = weak_ref_tensors(output) - - # here we always use weak ref for the output - # to save memory - entry.output = weak_ref_tensors(output) - entry.cudagraph = cudagraph - - compilation_counter.num_cudagraph_captured += 1 - - # important: we need to return the output, rather than - # the weak ref of the output, so that pytorch can correctly - # manage the memory during cuda graph capture - return output - - if self.is_debugging_mode: - # check if the input addresses are the same - new_input_addresses = [ - x.data_ptr() for x in args if isinstance(x, torch.Tensor) - ] - assert new_input_addresses == entry.input_addresses, ( - "Input addresses for cudagraphs are different during replay." - f" Expected {entry.input_addresses}, got {new_input_addresses}" - ) - - entry.cudagraph.replay() - return entry.output + return entry.runnable(*args) diff --git a/vllm/compilation/monitor.py b/vllm/compilation/monitor.py index 1e059b59f..9047bf3cb 100644 --- a/vllm/compilation/monitor.py +++ b/vllm/compilation/monitor.py @@ -37,3 +37,21 @@ def end_monitoring_torch_compile(vllm_config: VllmConfig): if context_manager is not None: context_manager.__exit__(None, None, None) context_manager = None + + +cudagraph_capturing_enabled: bool = True + + +def validate_cudagraph_capturing_enabled(): + # used to monitor whether an cudagraph capturing is legal at runtime. + # should be called before any cudagraph capturing. + # if an illegal cudagraph capturing happens, raise an error. + global cudagraph_capturing_enabled + if not cudagraph_capturing_enabled: + raise RuntimeError("CUDA graph capturing detected at an inappropriate " + "time. This operation is currently disabled.") + + +def set_cudagraph_capturing_enabled(enabled: bool): + global cudagraph_capturing_enabled + cudagraph_capturing_enabled = enabled diff --git a/vllm/compilation/wrapper.py b/vllm/compilation/wrapper.py index 8d5df1061..96d4eae2e 100644 --- a/vllm/compilation/wrapper.py +++ b/vllm/compilation/wrapper.py @@ -11,7 +11,8 @@ from typing import Callable, Optional import torch import vllm.envs as envs -from vllm.config import CompilationLevel, get_current_vllm_config +from vllm.config import (CompilationLevel, CUDAGraphMode, + get_current_vllm_config) from vllm.logger import init_logger logger = init_logger(__name__) @@ -115,8 +116,8 @@ class TorchCompileWrapperWithCustomDispatcher: except Exception: pass - if self.vllm_config.compilation_config.use_cudagraph and \ - "update" in new_code.co_names: + if self.vllm_config.compilation_config.cudagraph_mode != \ + CUDAGraphMode.NONE and "update" in new_code.co_names: import depyf src = depyf.decompile(new_code) msg = "Assigning / modifying buffers of nn.Module during forward pass is not allowed when using cudagraph inside the compiler because it will cause silent errors. Please use eager mode or fix the code. The following code contains clues about which buffer is being modified (please search for the usage of the function `update`):\n" + src # noqa diff --git a/vllm/config/__init__.py b/vllm/config/__init__.py index b4ea15ef5..280ae60c9 100644 --- a/vllm/config/__init__.py +++ b/vllm/config/__init__.py @@ -29,10 +29,10 @@ from typing_extensions import Self, assert_never, runtime_checkable import vllm.envs as envs from vllm import version -from vllm.config.cache import (BlockSize, CacheConfig, CacheDType, +from vllm.config.cache import (BlockSize, CacheConfig, CacheDType, MambaDType, PrefixCachingHashAlgo) from vllm.config.compilation import (CompilationConfig, CompilationLevel, - PassConfig) + CUDAGraphMode, PassConfig) from vllm.config.parallel import DistributedExecutorBackend, ParallelConfig from vllm.config.scheduler import SchedulerConfig, SchedulerPolicy from vllm.config.utils import ConfigType, config @@ -388,6 +388,10 @@ class ModelConfig: interleave_mm_strings: bool = False """Enable fully interleaved support for multimodal prompts, while using --chat-template-content-format=string. Defaults to False.""" + skip_mm_profiling: bool = False + """When enabled, skips multimodal memory profiling and only profiles with + language backbone model during engine initialization. + """ media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict) """Additional args passed to process media inputs, keyed by modalities. For example, to set num_frames for video, set @@ -837,7 +841,8 @@ class ModelConfig: media_io_kwargs=self.media_io_kwargs, mm_processor_kwargs=self.mm_processor_kwargs, mm_processor_cache_gb=self.mm_processor_cache_gb, - interleave_mm_strings=self.interleave_mm_strings) + interleave_mm_strings=self.interleave_mm_strings, + skip_mm_profiling=self.skip_mm_profiling) return None @@ -2511,6 +2516,16 @@ class MultiModalConfig: Enable fully interleaved support for multimodal prompts. """ + skip_mm_profiling: bool = False + """ + When enabled, skips multimodal memory profiling and only profiles with + language backbone model during engine initialization. + + This reduces engine startup time but shifts the responsibility to users for + estimating the peak memory usage of the activation of multimodal encoder and + embedding cache. + """ + def compute_hash(self) -> str: """ WARNING: Whenever a new field is added to this config, @@ -3514,11 +3529,21 @@ class VllmConfig: else: self.compilation_config.level = \ CompilationLevel.NO_COMPILATION + else: # NB: Passing both --enforce-eager and a compilation level # in V0 means the compilation level wins out. self.compilation_config.level = CompilationLevel.NO_COMPILATION + # if cudagraph_mode is not explicitly set by users, set default value + if self.compilation_config.cudagraph_mode is None: + if envs.VLLM_USE_V1 and self.compilation_config.level \ + == CompilationLevel.PIECEWISE: + self.compilation_config.cudagraph_mode = \ + CUDAGraphMode.PIECEWISE + else: + self.compilation_config.cudagraph_mode = CUDAGraphMode.NONE + # async tp is built on top of sequence parallelism # and requires it to be enabled. if self.compilation_config.pass_config.enable_async_tp: @@ -3526,12 +3551,13 @@ class VllmConfig: True if self.compilation_config.pass_config.enable_sequence_parallelism: self.compilation_config.custom_ops.append("+rms_norm") - if envs.VLLM_USE_V1 and self.model_config is not None and \ - not self.model_config.enforce_eager: - # By default, V1 uses piecewise CUDA graphs. If full_cuda_graph - # is set to True, full CUDA graphs will be used. + + # disable cudagraph when enforce eager execution + if self.model_config is not None and self.model_config.enforce_eager: + logger.info("Cudagraph is disabled under eager mode") + self.compilation_config.cudagraph_mode = CUDAGraphMode.NONE + elif envs.VLLM_USE_V1: self.compilation_config.cudagraph_num_of_warmups = 1 - self.compilation_config.set_splitting_ops_for_v1() self._set_cudagraph_sizes() @@ -3551,12 +3577,6 @@ class VllmConfig: "Disabling `torch.compile`.") self.compilation_config.level = CompilationLevel.NO_COMPILATION - if self.compilation_config.full_cuda_graph and \ - not self.model_config.disable_cascade_attn: - logger.info("full_cuda_graph is not supported with " - "cascade attention. Disabling cascade attention.") - self.model_config.disable_cascade_attn = True - disable_chunked_prefill_reasons: list[str] = [] if self.model_config and self.model_config.pooler_config: @@ -3597,9 +3617,32 @@ class VllmConfig: "to True to enable.") current_platform.check_and_update_config(self) + # final check of cudagraph mode after platform-specific update + if envs.VLLM_USE_V1: + if self.compilation_config.cudagraph_mode == CUDAGraphMode.FULL \ + and self.model_config is not None and \ + not self.model_config.disable_cascade_attn: + logger.info("CUDAGraphMode.FULL is not supported with " + "cascade attention currently. Disabling cascade" + "attention.") + self.model_config.disable_cascade_attn = True + + if self.compilation_config.cudagraph_mode\ + .requires_piecewise_compilation(): + assert self.compilation_config.level == \ + CompilationLevel.PIECEWISE, \ + "Compilation level should be CompilationLevel.PIECEWISE "\ + "when cudagraph_mode piecewise cudagraphs is used, "\ + f"cudagraph_mode={self.compilation_config.cudagraph_mode}" + if not self.instance_id: self.instance_id = random_uuid()[:5] + # Do this after all the updates to compilation_config.level + if envs.VLLM_USE_V1 and \ + self.compilation_config.level == CompilationLevel.PIECEWISE: + self.compilation_config.set_splitting_ops_for_v1() + if (envs.VLLM_USE_V1 and not self.scheduler_config.disable_hybrid_kv_cache_manager): # logger should only print warning message for hybrid models. As we diff --git a/vllm/config/cache.py b/vllm/config/cache.py index 69cb0d973..ae11dec3c 100644 --- a/vllm/config/cache.py +++ b/vllm/config/cache.py @@ -23,6 +23,7 @@ logger = init_logger(__name__) BlockSize = Literal[1, 8, 16, 32, 64, 128] CacheDType = Literal["auto", "fp8", "fp8_e4m3", "fp8_e5m2", "fp8_inc"] +MambaDType = Literal["auto", "float32"] PrefixCachingHashAlgo = Literal["builtin", "sha256", "sha256_cbor_64bit"] @@ -93,6 +94,15 @@ class CacheConfig: """ Optional override for mamba page size; used by hybrid mamba/attention models to ensure exact alignment with attention page size.""" + mamba_cache_dtype: MambaDType = "auto" + """The data type to use for the Mamba cache (both the conv as well as the + ssm state). If set to 'auto', the data type will be inferred from the model + config.""" + mamba_ssm_cache_dtype: MambaDType = "auto" + """The data type to use for the Mamba cache (ssm state only, conv state will + still be controlled by mamba_cache_dtype). If set to 'auto', the data type + for the ssm state will be determined by mamba_cache_dtype.""" + # Will be set after profiling. num_gpu_blocks: Optional[int] = field(default=None, init=False) """The number of blocks to allocate for GPU memory.""" @@ -123,6 +133,8 @@ class CacheConfig: """ factors: list[Any] = [] factors.append(self.cache_dtype) + factors.append(self.mamba_cache_dtype) + factors.append(self.mamba_ssm_cache_dtype) # `cpu_offload_gb` does not use `torch.compile` yet. hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest() diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 8a78d811b..56a2183f8 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -1,12 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import enum import hashlib from collections import Counter from dataclasses import asdict, field -from typing import TYPE_CHECKING, Any, Callable, Optional, Union +from typing import TYPE_CHECKING, Any, Callable, ClassVar, Optional, Union -from pydantic import TypeAdapter +from pydantic import TypeAdapter, field_validator from pydantic.dataclasses import dataclass import vllm.envs as envs @@ -31,6 +32,40 @@ class CompilationLevel: PIECEWISE = 3 +class CUDAGraphMode(enum.Enum): + """ Constants for the cudagraph mode in CompilationConfig. + Meanwhile, the subset enum `NONE`, `PIECEWISE` and `FULL` are also + treated as concrete runtime mode for cudagraph runtime dispatching. + """ + NONE = 0 + PIECEWISE = 1 + FULL = 2 + FULL_DECODE_ONLY = (FULL, NONE) + FULL_AND_PIECEWISE = (FULL, PIECEWISE) + + def decode_mode(self) -> 'CUDAGraphMode': + return CUDAGraphMode(self.value[0]) if \ + self.separate_routine() else self + + def mixed_mode(self) -> 'CUDAGraphMode': + return CUDAGraphMode(self.value[1]) if \ + self.separate_routine() else self + + def requires_piecewise_compilation(self) -> bool: + return (self.decode_mode() == CUDAGraphMode.PIECEWISE + or self.mixed_mode() == CUDAGraphMode.PIECEWISE) + + def max_cudagraph_mode(self) -> 'CUDAGraphMode': + return CUDAGraphMode(max( + self.value)) if self.separate_routine() else self + + def has_full_cudagraphs(self) -> bool: + return self.max_cudagraph_mode() == CUDAGraphMode.FULL + + def separate_routine(self) -> bool: + return isinstance(self.value, tuple) + + @config @dataclass class PassConfig: @@ -91,6 +126,7 @@ class CompilationConfig: - [`splitting_ops`][vllm.config.CompilationConfig.splitting_ops] - CudaGraph capture: - [`use_cudagraph`][vllm.config.CompilationConfig.use_cudagraph] + - [`cudagraph_mode`][vllm.config.CompilationConfig.cudagraph_mode] - [`cudagraph_capture_sizes`] [vllm.config.CompilationConfig.cudagraph_capture_sizes] - [`cudagraph_num_of_warmups`] @@ -157,7 +193,7 @@ class CompilationConfig: By default, all custom ops are enabled when running without Inductor and disabled when running with Inductor: level>=PIECEWISE and use_inductor=True. Inductor generates (fused) Triton kernels for disabled custom ops.""" - splitting_ops: list[str] = field(default_factory=list) + splitting_ops: Optional[list[str]] = None """A list of ops to split the full graph into subgraphs, used in piecewise compilation.""" @@ -187,7 +223,43 @@ class CompilationConfig: constructor, e.g. `CompilationConfig(inductor_passes={"a": func})`.""" # CudaGraph compilation - use_cudagraph: bool = field(default_factory=lambda: envs.VLLM_USE_V1) + cudagraph_mode: Optional[CUDAGraphMode] = None + """ + The mode of the cudagraph. + - NONE, no cudagraph capture. + - PIECEWISE. (v1 default) + - FULL. + - FULL_DECODE_ONLY. + - FULL_AND_PIECEWISE. + + PIECEWISE mode build piecewise cudagraph only, keeping the cudagraph + incompatiable ops (i.e. some attention ops) outside the cudagraph + for general flexibility. + This is the default mode. + + FULL mode: Capture full cudagraph for all batches. Can be good for small + models or workloads with small prompts; not supported by many backends. + Generally for performance FULL_AND_PIECEWISE is better. + + FULL_DECODE_ONLY mode: Capture full cudagraph for decode batches only. + Mixed prefill-decode batches are run without cudagraphs. Can be good for + decode instances in a P/D setup where prefill is not as important so we + can save some memory. + + FULL_AND_PIECEWISE mode: Capture full cudagraph for decode batches and + piecewise cudagraph for prefill and mixed prefill-decode batches. + This is like the most performant mode for most models. + + Currently, the cudagraph mode is only used for the v1 engine. + Note that the cudagraph logic is generally orthogonal to the + compilation logic. While piecewise cudagraphs require piecewise + compilation (level=PIECEWISE and non-empty splitting_ops), full + cudagraphs are supported with and without compilation. + + Warning: This flag is new and subject to change in addition + more modes may be added. + """ + use_cudagraph: bool = True """Whether to use cudagraph inside compilation. - False: cudagraph inside compilation is not used. - True: cudagraph inside compilation is used. It requires @@ -197,8 +269,9 @@ class CompilationConfig: CompilationLevel.PIECEWISE (aka -O3). Note that this is orthogonal to the cudagraph capture logic outside of compilation. - TODO: move outside cudagraph logic into compilation. - torch.compile will handle cudagraph capture logic in the future.""" + Warning: This flag is deprecated and will be removed in the next major or + minor release, i.e. v0.11.0 or v1.0.0. Please use cudagraph_mode instead. + """ cudagraph_num_of_warmups: int = 0 """Number of warmup runs for cudagraph. It means the first several runs will be treated as warmup runs. @@ -213,12 +286,17 @@ class CompilationConfig: cudagraph. If the caller can guarantee that the same input buffers are always used, it can set this to False. Otherwise, it should set this to True, and the compiler will copy the input to an - internally managed buffer. Default is False.""" - full_cuda_graph: bool = False + internally managed buffer. Default is False. + Note that this flag is only effective when cudagraph_mode is PIECEWISE. + """ + full_cuda_graph: Optional[bool] = False """whether to use a full cuda graph for the entire forward pass rather than splitting certain operations such as attention into subgraphs. Thus this flag cannot be used together with splitting_ops. This may provide - performance benefits for smaller models.""" + performance benefits for smaller models. + Warning: This flag is deprecated and will be removed in the next major or + minor release, i.e. v0.11.0 or v1.0.0. Please use cudagraph_mode instead. + """ pass_config: PassConfig = field(default_factory=PassConfig) """Custom inductor passes, see PassConfig for more details""" @@ -253,6 +331,13 @@ class CompilationConfig: Map from layer name to layer objects that need to be accessed outside model code, e.g., Attention, FusedMOE when dp_size>1.""" + # Attention ops; used for piecewise cudagraphs + _attention_ops: ClassVar[list[str]] = [ + "vllm.unified_attention", + "vllm.unified_attention_with_output", + "vllm.mamba_mixer2", + ] + def compute_hash(self) -> str: """ WARNING: Whenever a new field is added to this config, @@ -297,13 +382,26 @@ class CompilationConfig: if pass_config_exclude: exclude["pass_config"] = pass_config_exclude - return TypeAdapter(CompilationConfig).dump_json( - self, - exclude=exclude, # type: ignore[arg-type] - exclude_unset=True).decode() + # The cast to string is necessary because Pydantic is mocked in docs + # builds and sphinx-argparse doesn't know the return type of decode() + return str( + TypeAdapter(CompilationConfig).dump_json( + self, + exclude=exclude, # type: ignore[arg-type] + exclude_unset=True).decode()) __str__ = __repr__ + @field_validator("cudagraph_mode", mode="before") + @classmethod + def validate_cudagraph_mode_before(cls, value: Any) -> Any: + """ + enable parse the `cudagraph_mode` enum type from string + """ + if isinstance(value, str): + return CUDAGraphMode[value.upper()] + return value + def __post_init__(self) -> None: count_none = self.custom_ops.count("none") count_all = self.custom_ops.count("all") @@ -341,7 +439,26 @@ class CompilationConfig: if isinstance(self.pass_config, dict): self.pass_config = PassConfig(**self.pass_config) - def init_backend(self, vllm_config: VllmConfig) -> Union[str, Callable]: + # migrate the deprecated flags + if not self.use_cudagraph: + logger.warning("use_cudagraph is deprecated, use " + "cudagraph_mode=NONE instead.") + if self.cudagraph_mode is not None: + raise ValueError( + "use_cudagraph and cudagraph_mode are mutually" + " exclusive, prefer cudagraph_mode since " + "use_cudagraph is deprecated.") + self.cudagraph_mode = CUDAGraphMode.NONE + if self.full_cuda_graph: + logger.warning("full_cuda_graph is deprecated, use " + "cudagraph_mode=FULL instead.") + if self.cudagraph_mode is not None: + raise ValueError("full_cuda_graph and cudagraph_mode are " + "mutually exclusive, prefer cudagraph_mode " + "since full_cuda_graph is deprecated.") + self.cudagraph_mode = CUDAGraphMode.FULL + + def init_backend(self, vllm_config: "VllmConfig") -> Union[str, Callable]: if self.level == CompilationLevel.NO_COMPILATION: raise ValueError("No compilation level is set.") @@ -414,15 +531,34 @@ class CompilationConfig: self.max_capture_size] = self.max_capture_size def set_splitting_ops_for_v1(self): - # NOTE: this function needs to be called - if self.splitting_ops and self.full_cuda_graph: - raise ValueError("full_cuda_graph cannot be used together with " - "splitting_ops, as Full CUDA graph will override " - f"the splitting_ops: {self.splitting_ops}") + # NOTE: this function needs to be called only when level is + # CompilationLevel.PIECEWISE + assert self.level == CompilationLevel.PIECEWISE, ( + "set_splitting_ops_for_v1 should only be called when " + "level is CompilationLevel.PIECEWISE") - if not self.splitting_ops: - self.splitting_ops = [] if self.full_cuda_graph else [ - "vllm.unified_attention", - "vllm.unified_attention_with_output", - "vllm.mamba_mixer2", - ] + if self.splitting_ops is None: + # NOTE: When using full cudagraph, instead of setting an empty + # list and capture the full cudagraph inside the flattened fx + # graph, we keep the piecewise fx graph structure but capture the + # full cudagraph outside the fx graph. This reduces some cpu + # overhead when the runtime batch_size is not cudagraph captured. + # see https://github.com/vllm-project/vllm/pull/20059 for details. + self.splitting_ops = self._attention_ops + elif len(self.splitting_ops) == 0: + logger.warning_once("Using piecewise compilation with empty " + "splitting_ops.") + if self.cudagraph_mode == CUDAGraphMode.PIECEWISE: + logger.warning_once( + "When compilation level is piecewise with empty " + "splitting_ops, PIECEWISE cudagraph_mode will be " + "treated as FULL cudagraph_mode. Please ensure you are " + "using attention backends that support cudagraph or set " + "cudagraph_mode to NONE explicitly if encountering " + "any problems.") + self.cudagraph_mode = CUDAGraphMode.FULL + self.splitting_ops = [] + + def splitting_ops_contain_attention(self) -> bool: + return self.splitting_ops is not None and all( + op in self.splitting_ops for op in self._attention_ops) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/base.py b/vllm/distributed/kv_transfer/kv_connector/v1/base.py index b72104397..07fcdecac 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/base.py @@ -325,4 +325,8 @@ class KVConnectorBase_V1(ABC): str: the required KV cache layout. e.g. HND, or NHD. None if the connector does not require a specific layout. """ + + if cls is KVConnectorBase_V1: + raise TypeError("get_required_kvcache_layout should not be called " + "on the abstract base class") return None diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py index 7d67c76e2..d3f6a226d 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py @@ -228,9 +228,10 @@ class MultiConnector(KVConnectorBase_V1): for ktc in ktcs: kv_transfer_config = KVTransferConfig(**ktc) temp_vllm_config.kv_transfer_config = kv_transfer_config + connector_cls = KVConnectorFactory.get_connector_class( + kv_transfer_config) required_kvcache_layout = ( - KVConnectorBase_V1.get_required_kvcache_layout( - temp_vllm_config)) + connector_cls.get_required_kvcache_layout(temp_vllm_config)) if required_kvcache_layout is not None: layouts.add(required_kvcache_layout) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index dd1072da0..f8af6d36e 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -27,12 +27,12 @@ from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig, DeviceConfig, DistributedExecutorBackend, GuidedDecodingBackend, HfOverrides, KVEventsConfig, KVTransferConfig, LoadConfig, LogprobsMode, - LoRAConfig, ModelConfig, ModelDType, ModelImpl, - MultiModalConfig, ObservabilityConfig, ParallelConfig, - PoolerConfig, PrefixCachingHashAlgo, RunnerOption, - SchedulerConfig, SchedulerPolicy, SpeculativeConfig, - TaskOption, TokenizerMode, VllmConfig, get_attr_docs, - get_field) + LoRAConfig, MambaDType, ModelConfig, ModelDType, + ModelImpl, MultiModalConfig, ObservabilityConfig, + ParallelConfig, PoolerConfig, PrefixCachingHashAlgo, + RunnerOption, SchedulerConfig, SchedulerPolicy, + SpeculativeConfig, TaskOption, TokenizerMode, + VllmConfig, get_attr_docs, get_field) from vllm.logger import init_logger from vllm.platforms import CpuArchEnum, current_platform from vllm.plugins import load_general_plugins @@ -350,6 +350,7 @@ class EngineArgs: MultiModalConfig.mm_processor_kwargs disable_mm_preprocessor_cache: bool = False # DEPRECATED mm_processor_cache_gb: int = MultiModalConfig.mm_processor_cache_gb + skip_mm_profiling: bool = MultiModalConfig.skip_mm_profiling # LoRA fields enable_lora: bool = False enable_lora_bias: bool = LoRAConfig.bias_enabled @@ -421,6 +422,8 @@ class EngineArgs: override_attention_dtype: str = ModelConfig.override_attention_dtype calculate_kv_scales: bool = CacheConfig.calculate_kv_scales + mamba_cache_dtype: MambaDType = CacheConfig.mamba_cache_dtype + mamba_ssm_cache_dtype: MambaDType = CacheConfig.mamba_ssm_cache_dtype additional_config: dict[str, Any] = \ get_field(VllmConfig, "additional_config") @@ -693,6 +696,10 @@ class EngineArgs: **cache_kwargs["calculate_kv_scales"]) cache_group.add_argument("--kv-sharing-fast-prefill", **cache_kwargs["kv_sharing_fast_prefill"]) + cache_group.add_argument("--mamba-cache-dtype", + **cache_kwargs["mamba_cache_dtype"]) + cache_group.add_argument("--mamba-ssm-cache-dtype", + **cache_kwargs["mamba_ssm_cache_dtype"]) # Multimodal related configs multimodal_kwargs = get_kwargs(MultiModalConfig) @@ -716,6 +723,8 @@ class EngineArgs: multimodal_group.add_argument( "--interleave-mm-strings", **multimodal_kwargs["interleave_mm_strings"]) + multimodal_group.add_argument("--skip-mm-profiling", + **multimodal_kwargs["skip_mm_profiling"]) # LoRA related configs lora_kwargs = get_kwargs(LoRAConfig) @@ -918,6 +927,7 @@ class EngineArgs: limit_mm_per_prompt=self.limit_mm_per_prompt, interleave_mm_strings=self.interleave_mm_strings, media_io_kwargs=self.media_io_kwargs, + skip_mm_profiling=self.skip_mm_profiling, use_async_output_proc=not self.disable_async_output_proc, config_format=self.config_format, mm_processor_kwargs=self.mm_processor_kwargs, @@ -1101,6 +1111,8 @@ class EngineArgs: cpu_offload_gb=self.cpu_offload_gb, calculate_kv_scales=self.calculate_kv_scales, kv_sharing_fast_prefill=self.kv_sharing_fast_prefill, + mamba_cache_dtype=self.mamba_cache_dtype, + mamba_ssm_cache_dtype=self.mamba_ssm_cache_dtype, ) ray_runtime_env = None diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index e5d31c1fd..af86835a4 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -126,7 +126,7 @@ async def lifespan(app: FastAPI): async def _force_log(): while True: - await asyncio.sleep(10.) + await asyncio.sleep(envs.VLLM_LOG_STATS_INTERVAL) await engine_client.do_log_stats() task = asyncio.create_task(_force_log()) diff --git a/vllm/envs.py b/vllm/envs.py index 145ec3495..82084d1fc 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -38,6 +38,7 @@ if TYPE_CHECKING: VLLM_LOGGING_PREFIX: str = "" VLLM_LOGGING_CONFIG_PATH: Optional[str] = None VLLM_LOGITS_PROCESSOR_THREADS: Optional[int] = None + VLLM_LOG_STATS_INTERVAL: float = 10. VLLM_TRACE_FUNCTION: int = 0 VLLM_ATTENTION_BACKEND: Optional[str] = None VLLM_USE_FLASHINFER_SAMPLER: Optional[bool] = None @@ -122,6 +123,7 @@ if TYPE_CHECKING: VLLM_MOE_DP_CHUNK_SIZE: int = 256 VLLM_RANDOMIZE_DP_DUMMY_INPUTS: bool = False VLLM_MARLIN_USE_ATOMIC_ADD: bool = False + VLLM_MXFP4_USE_MARLIN: Optional[bool] = None VLLM_V0_USE_OUTLINES_CACHE: bool = False VLLM_V1_USE_OUTLINES_CACHE: bool = False VLLM_TPU_BUCKET_PADDING_GAP: int = 0 @@ -182,6 +184,12 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]: return int(value) +def maybe_convert_bool(value: Optional[str]) -> Optional[bool]: + if value is None: + return None + return bool(int(value)) + + def get_vllm_port() -> Optional[int]: """Get the port from VLLM_PORT environment variable. @@ -429,6 +437,12 @@ environment_variables: dict[str, Callable[[], Any]] = { lambda: int(os.getenv("VLLM_LOGITS_PROCESSOR_THREADS", "0")) if "VLLM_LOGITS_PROCESSOR_THREADS" in os.environ else None, + # If set, vllm will log stats at this interval in seconds + # If not set, vllm will log stats every 10 seconds. + "VLLM_LOG_STATS_INTERVAL": + lambda: val if (val := float(os.getenv("VLLM_LOG_STATS_INTERVAL", "10."))) + > 0. else 10., + # Trace function calls # If set to 1, vllm will trace function calls # Useful for debugging @@ -906,6 +920,10 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_MARLIN_USE_ATOMIC_ADD": lambda: os.environ.get("VLLM_MARLIN_USE_ATOMIC_ADD", "0") == "1", + # Whether to use marlin kernel in mxfp4 quantization method + "VLLM_MXFP4_USE_MARLIN": + lambda: maybe_convert_bool(os.environ.get("VLLM_MXFP4_USE_MARLIN", None)), + # Whether to turn on the outlines cache for V0 # This cache is unbounded and on disk, so it's not safe to use in # an environment with potentially malicious users. @@ -1090,6 +1108,12 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_USE_TRTLLM_ATTENTION": lambda: os.getenv("VLLM_USE_TRTLLM_ATTENTION", None), + # If set to 1, force the use of TRTLLM FP4 GEMM backend in flashinfer. + # Otherwise, uses the first available of: flashinfer cutlass GEMM, + # vllm cutlass GEMM, marlin GEMM. + "VLLM_USE_TRTLLM_FP4_GEMM": + lambda: bool(int(os.getenv("VLLM_USE_TRTLLM_FP4_GEMM", "0"))), + # Controls garbage collection during CUDA graph capture. # If set to 0 (default), enables GC freezing to speed up capture time. # If set to 1, allows GC to run during capture. @@ -1197,6 +1221,7 @@ def compute_hash() -> str: "VLLM_DP_SIZE", "VLLM_USE_STANDALONE_COMPILE", "VLLM_FUSED_MOE_CHUNK_SIZE", + "VLLM_USE_TRTLLM_FP4_GEMM", ] for key in environment_variables_to_hash: if key in environment_variables: diff --git a/vllm/forward_context.py b/vllm/forward_context.py index 4686ba24e..c57c51d28 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -5,13 +5,13 @@ import time from collections import defaultdict from contextlib import contextmanager from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Union import torch import torch.distributed as dist import vllm.envs as envs -from vllm.config import ParallelConfig, VllmConfig +from vllm.config import CUDAGraphMode, ParallelConfig, VllmConfig from vllm.logger import init_logger if TYPE_CHECKING: @@ -26,6 +26,27 @@ batchsize_logging_interval: float = envs.VLLM_LOG_BATCHSIZE_INTERVAL batchsize_forward_time: defaultdict = defaultdict(list) +class BatchDescriptor(NamedTuple): + """ + Batch descriptor for cudagraph dispatching. We should keep the num of + items as minimal as possible to properly and uniquely describe the padded + batch for cudagraph. + """ + num_tokens: int + uniform_decode: bool = False + """ + False can also be used for an uniform decode batch to dispatch to the + cudagraph supporting non-uniform batches. + """ + + @property + def non_uniform(self) -> "BatchDescriptor": + """ + Return a non-uniform version of current batch descriptor. + """ + return BatchDescriptor(self.num_tokens, uniform_decode=False) + + def _compute_chunked_local_num_tokens(num_tokens_across_dp_cpu: list[int], max_num_tokens: int, chunk_idx: int) -> list[int]: @@ -152,7 +173,15 @@ class ForwardContext: virtual_engine: int # set dynamically for each forward pass # set dynamically for each forward pass dp_metadata: Optional[DPMetadata] = None - skip_cuda_graphs: bool = False + # determine the cudagraph style at runtime to be FULL, PIECEWISE, or NONE. + # by default NONE, no cudagraph is used. + cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE + batch_descriptor: Optional[BatchDescriptor] = None + + def __post_init__(self): + assert self.cudagraph_runtime_mode in [ + CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL], \ + f"Invalid cudagraph runtime mode: {self.cudagraph_runtime_mode}" _forward_context: Optional[ForwardContext] = None @@ -168,13 +197,13 @@ def get_forward_context() -> ForwardContext: @contextmanager def set_forward_context( - attn_metadata: Any, - vllm_config: VllmConfig, - virtual_engine: int = 0, - num_tokens: Optional[int] = None, - num_tokens_across_dp: Optional[torch.Tensor] = None, - skip_cuda_graphs: bool = False, -): + attn_metadata: Any, + vllm_config: VllmConfig, + virtual_engine: int = 0, + num_tokens: Optional[int] = None, + num_tokens_across_dp: Optional[torch.Tensor] = None, + cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, + batch_descriptor: Optional[BatchDescriptor] = None): """A context manager that stores the current forward context, can be attention metadata, etc. Here we can inject common logic for every model forward pass. @@ -198,7 +227,8 @@ def set_forward_context( virtual_engine=virtual_engine, attn_metadata=attn_metadata, dp_metadata=dp_metadata, - skip_cuda_graphs=skip_cuda_graphs, + cudagraph_runtime_mode=cudagraph_runtime_mode, + batch_descriptor=batch_descriptor, ) try: diff --git a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py index 1988c73ba..a49d41c18 100644 --- a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py @@ -18,6 +18,8 @@ from vllm.utils import direct_register_custom_op def fused_marlin_moe(hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, + bias1: Optional[torch.Tensor], + bias2: Optional[torch.Tensor], w1_scale: torch.Tensor, w2_scale: torch.Tensor, gating_output: torch.Tensor, @@ -26,6 +28,7 @@ def fused_marlin_moe(hidden_states: torch.Tensor, quant_type_id: int, apply_router_weight_on_input: bool = False, global_num_experts: int = -1, + activation: Optional[str] = "silu", expert_map: Optional[torch.Tensor] = None, global_scale1: Optional[torch.Tensor] = None, global_scale2: Optional[torch.Tensor] = None, @@ -88,6 +91,7 @@ def fused_marlin_moe(hidden_states: torch.Tensor, assert w2.is_contiguous(), "Expert weights2 must be contiguous" assert hidden_states.dtype in [torch.float16, torch.bfloat16] assert num_bits in [4, 8] + assert topk_weights.dtype == torch.float32 M, K = hidden_states.shape E = w1.shape[0] @@ -138,6 +142,7 @@ def fused_marlin_moe(hidden_states: torch.Tensor, hidden_states, intermediate_cache1, w1, + bias1, w1_scale, global_scale1, w1_zeros, @@ -161,8 +166,28 @@ def fused_marlin_moe(hidden_states: torch.Tensor, use_fp32_reduce=True, is_zp_float=False) - torch.ops._C.silu_and_mul(intermediate_cache2, - intermediate_cache1.view(-1, 2 * N)) + if activation == "silu": + torch.ops._C.silu_and_mul(intermediate_cache2, + intermediate_cache1.view(-1, 2 * N)) + elif activation == "swiglu_oai": + # NOTE: in gpt-oss, the gate_proj and up_proj is interleaved + # - interleaved: gate, up = gate_up[..., ::2], gate_up[..., 1::2] + # - origin: gate, up = gate_up[..., :N], gate_up[..., N:] + + @torch.compile(dynamic=True) + def swiglu_oai(gate_up): + alpha = 1.702 + limit = 7.0 + gate, up = gate_up[..., ::2], gate_up[..., 1::2] + gate = gate.clamp(min=None, max=limit) + up = up.clamp(min=-limit, max=limit) + glu = gate * torch.sigmoid(gate * alpha) + return (up + 1) * glu + + intermediate_cache2 = swiglu_oai(intermediate_cache1) + else: + raise ValueError(f"Unsupported activation: {activation}. " + "Only silu and swiglu_oai activations are supported.") if expert_map is not None: intermediate_cache3.zero_() @@ -171,6 +196,7 @@ def fused_marlin_moe(hidden_states: torch.Tensor, intermediate_cache2, intermediate_cache3, w2, + bias2, w2_scale, global_scale2, w2_zeros, diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 98087a35e..1c497fa55 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1189,10 +1189,10 @@ def flashinfer_fused_moe_per_tensor_scale_fp8( hidden_states: torch.Tensor, input_scale: torch.Tensor, gemm1_weights: torch.Tensor, - gemm1_weights_scale: torch.Tensor, - activation_scale: torch.Tensor, gemm2_weights: torch.Tensor, - gemm2_weights_scale: torch.Tensor, + output1_scales_scalar: torch.Tensor, + output1_scales_gate_scalar: torch.Tensor, + output2_scales_scalar: torch.Tensor, num_experts: int, top_k: int, num_expert_group: Optional[int], @@ -1206,17 +1206,12 @@ def flashinfer_fused_moe_per_tensor_scale_fp8( num_expert_group = num_expert_group if num_expert_group is not None else 0 topk_group = topk_group if topk_group is not None else 0 - quant_hidden_states, input_scale = moe_kernel_quantize_input( + quant_hidden_states, _ = moe_kernel_quantize_input( hidden_states, input_scale, quant_dtype=torch.float8_e4m3fn, per_act_token_quant=False) - output1_scales_scalar = gemm1_weights_scale * input_scale * ( - 1.0 / activation_scale) - output1_scales_gate_scalar = gemm1_weights_scale * input_scale - output2_scales_scalar = activation_scale * gemm2_weights_scale - from vllm.utils.flashinfer import ( flashinfer_trtllm_fp8_per_tensor_scale_moe) return flashinfer_trtllm_fp8_per_tensor_scale_moe( @@ -1244,24 +1239,24 @@ def flashinfer_fused_moe_per_tensor_scale_fp8( def flashinfer_fused_moe_per_tensor_scale_fp8_fake( routing_logits: torch.Tensor, - routing_bias: torch.Tensor, + routing_bias: Optional[torch.Tensor], hidden_states: torch.Tensor, + input_scale: torch.Tensor, gemm1_weights: torch.Tensor, + gemm2_weights: torch.Tensor, output1_scales_scalar: torch.Tensor, output1_scales_gate_scalar: torch.Tensor, - gemm2_weights: torch.Tensor, output2_scales_scalar: torch.Tensor, num_experts: int, top_k: int, - num_expert_group: int, - topk_group: int, + num_expert_group: Optional[int], + topk_group: Optional[int], intermediate_size: int, local_expert_offset: int, local_num_experts: int, - routed_scaling_factor: float = 1.0, - use_routing_scales_on_input: bool = False, - tile_tokens_dim: int = 8, - routing_method_type: int = 0) -> torch.Tensor: + use_routing_scales_on_input: bool, + routing_method_type: int, + routed_scaling_factor: float = 1.0) -> torch.Tensor: pass diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index ddc02168e..36e758258 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -36,7 +36,7 @@ from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.platforms.interface import CpuArchEnum from vllm.utils import (direct_register_custom_op, has_deep_ep, has_pplx, - has_triton_kernels, is_torch_equal_or_newer, round_up) + round_up) from vllm.utils.flashinfer import has_flashinfer if current_platform.is_cuda_alike(): @@ -751,19 +751,11 @@ class FusedMoE(CustomOp): self.global_num_experts = num_experts + num_redundant_experts # we padding globally so EP buffer allocation works - if quant_config and quant_config.get_name() == "mxfp4": - if not current_platform.is_device_capability(100): - if not is_torch_equal_or_newer("2.8.0"): - raise RuntimeError( - "Mxfp4 on non-blackwell requires torch >= 2.8.0") - if not has_triton_kernels(): - raise NotImplementedError( - "triton_kernels must be installed for " - "mxfp4 on non-blackwell") - if (current_platform.is_rocm() - or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8 - or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16): - hidden_size = round_up(hidden_size, 256) + if (quant_config and quant_config.get_name() == "mxfp4" + and (current_platform.is_rocm() + or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8 + or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16)): + hidden_size = round_up(hidden_size, 256) # For smuggling this layer into the fused moe custom op compilation_config = vllm_config.compilation_config diff --git a/vllm/model_executor/layers/mamba/mamba2_metadata.py b/vllm/model_executor/layers/mamba/mamba2_metadata.py index 0a836fd17..3256ac034 100644 --- a/vllm/model_executor/layers/mamba/mamba2_metadata.py +++ b/vllm/model_executor/layers/mamba/mamba2_metadata.py @@ -11,7 +11,7 @@ from vllm.attention.backends.placeholder_attn import ( PlaceholderAttentionMetadata) from vllm.attention.backends.utils import PAD_SLOT_ID from vllm.platforms import current_platform -from vllm.v1.attention.backends.mamba_attn import ( +from vllm.v1.attention.backends.mamba2_attn import ( Mamba2AttentionMetadata, _query_start_loc_to_chunk_indices_offsets) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer.py b/vllm/model_executor/layers/mamba/mamba_mixer.py index 17b7f84a9..3c7322260 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer.py @@ -1,14 +1,15 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional +from typing import NamedTuple, Optional import torch from torch import nn from torch.nn.parameter import Parameter from vllm import envs -from vllm.config import get_current_vllm_config +from vllm.attention.backends.abstract import AttentionMetadata +from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config from vllm.distributed.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.forward_context import ForwardContext, get_forward_context @@ -19,7 +20,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, RowParallelLinear) from vllm.model_executor.layers.mamba.abstract import MambaBase from vllm.model_executor.layers.mamba.mamba_utils import ( - MambaStateShapeCalculator) + MambaStateDtypeCalculator, MambaStateShapeCalculator) from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( causal_conv1d_fn, causal_conv1d_update) from vllm.model_executor.layers.mamba.ops.mamba_ssm import ( @@ -55,6 +56,8 @@ class MambaMixer(MambaBase, CustomOp): rms_norm_eps: float = 1e-5, activation="silu", is_lora_enabled: bool = False, + model_config: Optional[ModelConfig] = None, + cache_config: Optional[CacheConfig] = None, prefix: str = ""): super().__init__() self.time_step_rank = time_step_rank @@ -152,15 +155,42 @@ class MambaMixer(MambaBase, CustomOp): # The inner tuple is (conv_state, ssm_state) self.kv_cache = [(torch.tensor([]), torch.tensor([]))] + self.model_config = model_config + self.cache_config = cache_config self.prefix = prefix + def _ssm_transform( + self, x: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + if self.is_lora_enabled: + # Lora kernel requires contiguous tensor. + ssm_params = self.x_proj(x.contiguous())[0] + else: + ssm_params = self.x_proj(x)[0] + time_step, B, C = torch.split( + ssm_params, + [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], + dim=-1) + if self.use_rms_norm: + assert self.dt_layernorm is not None + assert self.b_layernorm is not None + assert self.c_layernorm is not None + time_step = self.dt_layernorm(time_step.contiguous()) + B = self.b_layernorm(B.contiguous()) + C = self.c_layernorm(C.contiguous()) + discrete_time_step = self.dt_proj(time_step)[0].transpose(-2, -1) + return discrete_time_step, B, C + def forward(self, hidden_states: torch.Tensor, mamba_cache_params: Optional[MambaCacheParams] = None): if not envs.VLLM_USE_V1: return CustomOp.forward(self, hidden_states, mamba_cache_params) else: - return self.forward_cuda(hidden_states, mamba_cache_params) + return self.forward_cuda( + hidden_states, + mamba_cache_params, + ) def forward_native(self, hidden_states: torch.Tensor, @@ -170,6 +200,27 @@ class MambaMixer(MambaBase, CustomOp): def forward_cuda(self, hidden_states: torch.Tensor, mamba_cache_params: Optional[MambaCacheParams] = None): + """ + Run the Mamba-1 SSM pipeline. + + Steps + ----- + 1. Apply the gated-MLP linear projection to the raw input. + 2. Pass the projected sequence through the convolutional mixing layer. + 3. Feed the result into the State-Space Model (SSM) blocks. + 4. Perform the recurrence y ← SSM(A, B, C, Δ)(x) + to produce contextual representations. + 5. Project the contextualised sequence back + to the output embedding dimension. + + Batch handling + -------------- + Prefill and decode tokens are processed by dedicated CUDA + kernels for both the convolutional (conv1d) and SSM stages. + In the case of a mixed batch (containing both prefill and + decode tokens), both sets of kernels are executed independently + and their outputs are concatenated before the final output projection. + """ forward_context: ForwardContext = get_forward_context() attn_metadata = forward_context.attn_metadata @@ -185,126 +236,151 @@ class MambaMixer(MambaBase, CustomOp): self_kv_cache = self.kv_cache[forward_context.virtual_engine] conv_state = self_kv_cache[0].transpose(-1, -2) ssm_state = self_kv_cache[1] - has_initial_state = mamba1_metadata.has_initial_states - context_lens_tensor = mamba1_metadata.context_lens_tensor + has_initial_states = mamba1_metadata.has_initial_states else: + assert isinstance(attn_metadata, AttentionMetadata) assert mamba_cache_params is not None conv_state = mamba_cache_params.conv_state ssm_state = mamba_cache_params.ssm_state state_indices_tensor = mamba_cache_params.state_indices_tensor query_start_loc = attn_metadata.query_start_loc context_lens_tensor = attn_metadata.context_lens_tensor - + has_initial_states = None if context_lens_tensor is not None: - has_initial_state = context_lens_tensor > 0 + has_initial_states = context_lens_tensor > 0 # 1. Gated MLP's linear projection projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1) - hidden_states, gate = projected_states.chunk(2, dim=-2) + hidden_states_BC, gate = projected_states.chunk(2, dim=-2) - # 2. Convolution sequence transformation conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)) if envs.VLLM_USE_V1 and attn_metadata is None: # V1 profile run - hidden_states = hidden_states.contiguous() - return self.out_proj(hidden_states.transpose(-2, -1))[0] + hidden_states_BC = hidden_states_BC.contiguous() + return self.out_proj(hidden_states_BC.transpose(-2, -1))[0] - if query_start_loc is not None and context_lens_tensor is not None: - # |---------- N-1 iteration --------| - # |---------------- N iteration ---------------------| - # |- tokenA -|......................|-- newTokens ---| - # |---------- context_len ----------| - # |-------------------- seq_len ---------------------| - # |-- query_len ---| - hidden_states = causal_conv1d_fn( - hidden_states, + num_prefill_tokens = attn_metadata.num_prefill_tokens # token count + num_decode_tokens = attn_metadata.num_decode_tokens + num_prefills = attn_metadata.num_prefills # request count + num_decodes = attn_metadata.num_decode_tokens # token count (=request) + has_prefill = num_prefill_tokens > 0 + has_decode = num_decode_tokens > 0 + + prefill_decode_split = split_batch_to_prefill_and_decode( + hidden_states_BC, + gate, + state_indices_tensor, + query_start_loc, + has_initial_states, + num_prefill_tokens, + num_decode_tokens, + num_prefills, + num_decodes, + ) + hidden_states_BC_p = prefill_decode_split.hidden_states_BC_p + hidden_states_BC_d = prefill_decode_split.hidden_states_BC_d + gate_p = prefill_decode_split.gate_p + gate_d = prefill_decode_split.gate_d + state_indices_tensor_p = prefill_decode_split.state_indices_tensor_p + state_indices_tensor_d = prefill_decode_split.state_indices_tensor_d + query_start_loc_p = prefill_decode_split.query_start_loc_p + has_initial_states_p = prefill_decode_split.has_initial_states_p + + ssm_outputs = [] + + if has_prefill: + # 2. Convolution sequence transformation + conv_out_p = causal_conv1d_fn( + hidden_states_BC_p, conv_weights, - bias=self.conv1d.bias, + self.conv1d.bias, activation=self.activation, conv_states=conv_state, - has_initial_state=has_initial_state, - cache_indices=state_indices_tensor, - query_start_loc=query_start_loc) - else: - hidden_states = causal_conv1d_update( - hidden_states.transpose(0, 1), + has_initial_state=has_initial_states_p, + cache_indices=state_indices_tensor_p, + query_start_loc=query_start_loc_p) + # 3. State Space Model sequence transformations. + discrete_time_step_p, B_p, C_p = self._ssm_transform( + conv_out_p.transpose(-2, -1)) + time_proj_bias = self._time_proj_bias() + + # 4. Perform the recurrence y ← SSM(A, B, C, Δ)(x) + scan_out_p = selective_scan_fn( + conv_out_p, + ssm_state, + discrete_time_step_p, + self.A, + B_p.transpose(-2, -1), + C_p.transpose(-2, -1), + self.D.float(), + gate_p, + time_proj_bias, + delta_softplus=True, + cache_indices=state_indices_tensor_p, + has_initial_state=has_initial_states_p, + query_start_loc=query_start_loc_p) + ssm_outputs.append(scan_out_p) + + if has_decode: + # 2. Convolution sequence transformation + conv_out_d = causal_conv1d_update( + hidden_states_BC_d.transpose(0, 1), conv_state, conv_weights, self.conv1d.bias, self.activation, - conv_state_indices=state_indices_tensor) - hidden_states = hidden_states.transpose(0, 1) + conv_state_indices=state_indices_tensor_d).transpose(0, 1) - # 3. State Space Model sequence transformation - # 3.a. input varying initialization of time_step, B and C + # 3. State Space Model sequence transformation. + discrete_time_step_d, B_d, C_d = self._ssm_transform( + conv_out_d.transpose(-2, -1)) + time_proj_bias = self._time_proj_bias() - if self.is_lora_enabled: - # lora kernel requires contiguous tensor - ssm_parameters = self.x_proj( - hidden_states.transpose(-2, -1).contiguous())[0] - else: - ssm_parameters = self.x_proj(hidden_states.transpose(-2, -1))[0] - - time_step, B, C = torch.split( - ssm_parameters, - [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], - dim=-1, - ) - if self.use_rms_norm: - assert self.dt_layernorm is not None - assert self.b_layernorm is not None - assert self.c_layernorm is not None - time_step = self.dt_layernorm(time_step.contiguous()) - B = self.b_layernorm(B.contiguous()) - C = self.c_layernorm(C.contiguous()) - - discrete_time_step = self.dt_proj(time_step)[0].transpose(-2, -1) - # 3.c perform the recurrence y ← SSM(A, B, C)(x) - time_proj_bias = (self.dt_proj.bias.float() if hasattr( - self.dt_proj, "bias") else None) - - if query_start_loc is not None and context_lens_tensor is not None: - scan_outputs = selective_scan_fn( - hidden_states, - ssm_state, - discrete_time_step, - self.A, - B.transpose(-2, -1), - C.transpose(-2, -1), - self.D.float(), - gate, - time_proj_bias, - delta_softplus=True, - cache_indices=state_indices_tensor, - has_initial_state=has_initial_state, - query_start_loc=query_start_loc) - else: - scan_outputs = torch.empty_like(hidden_states.transpose(0, 1)) + # 4. Perform the recurrence y ← SSM(A, B, C, Δ)(x) + scan_outputs_d = torch.empty_like( + hidden_states_BC_d.transpose(0, 1)) selective_state_update(ssm_state, - hidden_states.transpose(0, 1), - discrete_time_step.transpose(0, 1), + conv_out_d.transpose(0, 1), + discrete_time_step_d.transpose(0, 1), self.A, - B, - C, + B_d, + C_d, self.D, - gate.transpose(0, 1), + gate_d.transpose(0, 1), time_proj_bias, dt_softplus=True, - state_batch_indices=state_indices_tensor, - out=scan_outputs) - scan_outputs = scan_outputs.transpose(0, 1) + state_batch_indices=state_indices_tensor_d, + out=scan_outputs_d) + scan_outputs_d = scan_outputs_d.transpose(0, 1) - # 4. Final linear projection - if self.is_lora_enabled: - # lora kernel requires contiguous tensor - contextualized_states = self.out_proj( - scan_outputs.transpose(-2, -1).contiguous())[0] + if envs.VLLM_USE_V1: + ssm_outputs.insert(0, scan_outputs_d) + else: + ssm_outputs.append(scan_outputs_d) + + scan_outputs_combined = ssm_outputs[0] if len( + ssm_outputs) == 1 else torch.cat(ssm_outputs, dim=-1) + + # 5. Final output projection + if self.is_lora_enabled: # Lora kernel requires contiguous tensor. + scan_outputs_combined = scan_outputs_combined.transpose( + -2, -1).contiguous() + out = self.out_proj(scan_outputs_combined)[0] else: - contextualized_states = self.out_proj( - scan_outputs.transpose(-2, -1))[0] - return contextualized_states + out = self.out_proj(scan_outputs_combined.transpose(-2, -1))[0] + + return out + + def get_state_dtype(self) -> tuple[torch.dtype]: + assert self.model_config is not None + assert self.cache_config is not None + return MambaStateDtypeCalculator.mamba1_state_dtype( + self.model_config.dtype, + self.cache_config.mamba_cache_dtype, + self.cache_config.mamba_ssm_cache_dtype, + ) def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]: return MambaStateShapeCalculator.mamba1_state_shape( @@ -317,3 +393,69 @@ class MambaMixer(MambaBase, CustomOp): @property def mamba_type(self) -> str: return "mamba1" + + def _time_proj_bias(self) -> Optional[torch.Tensor]: + if hasattr(self.dt_proj, "bias") and self.dt_proj.bias is not None: + return self.dt_proj.bias.float() + return None + + +class PrefillDecodeSplit(NamedTuple): + hidden_states_BC_p: torch.Tensor + hidden_states_BC_d: torch.Tensor + gate_p: torch.Tensor + gate_d: torch.Tensor + state_indices_tensor_p: torch.Tensor + state_indices_tensor_d: torch.Tensor + query_start_loc_p: Optional[torch.Tensor] + has_initial_states_p: Optional[torch.Tensor] + + +def split_batch_to_prefill_and_decode( + hidden_states_BC: torch.Tensor, + gate: torch.Tensor, + state_indices_tensor: torch.Tensor, + query_start_loc: torch.Tensor, + has_initial_states: Optional[torch.Tensor], + num_prefill_tokens: int, + num_decode_tokens: int, + num_prefills: int, + num_decodes: int, +) -> PrefillDecodeSplit: + if envs.VLLM_USE_V1: + # In v1, decode tokens come first, then prefill tokens. + hidden_states_BC_d, hidden_states_BC_p = torch.split( + hidden_states_BC, [num_decode_tokens, num_prefill_tokens], dim=-1) + gate_d, gate_p = torch.split(gate, + [num_decode_tokens, num_prefill_tokens], + dim=-1) + state_indices_tensor_d, state_indices_tensor_p = torch.split( + state_indices_tensor, [num_decodes, num_prefills], dim=0) + query_start_loc_p = (query_start_loc[-num_prefills - 1:] - + num_decodes if num_prefills > 0 else None) + has_initial_states_p = has_initial_states[-num_prefills:] if ( + has_initial_states is not None and num_prefills > 0) else None + else: + # In v0, prefill tokens come first, then decode tokens. + hidden_states_BC_p, hidden_states_BC_d = torch.split( + hidden_states_BC, [num_prefill_tokens, num_decode_tokens], dim=-1) + gate_p, gate_d = torch.split(gate, + [num_prefill_tokens, num_decode_tokens], + dim=-1) + state_indices_tensor_p, state_indices_tensor_d = torch.split( + state_indices_tensor, [num_prefills, num_decodes], dim=0) + query_start_loc_p = (query_start_loc[:num_prefills + + 1] if num_prefills > 0 else None) + has_initial_states_p = has_initial_states[:num_prefills] if ( + has_initial_states is not None and num_prefills > 0) else None + + return PrefillDecodeSplit( + hidden_states_BC_p=hidden_states_BC_p, + hidden_states_BC_d=hidden_states_BC_d, + gate_p=gate_p, + gate_d=gate_d, + state_indices_tensor_p=state_indices_tensor_p, + state_indices_tensor_d=state_indices_tensor_d, + query_start_loc_p=query_start_loc_p, + has_initial_states_p=has_initial_states_p, + ) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index 10a5618c2..743e520ec 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -8,7 +8,7 @@ from torch import nn from vllm import envs from vllm.attention.backends.abstract import AttentionMetadata -from vllm.config import get_current_vllm_config +from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config from vllm.distributed import (divide, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_gather, @@ -21,7 +21,7 @@ from vllm.model_executor.layers.mamba.abstract import MambaBase from vllm.model_executor.layers.mamba.mamba2_metadata import (Mamba2Metadata, update_metadata) from vllm.model_executor.layers.mamba.mamba_utils import ( - MambaStateShapeCalculator) + MambaStateDtypeCalculator, MambaStateShapeCalculator) from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( causal_conv1d_fn, causal_conv1d_update) from vllm.model_executor.layers.mamba.ops.layernorm_gated import rms_norm_gated @@ -36,7 +36,7 @@ from vllm.model_executor.models.mamba_cache import MambaCacheParams from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.utils import direct_register_custom_op -from vllm.v1.attention.backends.mamba_attn import Mamba2AttentionMetadata +from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionMetadata # Added by the IBM Team, 2024 @@ -218,23 +218,23 @@ class MambaMixer2(MambaBase, CustomOp): **selective** state spaces) """ - def __init__( - self, - hidden_size: int, - ssm_state_size: int, - conv_kernel_size: int, - intermediate_size: int, - use_conv_bias: bool, - use_bias: bool, - n_groups: int = 1, - num_heads: int = 128, - head_dim: int = 64, - rms_norm_eps: float = 1e-5, - activation: str = "silu", - use_rms_norm: bool = True, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ): + def __init__(self, + hidden_size: int, + ssm_state_size: int, + conv_kernel_size: int, + intermediate_size: int, + use_conv_bias: bool, + use_bias: bool, + n_groups: int = 1, + num_heads: int = 128, + head_dim: int = 64, + rms_norm_eps: float = 1e-5, + activation: str = "silu", + use_rms_norm: bool = True, + model_config: Optional[ModelConfig] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): super().__init__() # For TP, the sharding plan is as follows: @@ -417,6 +417,8 @@ class MambaMixer2(MambaBase, CustomOp): # The inner tuple is (conv_state, ssm_state) self.kv_cache = [(torch.tensor([]), torch.tensor([]))] + self.model_config = model_config + self.cache_config = cache_config self.prefix = prefix def forward_native( @@ -670,7 +672,7 @@ class MambaMixer2(MambaBase, CustomOp): dt_limit=(0.0, float("inf")), out=preallocated_ssm_out_p.view(1, num_prefill_tokens, -1, self.head_dim), - ) + state_dtype=ssm_state.dtype) # update ssm states # - varlen state is a (num_prefills, nheads, headdim, dstate) tensor @@ -732,6 +734,15 @@ class MambaMixer2(MambaBase, CustomOp): # 5. Final linear projection output[:num_actual_tokens], _ = self.out_proj(hidden_states) + def get_state_dtype(self) -> tuple[torch.dtype, torch.dtype]: + assert self.model_config is not None + assert self.cache_config is not None + return MambaStateDtypeCalculator.mamba2_state_dtype( + self.model_config.dtype, + self.cache_config.mamba_cache_dtype, + self.cache_config.mamba_ssm_cache_dtype, + ) + def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]: return MambaStateShapeCalculator.mamba2_state_shape( intermediate_size=self.intermediate_size, diff --git a/vllm/model_executor/layers/mamba/mamba_utils.py b/vllm/model_executor/layers/mamba/mamba_utils.py index ad1401791..66674d1a6 100644 --- a/vllm/model_executor/layers/mamba/mamba_utils.py +++ b/vllm/model_executor/layers/mamba/mamba_utils.py @@ -1,6 +1,58 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Union + +import torch + +from vllm.config import MambaDType, ModelDType from vllm.distributed import divide +from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, get_kv_cache_torch_dtype + + +class MambaStateDtypeCalculator: + + @classmethod + def linear_attention_state_dtype( + cls, + model_dtype: Union[ModelDType, torch.dtype], + mamba_cache_dtype: MambaDType, + ) -> tuple[torch.dtype, ...]: + # TODO (tdoublep) requires testing + if mamba_cache_dtype == "float32": + raise ValueError("fp32 state for minimax is not yet supported") + state_dtype = get_kv_cache_torch_dtype(mamba_cache_dtype, model_dtype) + return (state_dtype, ) + + @classmethod + def mamba1_state_dtype( + cls, + model_dtype: Union[ModelDType, torch.dtype], + mamba_cache_dtype: MambaDType, + mamba_ssm_cache_dtype: MambaDType, + ) -> tuple[torch.dtype, ...]: + # TODO (tdoublep) requires kernel changes + if mamba_cache_dtype == "float32" or mamba_ssm_cache_dtype == "float32": + raise ValueError("fp32 state for mamba1 is not yet supported") + else: + return MambaStateDtypeCalculator.mamba2_state_dtype( + model_dtype, mamba_cache_dtype, mamba_ssm_cache_dtype) + + @classmethod + def mamba2_state_dtype( + cls, + model_dtype: Union[ModelDType, torch.dtype], + mamba_cache_dtype: MambaDType, + mamba_ssm_cache_dtype: MambaDType, + ) -> tuple[torch.dtype, ...]: + conv_state_dtype = get_kv_cache_torch_dtype(mamba_cache_dtype, + model_dtype) + if mamba_ssm_cache_dtype == "auto": + temporal_state_dtype = conv_state_dtype + else: + temporal_state_dtype = ( + STR_DTYPE_TO_TORCH_DTYPE[mamba_ssm_cache_dtype]) + + return (conv_state_dtype, temporal_state_dtype) class MambaStateShapeCalculator: diff --git a/vllm/model_executor/layers/mamba/ops/ssd_combined.py b/vllm/model_executor/layers/mamba/ops/ssd_combined.py index fd74cb837..d0b3e9e52 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_combined.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_combined.py @@ -41,6 +41,7 @@ def _mamba_chunk_scan_combined_fwd(x, cu_seqlens=None, dt_softplus=False, dt_limit=(0.0, float("inf")), + state_dtype=None, out=None): assert is_int_pow_2(chunk_size), "chunk_size must be integer power of 2" batch, seqlen, nheads, headdim = x.shape @@ -118,7 +119,7 @@ def _mamba_chunk_scan_combined_fwd(x, if initial_states is not None else None, seq_idx=seq_idx, chunk_size=chunk_size, - out_dtype=C.dtype, + out_dtype=state_dtype if state_dtype is not None else C.dtype, is_cont_batched=cu_seqlens is not None) states, final_states = (rearrange(t, "... (p n) -> ... p n", n=dstate) for t in [states, final_states]) @@ -189,7 +190,8 @@ def mamba_chunk_scan_combined(x, dt_limit=(0.0, float("inf")), out=None, return_final_states=False, - return_varlen_states=False): + return_varlen_states=False, + state_dtype=None): """ Argument: x: (batch, seqlen, nheads, headdim) @@ -206,6 +208,7 @@ def mamba_chunk_scan_combined(x, cu_seqlens: (num_sequences + 1) or None, only used if return_varlen_states is True dt_softplus: Whether to apply softplus to dt out: Preallocated output tensor + state_dtype: The data type of the ssm state """ if not return_varlen_states: @@ -229,7 +232,8 @@ def mamba_chunk_scan_combined(x, cu_seqlens=cu_seqlens, dt_softplus=dt_softplus, dt_limit=dt_limit, - out=out) + out=out, + state_dtype=state_dtype) if not return_varlen_states: if not return_final_states: return diff --git a/vllm/model_executor/layers/quantization/awq_marlin.py b/vllm/model_executor/layers/quantization/awq_marlin.py index 6cf02658a..ed7ffb21e 100644 --- a/vllm/model_executor/layers/quantization/awq_marlin.py +++ b/vllm/model_executor/layers/quantization/awq_marlin.py @@ -25,7 +25,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import ( apply_awq_marlin_linear, awq_to_marlin_zero_points, check_marlin_supported, check_marlin_supports_layer, check_moe_marlin_supports_layer, marlin_make_empty_g_idx, marlin_make_workspace_new, - marlin_moe_permute_scales, marlin_permute_scales, + marlin_moe_permute_scales, marlin_permute_bias, marlin_permute_scales, moe_awq_to_marlin_zero_points, verify_marlin_supported, verify_marlin_supports_shape) from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead @@ -303,6 +303,9 @@ class AWQMarlinLinearMethod(LinearMethodBase): layer.g_idx = marlin_make_empty_g_idx(device) layer.g_idx_sort_indices = marlin_make_empty_g_idx(device) + if hasattr(layer, "bias") and layer.bias is not None: + layer.bias.data = marlin_permute_bias(layer.bias) + def apply( self, layer: torch.nn.Module, @@ -469,6 +472,12 @@ class AWQMoEMethod(FusedMoEMethodBase): num_bits=self.quant_config.weight_bits) replace_parameter(layer, "w2_qzeros", marlin_w2_zp) + if hasattr(layer, "w13_bias") and layer.w13_bias is not None: + layer.w13_bias.data = marlin_permute_bias(layer.w13_bias) + + if hasattr(layer, "w2_bias") and layer.w2_bias is not None: + layer.w2_bias.data = marlin_permute_bias(layer.w2_bias) + def apply( self, layer: torch.nn.Module, @@ -513,6 +522,8 @@ class AWQMoEMethod(FusedMoEMethodBase): x, layer.w13_qweight, layer.w2_qweight, + getattr(layer, "w13_bias", None), + getattr(layer, "w2_bias", None), layer.w13_scales, layer.w2_scales, router_logits, diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index 69bced7c0..637a84372 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -192,7 +192,15 @@ class CompressedTensorsConfig(QuantizationConfig): quant_config.get("weights")) target_scheme_map[target]["input_activations"] = None - if is_activation_quantization_format(quant_format): + target_scheme_map[target]["format"] = quant_config.get( + "format") + format = target_scheme_map[target].get("format") + # If no per-config format defined, use global format in config + act_quant_format = is_activation_quantization_format( + format + ) if format is not None else is_activation_quantization_format( + quant_format) + if act_quant_format: input_activations = quant_config.get("input_activations") # The only case where we have activation quant supported # but no input_activations provided in the config @@ -389,8 +397,10 @@ class CompressedTensorsConfig(QuantizationConfig): return (is_channel_group and input_quant_none and is_static) def _get_scheme_from_parts( - self, weight_quant: BaseModel, - input_quant: BaseModel) -> "CompressedTensorsScheme": + self, + weight_quant: BaseModel, + input_quant: BaseModel, + format: Optional[str] = None) -> "CompressedTensorsScheme": # Detect If Mixed Precision if self._is_fp4a16_nvfp4(weight_quant, input_quant): return CompressedTensorsW4A16Fp4() @@ -412,7 +422,11 @@ class CompressedTensorsConfig(QuantizationConfig): group_size=weight_quant.group_size, actorder=weight_quant.actorder) - if is_activation_quantization_format(self.quant_format): + act_quant_format = is_activation_quantization_format( + format + ) if format is not None else is_activation_quantization_format( + self.quant_format) + if act_quant_format: if self._is_fp4a4_nvfp4(weight_quant, input_quant): if cutlass_fp4_supported( ) or envs.VLLM_USE_NVFP4_CT_EMULATIONS: @@ -507,6 +521,7 @@ class CompressedTensorsConfig(QuantizationConfig): scheme_dict = self.target_scheme_map[matched_target] weight_quant = scheme_dict.get("weights") input_quant = scheme_dict.get("input_activations") + format = scheme_dict.get("format") # Find the sparsity scheme of the layer # assume that fused layers inerhit first component's sparsity scheme @@ -547,7 +562,7 @@ class CompressedTensorsConfig(QuantizationConfig): scheme = self._get_scheme_from_parts( # type: ignore weight_quant=weight_quant, input_quant=input_quant, - ) + format=format) # Raise error if device does not support the scheme # (e.g. fp8 needs ada lovelace) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index c04f7c39a..839942bea 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -324,6 +324,8 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): x, layer.w13_weight, layer.w2_weight, + None, + None, layer.w13_weight_scale, layer.w2_weight_scale, router_logits, @@ -795,6 +797,8 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): x, layer.w13_weight, layer.w2_weight, + None, + None, layer.w13_weight_scale, layer.w2_weight_scale, router_logits, @@ -1253,6 +1257,8 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): x, layer.w13_weight_packed, layer.w2_weight_packed, + None, + None, layer.w13_weight_scale, layer.w2_weight_scale, router_logits, diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py index 8ba721629..63bfe565b 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py @@ -15,6 +15,7 @@ from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import from vllm.model_executor.parameter import (GroupQuantScaleParameter, ModelWeightParameter, PerTensorScaleParameter) +from vllm.utils.flashinfer import flashinfer_scaled_fp4_mm, has_flashinfer logger = init_logger(__name__) @@ -24,6 +25,13 @@ __all__ = ["CompressedTensorsW4A4Fp4"] class CompressedTensorsW4A4Fp4(CompressedTensorsScheme): def __init__(self): + if envs.VLLM_USE_TRTLLM_FP4_GEMM: + assert has_flashinfer(), "TRTLLM FP4 GEMM requires FlashInfer" + self.backend = "flashinfer-trtllm" + elif has_flashinfer(): + self.backend = "flashinfer-cutlass" + else: + self.backend = "cutlass" self.group_size = 16 @classmethod @@ -108,16 +116,36 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme): layer.weight_global_scale.max().to(torch.float32), requires_grad=False) - swizzled_weight_scale = self.swizzle_blockscale(layer.weight_scale) - layer.weight_scale_swizzled = Parameter(swizzled_weight_scale, - requires_grad=False) + if self.backend == "flashinfer-trtllm": + # FlashInfer TRTLLM FP4 GEMM requires a different weight layout. + # FlashInfer provides nvfp4_quantize to quantize + shuffle the + # layout but we use our own quantization so we have to call + # shuffles ourselves. + from flashinfer import shuffle_matrix_a, shuffle_matrix_sf_a - # required by cutlass kernel; need Parameter, not ModelWeightParameter - layer.weight = Parameter(layer.weight_packed.data, requires_grad=False) + weight = layer.weight_packed.data + weight_scale = layer.weight_scale.data - layer.alpha = Parameter(layer.input_global_scale * - layer.weight_global_scale, - requires_grad=False) + epilogue_tile_m = 128 + weight = shuffle_matrix_a(weight.view(torch.uint8), + epilogue_tile_m) + weight_scale = (shuffle_matrix_sf_a(weight_scale.view( + torch.uint8), epilogue_tile_m).reshape( + weight_scale.shape).view(torch.float8_e4m3fn)) + + layer.weight_scale_swizzled = Parameter(weight_scale, + requires_grad=False) + layer.weight_packed = Parameter(weight, requires_grad=False) + else: + swizzled_weight_scale = self.swizzle_blockscale(layer.weight_scale) + layer.weight_scale_swizzled = Parameter(swizzled_weight_scale, + requires_grad=False) + layer.weight_packed = Parameter(layer.weight_packed.data, + requires_grad=False) + + layer.alpha = Parameter( + 1 / (layer.input_global_scale * layer.weight_global_scale), + requires_grad=False) def apply_weights(self, layer: torch.nn.Module, @@ -128,7 +156,7 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme): out = run_nvfp4_emulations( x=x, input_global_scale=layer.input_global_scale, - weight=layer.weight, + weight=layer.weight_packed, weight_scale_swizzled=layer.weight_scale_swizzled, weight_global_scale=layer.weight_global_scale) if bias is not None: @@ -136,14 +164,20 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme): return out output_dtype = x.dtype - output_shape = [x.shape[0], layer.weight.shape[0]] + output_shape = [x.shape[0], layer.weight_packed.shape[0]] # quantize BF16 or FP16 to (FP4 and interleaved block scale) x_fp4, x_blockscale = scaled_fp4_quant(x, layer.input_global_scale) - out = cutlass_scaled_fp4_mm(x_fp4, layer.weight, x_blockscale, - layer.weight_scale_swizzled, - 1 / layer.alpha, output_dtype) + mm_args = (x_fp4, layer.weight_packed, x_blockscale, + layer.weight_scale_swizzled, layer.alpha, output_dtype) + if self.backend == "flashinfer-trtllm": + out = flashinfer_scaled_fp4_mm(*mm_args, backend="trtllm") + elif self.backend == "flashinfer-cutlass": + out = flashinfer_scaled_fp4_mm(*mm_args, backend="cutlass") + else: + out = cutlass_scaled_fp4_mm(*mm_args) + if bias is not None: out = out + bias return out.view(*output_shape) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 9577fa025..dbd523428 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -24,8 +24,8 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( - apply_flashinfer_per_tensor_scale_fp8, rotate_flashinfer_fp8_moe_weights, - swap_w13_to_w31) + apply_flashinfer_per_tensor_scale_fp8, register_moe_scaling_factors, + rotate_flashinfer_fp8_moe_weights, swap_w13_to_w31) from vllm.model_executor.layers.quantization.utils.fp8_utils import ( get_col_major_tma_aligned_tensor, requant_weight_ue8m0_inplace) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( @@ -694,6 +694,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): w2_weight = layer.w2_weight.data w2_weight_scale_inv = layer.w2_weight_scale_inv.data if not self.block_quant: + register_moe_scaling_factors(layer) rotate_flashinfer_fp8_moe_weights(w13_weight, w2_weight) else: w13_weight = layer.w13_weight.data @@ -983,6 +984,8 @@ class Fp8MoEMethod(FusedMoEMethodBase): x, layer.w13_weight, layer.w2_weight, + None, + None, layer.w13_weight_scale, layer.w2_weight_scale, router_logits, diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index 9bed5e2e4..3299221e3 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -24,7 +24,7 @@ from vllm.model_executor.layers.quantization.utils.gptq_utils import ( get_dynamic_override, get_linear_quant_method, override_config) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( check_marlin_supported, check_moe_marlin_supports_layer, - marlin_make_workspace_new, marlin_moe_permute_scales, + marlin_make_workspace_new, marlin_moe_permute_scales, marlin_permute_bias, marlin_repeat_scales_on_all_ranks, verify_marlin_supported) from vllm.model_executor.parameter import (ChannelQuantScaleParameter, GroupQuantScaleParameter, @@ -618,6 +618,12 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): ) replace_parameter(layer, "w2_scales", marlin_w2_scales) + if hasattr(layer, "w13_bias") and layer.w13_bias is not None: + layer.w13_bias.data = marlin_permute_bias(layer.w13_bias) + + if hasattr(layer, "w2_bias") and layer.w2_bias is not None: + layer.w2_bias.data = marlin_permute_bias(layer.w2_bias) + def apply( self, layer: torch.nn.Module, @@ -662,6 +668,8 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): x, layer.w13_qweight, layer.w2_qweight, + getattr(layer, "w13_bias", None), + getattr(layer, "w2_bias", None), layer.w13_scales, layer.w2_scales, router_logits, diff --git a/vllm/model_executor/layers/quantization/hqq_marlin.py b/vllm/model_executor/layers/quantization/hqq_marlin.py index ee8a0e34b..8385ccac3 100644 --- a/vllm/model_executor/layers/quantization/hqq_marlin.py +++ b/vllm/model_executor/layers/quantization/hqq_marlin.py @@ -14,7 +14,7 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, - marlin_make_empty_g_idx, marlin_permute_scales) + marlin_make_empty_g_idx, marlin_permute_bias, marlin_permute_scales) from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( MarlinWorkspace) from vllm.model_executor.layers.quantization.utils.quant_utils import gptq_pack @@ -284,6 +284,9 @@ class HQQMarlinMethod(LinearMethodBase): layer.marlin_zeros = marlin_zp layer.marlin_scales = marlin_s + if hasattr(layer, "bias") and layer.bias is not None: + layer.bias.data = marlin_permute_bias(layer.bias) + def apply( self, layer: torch.nn.Module, @@ -307,6 +310,7 @@ class HQQMarlinMethod(LinearMethodBase): x, None, layer.marlin_qweight, + bias, scales, None, zeros, @@ -326,7 +330,4 @@ class HQQMarlinMethod(LinearMethodBase): if orig_type != torch.float16: marlin_out = marlin_out.to(orig_type) - if bias is not None: - marlin_out.add_(bias) - return marlin_out diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/marlin.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/marlin.py index 73e0b17ea..5eb993830 100644 --- a/vllm/model_executor/layers/quantization/kernels/mixed_precision/marlin.py +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/marlin.py @@ -9,8 +9,9 @@ from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.utils.marlin_utils import ( MARLIN_SUPPORTED_GROUP_SIZES, apply_gptq_marlin_linear, check_marlin_supports_shape, marlin_is_k_full, marlin_make_empty_g_idx, - marlin_make_workspace_new, marlin_permute_scales, marlin_sort_g_idx, - marlin_zero_points, query_marlin_supported_quant_types, unpack_cols) + marlin_make_workspace_new, marlin_permute_bias, marlin_permute_scales, + marlin_sort_g_idx, marlin_zero_points, query_marlin_supported_quant_types, + unpack_cols) from vllm.model_executor.parameter import (BasevLLMParameter, permute_param_layout_) from vllm.platforms import current_platform @@ -111,6 +112,9 @@ class MarlinLinearKernel(MPLinearKernel): self._transform_param(layer, self.w_q_name, transform_w_q) self._transform_param(layer, self.w_s_name, transform_w_s) + if hasattr(layer, "bias") and layer.bias is not None: + layer.bias.data = marlin_permute_bias(layer.bias) + def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index bed502226..22fbbab00 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -25,8 +25,8 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( build_flashinfer_fp4_cutlass_moe_kernel, flashinfer_fp4_cutlass_moe_forward, reorder_w1w3_to_w3w1) from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( - apply_flashinfer_per_tensor_scale_fp8, rotate_flashinfer_fp8_moe_weights, - swap_w13_to_w31) + apply_flashinfer_per_tensor_scale_fp8, register_moe_scaling_factors, + rotate_flashinfer_fp8_moe_weights, swap_w13_to_w31) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( apply_fp4_marlin_linear, is_fp4_marlin_supported, prepare_fp4_layer_for_marlin, prepare_moe_fp4_layer_for_marlin) @@ -38,7 +38,8 @@ from vllm.model_executor.parameter import (ModelWeightParameter, PerTensorScaleParameter) from vllm.scalar_type import scalar_types from vllm.utils import next_power_of_2 -from vllm.utils.flashinfer import has_flashinfer_moe +from vllm.utils.flashinfer import (flashinfer_scaled_fp4_mm, has_flashinfer, + has_flashinfer_moe) logger = init_logger(__name__) @@ -429,6 +430,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data) rotate_flashinfer_fp8_moe_weights(layer.w13_weight, layer.w2_weight) + register_moe_scaling_factors(layer) def apply( self, @@ -724,16 +726,20 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase): def __init__(self, quant_config: ModelOptNvFp4Config) -> None: self.quant_config = quant_config - self.cutlass_nvfp4_supported = cutlass_fp4_supported() - self.use_marlin = False - if not self.cutlass_nvfp4_supported: - if is_fp4_marlin_supported(): - self.use_marlin = True - else: - raise ValueError("Current platform does not support NVFP4" - " quantization. Please use Blackwell and" - " above.") + if envs.VLLM_USE_TRTLLM_FP4_GEMM: + assert has_flashinfer(), "TRTLLM FP4 GEMM requires FlashInfer" + self.backend = "flashinfer-trtllm" + elif has_flashinfer(): + self.backend = "flashinfer-cutlass" + elif cutlass_fp4_supported(): + self.backend = "cutlass" + elif is_fp4_marlin_supported(): + self.backend = "marlin" + else: + raise ValueError("Current platform does not support NVFP4" + " quantization. Please use Blackwell and" + " above.") def create_weights( self, @@ -815,17 +821,38 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase): # block_size = 16; assert (layer.weight_scale.dtype == torch.float8_e4m3fn), ( "Weight Block scale must be represented as FP8-E4M3") - swizzled_weight_scale = swizzle_blockscale(layer.weight_scale) - layer.weight_scale_swizzled = Parameter(swizzled_weight_scale, - requires_grad=False) - layer.weight = Parameter(layer.weight.data, requires_grad=False) + if self.backend == "flashinfer-trtllm": + # FlashInfer TRTLLM FP4 GEMM requires a different weight layout. + # FlashInfer provides nvfp4_quantize to quantize + shuffle the + # layout but we use our own quantization so we have to call + # shuffles ourselves. + from flashinfer import shuffle_matrix_a, shuffle_matrix_sf_a - if self.use_marlin: - prepare_fp4_layer_for_marlin(layer) - del layer.alpha - del layer.input_scale - del layer.weight_scale_swizzled + weight = layer.weight.data + weight_scale = layer.weight_scale.data + + epilogue_tile_m = 128 + weight = shuffle_matrix_a(weight.view(torch.uint8), + epilogue_tile_m) + weight_scale = (shuffle_matrix_sf_a(weight_scale.view( + torch.uint8), epilogue_tile_m).reshape( + weight_scale.shape).view(torch.float8_e4m3fn)) + + layer.weight_scale_swizzled = Parameter(weight_scale, + requires_grad=False) + layer.weight = Parameter(weight, requires_grad=False) + else: + swizzled_weight_scale = swizzle_blockscale(layer.weight_scale) + layer.weight_scale_swizzled = Parameter(swizzled_weight_scale, + requires_grad=False) + layer.weight = Parameter(layer.weight.data, requires_grad=False) + + if self.backend == "marlin": + prepare_fp4_layer_for_marlin(layer) + del layer.alpha + del layer.input_scale + del layer.weight_scale_swizzled def apply( self, @@ -833,7 +860,7 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase): x: torch.Tensor, bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: - if self.use_marlin: + if self.backend == "marlin": return apply_fp4_marlin_linear( input=x, weight=layer.weight, @@ -859,9 +886,21 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase): assert (layer.weight_scale_swizzled.dtype == torch.float8_e4m3fn) assert (layer.alpha.dtype == torch.float32) - out = cutlass_scaled_fp4_mm(x_fp4, layer.weight, x_blockscale, - layer.weight_scale_swizzled, layer.alpha, - output_dtype) + mm_args = ( + x_fp4, + layer.weight, + x_blockscale, + layer.weight_scale_swizzled, + layer.alpha, + output_dtype, + ) + if self.backend == "flashinfer-trtllm": + out = flashinfer_scaled_fp4_mm(*mm_args, backend="trtllm") + elif self.backend == "flashinfer-cutlass": + out = flashinfer_scaled_fp4_mm(*mm_args, backend="cutlass") + else: + out = cutlass_scaled_fp4_mm(*mm_args) + if bias is not None: out = out + bias return out.view(*output_shape) @@ -1330,6 +1369,8 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): x, layer.w13_weight, layer.w2_weight, + None, + None, layer.w13_weight_scale, layer.w2_weight_scale, router_logits, diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py index 03fbcf158..dbe6c603c 100644 --- a/vllm/model_executor/layers/quantization/mxfp4.py +++ b/vllm/model_executor/layers/quantization/mxfp4.py @@ -15,13 +15,17 @@ from vllm.model_executor.layers.linear import (LinearBase, from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) +from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( + prepare_moe_fp4_layer_for_marlin) from vllm.model_executor.layers.quantization.utils.mxfp4_utils import ( _can_support_mxfp4, _swizzle_mxfp4) from vllm.model_executor.layers.quantization.utils.quant_utils import ( is_layer_skipped) from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform -from vllm.utils import next_power_of_2, round_up +from vllm.scalar_type import scalar_types +from vllm.utils import (has_triton_kernels, is_torch_equal_or_newer, + next_power_of_2, round_up) if (envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8 or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16): @@ -81,6 +85,21 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): super().__init__() self.topk_indices_dtype = None self.moe = moe + self.use_marlin = self._should_use_marlin() + + def _should_use_marlin(self): + if envs.VLLM_MXFP4_USE_MARLIN is not None: + return envs.VLLM_MXFP4_USE_MARLIN + if current_platform.is_cuda() and \ + not current_platform.has_device_capability(100): + if not current_platform.is_device_capability(90): + # marlin kernel has better performance on ampere + return True + if not has_triton_kernels(): + return True + if not is_torch_equal_or_newer("2.8.0"): + return True + return False def create_weights(self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, @@ -101,11 +120,29 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): intermediate_size_per_partition_after_pad = \ intermediate_size_per_partition - # pad the intermediate size to be a multiple of 2 * mxfp4_block - # for to hold non-uniform sharded tensor as well as swizzling - # other padding to increase performance - if (envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8 - or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16): + if self.use_marlin: + # The moe marlin kernel requires that for each linear + # n % 256 == 0 and k % 128 == 0. + # In gate_up_proj: + # n = 2 * intermediate_size_per_partition_after_pad + # k = hidden_size + # In down_proj + # n = hidden_size + # k = intermediate_size_per_partition_after_pad + intermediate_size_per_partition_after_pad = round_up( + intermediate_size_per_partition, 128) + hidden_size = round_up(hidden_size, 256) + + layer.params_dtype = params_dtype + layer.num_experts = num_experts + layer.hidden_size = hidden_size + layer.intermediate_size_per_partition = \ + intermediate_size_per_partition_after_pad + elif (envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8 + or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16): + # pad the intermediate size to be a multiple of 2 * mxfp4_block + # for to hold non-uniform sharded tensor as well as swizzling + # other padding to increase performance intermediate_size_per_partition_after_pad = round_up( intermediate_size_per_partition, 256) hidden_size = round_up(hidden_size, 256) @@ -191,8 +228,10 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): set_weight_attrs(w2_bias, extra_weight_attrs) def process_weights_after_loading(self, layer): - if (envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8 - or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16): + if self.use_marlin: + prepare_moe_fp4_layer_for_marlin(layer) + elif (envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8 + or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16): layer.gemm1_alpha = Parameter(torch.tensor( [1.702] * self.num_experts, dtype=torch.float32).cuda(), requires_grad=False) @@ -399,13 +438,45 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): if enable_eplb: raise NotImplementedError("EPLB is not supported for mxfp4") + if self.use_marlin: + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias) + + return torch.ops.vllm.fused_marlin_moe( + x, + layer.w13_weight, + layer.w2_weight, + layer.w13_bias, + layer.w2_bias, + layer.w13_weight_scale, + layer.w2_weight_scale, + router_logits, + topk_weights, + topk_ids, + global_scale1=None, + global_scale2=None, + quant_type_id=scalar_types.float4_e2m1f.id, + apply_router_weight_on_input=apply_router_weight_on_input, + global_num_experts=global_num_experts, + activation=activation, + expert_map=expert_map) + assert _can_support_mxfp4( use_grouped_topk, topk_group, num_expert_group, expert_map, custom_routing_function, e_score_correction_bias, apply_router_weight_on_input, scoring_func, activation, expert_load_view, logical_to_physical_map, - logical_replica_count), ("MXFP4 are not supported\ - with this configuration.") + logical_replica_count), ( + "MXFP4 are not supported with this configuration.") if (envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8 or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16): diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py index 9fb194767..278ee5232 100644 --- a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py +++ b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py @@ -82,6 +82,12 @@ def apply_flashinfer_per_tensor_scale_fp8( apply_router_weight_on_input: bool, ) -> torch.Tensor: from flashinfer.fused_moe import RoutingMethodType + assert layer.output1_scales_scalar is not None, ( + "Expected output1_scales_scalar to be initialized") + assert layer.output1_scales_scalar is not None, ( + "Expected output1_scales_gate_scalar to be initialized") + assert layer.output1_scales_scalar is not None, ( + "Expected output2_scales_scalar to be initialized") from vllm.model_executor.models.llama4 import Llama4MoE assert layer.custom_routing_function == Llama4MoE.custom_routing_function, \ @@ -92,10 +98,10 @@ def apply_flashinfer_per_tensor_scale_fp8( hidden_states=hidden_states, input_scale=layer.w13_input_scale, gemm1_weights=layer.w13_weight, - gemm1_weights_scale=layer.w13_weight_scale, gemm2_weights=layer.w2_weight, - gemm2_weights_scale=layer.w2_weight_scale, - activation_scale=layer.w2_input_scale, + output1_scales_scalar=layer.output1_scales_scalar, + output1_scales_gate_scalar=layer.output1_scales_gate_scalar, + output2_scales_scalar=layer.output2_scales_scalar, num_experts=global_num_experts, top_k=top_k, num_expert_group=num_expert_group, @@ -105,4 +111,36 @@ def apply_flashinfer_per_tensor_scale_fp8( local_num_experts=layer.local_num_experts, use_routing_scales_on_input=apply_router_weight_on_input, routing_method_type=RoutingMethodType.Llama4, - ) \ No newline at end of file + ) + + +def get_moe_scaling_factors( + input_scale: torch.Tensor, + gemm1_weights_scale: torch.Tensor, + activation_scale: torch.Tensor, + gemm2_weights_scale: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + output1_scales_scalar = gemm1_weights_scale * input_scale * ( + 1.0 / activation_scale) + output1_scales_gate_scalar = gemm1_weights_scale * input_scale + output2_scales_scalar = activation_scale * gemm2_weights_scale + + return output1_scales_scalar, output1_scales_gate_scalar, \ + output2_scales_scalar + + +def register_moe_scaling_factors(layer: torch.nn.Module) -> None: + output1_scales, output1_gate_scales, output2_scales = \ + get_moe_scaling_factors( + layer.w13_input_scale, layer.w13_weight_scale, + layer.w2_input_scale, layer.w2_weight_scale + ) + layer.register_parameter( + 'output1_scales_scalar', + torch.nn.Parameter(output1_scales, requires_grad=False)) + layer.register_parameter( + 'output1_scales_gate_scalar', + torch.nn.Parameter(output1_gate_scales, requires_grad=False)) + layer.register_parameter( + 'output2_scales_scalar', + torch.nn.Parameter(output2_scales, requires_grad=False)) diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils.py b/vllm/model_executor/layers/quantization/utils/marlin_utils.py index 7540a1516..02057b476 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils.py @@ -261,6 +261,13 @@ def marlin_permute_scales(s: torch.Tensor, size_k: int, size_n: int, return s +def marlin_permute_bias(s: torch.Tensor) -> torch.Tensor: + origin_shape = s.shape + _, scale_perm_single = get_scale_perms() + s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single] + return s.reshape(*origin_shape).contiguous() + + def marlin_moe_permute_scales( s: torch.Tensor, size_k: int, @@ -410,6 +417,7 @@ def apply_gptq_marlin_linear( output = ops.gptq_marlin_gemm(reshaped_x, None, weight, + bias, weight_scale, None, weight_zp, @@ -425,9 +433,6 @@ def apply_gptq_marlin_linear( use_fp32_reduce=use_fp32_reduce, is_zp_float=False) - if bias is not None: - output.add_(bias) # In-place add - return output.reshape(out_shape) @@ -456,6 +461,7 @@ def apply_awq_marlin_linear( output = ops.gptq_marlin_gemm(reshaped_x, None, weight, + bias, weight_scale, None, weight_zp, @@ -470,7 +476,4 @@ def apply_awq_marlin_linear( use_fp32_reduce=use_fp32_reduce, is_zp_float=False) - if bias is not None: - output.add_(bias) # In-place add - return output.reshape(out_shape) diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py index ca10db69d..94ffdcd26 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py @@ -8,8 +8,8 @@ import torch import vllm._custom_ops as ops from vllm.logger import init_logger from vllm.model_executor.layers.quantization.utils.marlin_utils import ( - USE_FP32_REDUCE_DEFAULT, marlin_make_workspace_new, marlin_permute_scales, - should_use_atomic_add_reduce) + USE_FP32_REDUCE_DEFAULT, marlin_make_workspace_new, marlin_permute_bias, + marlin_permute_scales, should_use_atomic_add_reduce) from vllm.platforms import current_platform from vllm.scalar_type import scalar_types @@ -22,7 +22,7 @@ def is_fp4_marlin_supported(): return current_platform.has_device_capability(80) -def fp4_marlin_process_scales(marlin_scales): +def nvfp4_marlin_process_scales(marlin_scales): if not (marlin_scales >= 0).all(): logger.warning_once( "NVFP4 Marlin assumes the scales to be >=0, but has encountered " @@ -56,7 +56,20 @@ def fp4_marlin_process_scales(marlin_scales): return marlin_scales -def fp4_marlin_process_global_scale(global_scale): +def mxfp4_marlin_process_scales(marlin_scales): + # 8 is the number of scale number using by one thread + marlin_scales = marlin_scales.view(marlin_scales.size(0) // 2, 2, -1, 8) + marlin_scales = marlin_scales.permute(0, 2, 1, 3).reshape( + marlin_scales.size(0) * 2, -1) + + # fit the layout of fp8 dequantization + marlin_scales = marlin_scales.view(-1, 4)[:, [0, 2, 1, 3]].view( + marlin_scales.size(0), -1) + marlin_scales = marlin_scales.to(torch.float8_e8m0fnu) + return marlin_scales + + +def nvfp4_marlin_process_global_scale(global_scale): assert global_scale.dtype in [torch.half, torch.bfloat16] fp4_exponent = 2 if global_scale.dtype == torch.half: @@ -73,7 +86,7 @@ def apply_fp4_marlin_linear( input: torch.Tensor, weight: torch.Tensor, weight_scale: torch.Tensor, - weight_scale_2: torch.Tensor, + weight_scale_2: Optional[torch.Tensor], workspace: torch.Tensor, size_n: int, size_k: int, @@ -94,6 +107,7 @@ def apply_fp4_marlin_linear( output = ops.gptq_marlin_gemm(a=reshaped_x, c=None, b_q_weight=weight, + b_bias=bias, b_scales=weight_scale, global_scale=weight_scale_2, b_zeros=None, @@ -107,9 +121,6 @@ def apply_fp4_marlin_linear( use_atomic_add=use_atomic_add, use_fp32_reduce=use_fp32_reduce) - if bias is not None: - output.add_(bias) # In-place add - return output.reshape(out_shape) @@ -120,6 +131,9 @@ def prepare_fp4_layer_for_marlin(layer: torch.nn.Module) -> None: "be used leveraging the Marlin kernel. This may degrade " "performance for compute-heavy workloads.") + is_nvfp4 = hasattr(layer, "weight_scale_2") + group_size = 16 if is_nvfp4 else 32 + part_size_n = layer.output_size_per_partition part_size_k = layer.input_size_per_partition param_dtype = layer.params_dtype @@ -145,18 +159,35 @@ def prepare_fp4_layer_for_marlin(layer: torch.nn.Module) -> None: # WEIGHT SCALES # Permute scales - weight_scale = layer.weight_scale.T.to(param_dtype) + weight_scale = layer.weight_scale.T.contiguous() + + if not is_nvfp4: + weight_scale = weight_scale.view(torch.float8_e8m0fnu) + + weight_scale = weight_scale.to(param_dtype) weight_scale = marlin_permute_scales(s=weight_scale, size_k=part_size_k, size_n=part_size_n, - group_size=16) - weight_scale = fp4_marlin_process_scales(weight_scale) - layer.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False) + group_size=group_size) - weight_scale_2 = layer.weight_scale_2.to(param_dtype) - weight_scale_2 = fp4_marlin_process_global_scale(weight_scale_2) - layer.weight_scale_2 = torch.nn.Parameter(weight_scale_2, - requires_grad=False) + if is_nvfp4: + weight_scale = nvfp4_marlin_process_scales(weight_scale) + layer.weight_scale = torch.nn.Parameter(weight_scale, + requires_grad=False) + + weight_scale_2 = layer.weight_scale_2.to(param_dtype) + weight_scale_2 = nvfp4_marlin_process_global_scale(weight_scale_2) + layer.weight_scale_2 = torch.nn.Parameter(weight_scale_2, + requires_grad=False) + else: + weight_scale = mxfp4_marlin_process_scales(weight_scale) + layer.weight_scale = torch.nn.Parameter(weight_scale, + requires_grad=False) + + if hasattr(layer, "bias") and layer.bias is not None: + assert layer.bias.shape == (part_size_n, ) + bias = marlin_permute_bias(layer.bias) + layer.bias = torch.nn.Parameter(bias, requires_grad=False) return @@ -168,6 +199,9 @@ def prepare_moe_fp4_layer_for_marlin(layer: torch.nn.Module) -> None: "be used leveraging the Marlin kernel. This may degrade " "performance for compute-heavy workloads.") + is_nvfp4 = hasattr(layer, "w13_weight_scale_2") + group_size = 16 if is_nvfp4 else 32 + e = layer.num_experts k = layer.hidden_size n = layer.intermediate_size_per_partition @@ -208,8 +242,13 @@ def prepare_moe_fp4_layer_for_marlin(layer: torch.nn.Module) -> None: # WEIGHT SCALES # Permute scales for name in ["w13", "w2"]: - scales = getattr(layer, name + "_weight_scale").to(param_dtype) - global_scale = getattr(layer, name + "_weight_scale_2").to(param_dtype) + scales = getattr(layer, name + "_weight_scale") + if not is_nvfp4: + scales = scales.view(torch.float8_e8m0fnu) + scales = scales.to(param_dtype) + if is_nvfp4: + global_scale = getattr(layer, + name + "_weight_scale_2").to(param_dtype) tensor_list = [] if "w13" in name: @@ -218,23 +257,47 @@ def prepare_moe_fp4_layer_for_marlin(layer: torch.nn.Module) -> None: size_n, size_k = k, n for i in range(e): - marlin_scales = marlin_permute_scales(s=scales[i].T, + scale = scales[i].T + + marlin_scales = marlin_permute_scales(s=scale, size_k=size_k, size_n=size_n, - group_size=16) - marlin_scales = fp4_marlin_process_scales(marlin_scales) + group_size=group_size) + if is_nvfp4: + marlin_scales = nvfp4_marlin_process_scales(marlin_scales) + else: + marlin_scales = mxfp4_marlin_process_scales(marlin_scales) tensor_list.append(marlin_scales) scales = torch.cat([x.unsqueeze(0) for x in tensor_list], 0) scales = torch.nn.Parameter(scales, requires_grad=False) setattr(layer, name + "_weight_scale", scales) - global_scale = fp4_marlin_process_global_scale(global_scale) - global_scale = torch.nn.Parameter(global_scale, requires_grad=False) - setattr(layer, name + "_weight_scale_2", global_scale) + if is_nvfp4: + global_scale = nvfp4_marlin_process_global_scale(global_scale) + global_scale = torch.nn.Parameter(global_scale, + requires_grad=False) + setattr(layer, name + "_weight_scale_2", global_scale) + + # BIAS + # Permute bias + for name in ["w13_bias", "w2_bias"]: + if not hasattr(layer, name): + continue + bias = getattr(layer, name).to(param_dtype) + + tensor_list = [] + for i in range(e): + expert_bias = bias[i] + + tensor_list.append(marlin_permute_bias(expert_bias)) + + bias = torch.cat([x.unsqueeze(0) for x in tensor_list], 0) + bias = torch.nn.Parameter(bias, requires_grad=False) + setattr(layer, name, bias) -def rand_marlin_weight_fp4_like(weight, group_size): +def rand_marlin_weight_nvfp4_like(weight, group_size): assert group_size > 0 size_n, size_k = weight.shape device = weight.device @@ -276,8 +339,58 @@ def rand_marlin_weight_fp4_like(weight, group_size): size_k=size_k, size_n=size_n, group_size=group_size) - marlin_scales = fp4_marlin_process_scales(marlin_scales) + marlin_scales = nvfp4_marlin_process_scales(marlin_scales) - global_scale = fp4_marlin_process_global_scale(global_scale) + global_scale = nvfp4_marlin_process_global_scale(global_scale) return weight_ref.T, marlin_qweight, marlin_scales, global_scale + + +def rand_marlin_weight_mxfp4_like(weight, group_size): + assert group_size > 0 + size_n, size_k = weight.shape + device = weight.device + + scales = torch.randint(100, + 125, (size_n, size_k // group_size), + dtype=torch.uint8, + device=weight.device) + scales = scales.view(torch.float8_e8m0fnu) + + fp4_weight = torch.randint(0, + 256, (size_n, size_k // 2), + dtype=torch.uint8, + device=weight.device) + fp4_weight_part_1 = ((fp4_weight & 0b10000000) | + ((fp4_weight & 0b01110000) >> 2)) + fp4_weight_part_1 = fp4_weight_part_1.view(torch.float8_e4m3fn) + fp4_weight_part_1 = fp4_weight_part_1.to(weight.dtype) * (2**6) + + fp4_weight2 = fp4_weight << 4 + fp4_weight_part_2 = ((fp4_weight2 & 0b10000000) | + ((fp4_weight2 & 0b01110000) >> 2)) + fp4_weight_part_2 = fp4_weight_part_2.view(torch.float8_e4m3fn) + fp4_weight_part_2 = fp4_weight_part_2.to(weight.dtype) * (2**6) + + weight_ref = torch.cat( + [fp4_weight_part_2.unsqueeze(2), + fp4_weight_part_1.unsqueeze(2)], 2).view(size_n, size_k) + weight_ref = weight_ref * \ + scales.repeat_interleave(group_size, 1).to(weight.dtype) + + marlin_qweight = ops.gptq_marlin_repack( + b_q_weight=fp4_weight.view(torch.int32).T.contiguous(), + perm=torch.empty(0, dtype=torch.int, device=device), + size_k=size_k, + size_n=size_n, + num_bits=4, + ) + + marlin_scales = marlin_permute_scales(s=scales.T.to(weight.dtype), + size_k=size_k, + size_n=size_n, + group_size=group_size) + + marlin_scales = mxfp4_marlin_process_scales(marlin_scales) + + return weight_ref.T, marlin_qweight, marlin_scales.to(torch.float8_e8m0fnu) diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py index 5372c49d9..511e19545 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py @@ -8,8 +8,8 @@ import torch import vllm._custom_ops as ops from vllm.logger import init_logger from vllm.model_executor.layers.quantization.utils.marlin_utils import ( - USE_FP32_REDUCE_DEFAULT, marlin_make_workspace_new, marlin_permute_scales, - should_use_atomic_add_reduce) + USE_FP32_REDUCE_DEFAULT, marlin_make_workspace_new, marlin_permute_bias, + marlin_permute_scales, should_use_atomic_add_reduce) from vllm.platforms import current_platform from vllm.scalar_type import scalar_types @@ -58,6 +58,7 @@ def apply_fp8_marlin_linear( output = ops.gptq_marlin_gemm(a=reshaped_x, c=None, b_q_weight=weight, + b_bias=bias, b_scales=weight_scale, global_scale=None, b_zeros=None, @@ -71,9 +72,6 @@ def apply_fp8_marlin_linear( use_atomic_add=use_atomic_add, use_fp32_reduce=use_fp32_reduce) - if bias is not None: - output.add_(bias) # In-place add - return output.reshape(out_shape) @@ -160,6 +158,11 @@ def prepare_fp8_layer_for_marlin(layer: torch.nn.Module, marlin_scales = fp8_fused_exponent_bias_into_scales(marlin_scales) layer.weight_scale = torch.nn.Parameter(marlin_scales, requires_grad=False) + if hasattr(layer, "bias") and layer.bias is not None: + assert layer.bias.shape == (part_size_n, ) + bias = marlin_permute_bias(layer.bias) + layer.bias = torch.nn.Parameter(bias, requires_grad=False) + def prepare_moe_fp8_layer_for_marlin(layer: torch.nn.Module, size_k_first: bool = True) -> None: @@ -274,6 +277,23 @@ def prepare_moe_fp8_layer_for_marlin(layer: torch.nn.Module, setattr(layer, name + "_weight_scale", scales) + # BIAS + # Permute bias + for name in ["w13_bias", "w2_bias"]: + if not hasattr(layer, name): + continue + bias = getattr(layer, name).to(layer.orig_dtype) + + tensor_list = [] + for i in range(e): + expert_bias = bias[i] + + tensor_list.append(marlin_permute_bias(expert_bias)) + + bias = torch.cat([x.unsqueeze(0) for x in tensor_list], 0) + bias = torch.nn.Parameter(bias, requires_grad=False) + setattr(layer, name, bias) + def pack_fp8_to_int32(fp8_tensor: torch.Tensor, size_k_first: bool = True) -> torch.Tensor: diff --git a/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py b/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py index 95eabe149..deeb69bca 100644 --- a/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py +++ b/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py @@ -61,7 +61,7 @@ def _can_support_mxfp4(use_grouped_topk: bool = False, e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, scoring_func: str = "softmax", - activation: str = "silu", + activation: str = "swiglu_oai", expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None): diff --git a/vllm/model_executor/layers/rotary_embedding/base.py b/vllm/model_executor/layers/rotary_embedding/base.py index 6dfc28be7..10fce857a 100644 --- a/vllm/model_executor/layers/rotary_embedding/base.py +++ b/vllm/model_executor/layers/rotary_embedding/base.py @@ -8,7 +8,6 @@ import torch from vllm.model_executor.custom_op import CustomOp from .common import apply_rotary_emb_dispatch, apply_rotary_emb_torch -from .rocm_aiter_rope_ops import is_rocm_rotary_embedding_enabled @CustomOp.register("rotary_embedding") @@ -36,7 +35,6 @@ class RotaryEmbedding(CustomOp): cache = cache.to(dtype) self.cos_sin_cache: torch.Tensor self.register_buffer("cos_sin_cache", cache, persistent=False) - self.is_rocm_aiter_enabled = is_rocm_rotary_embedding_enabled() def _compute_inv_freq(self, base: float) -> torch.Tensor: """Compute the inverse frequency.""" @@ -121,75 +119,6 @@ class RotaryEmbedding(CustomOp): self.cos_sin_cache, self.is_neox_style) return query, key - def forward_hip( - self, - positions: torch.Tensor, - query: torch.Tensor, - key: Optional[torch.Tensor] = None, - offsets: Optional[torch.Tensor] = None, - is_nope_first=False, - ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: - # currently only rotary embedding ops from AITER package are - # supported for HiP forward. - if self.is_rocm_aiter_enabled: - return self.forward_hip_rocm_aiter(positions, query, key, offsets, - is_nope_first) - return self.forward_native(positions, query, key, offsets) - - def forward_hip_rocm_aiter( - self, - positions: torch.Tensor, - # if is_nope_first - # [[batch_size, seq_len, num_heads, nope_size+rope_size] - # if NOT is_nope_first - # [[batch_size, seq_len, num_heads, rope_size+nope_size], - query: torch.Tensor, - key: Optional[torch.Tensor] = None, - offsets: Optional[torch.Tensor] = None, - is_nope_first: bool = False, - ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: - if self.cos_sin_cache.device != query.device or \ - self.cos_sin_cache.dtype != query.dtype: - self.cos_sin_cache = self.cos_sin_cache.to(query.device, - dtype=query.dtype) - cos, sin = self.cos_sin_cache.chunk(2, dim=-1) - - cos = cos.unsqueeze(-2).unsqueeze(-2) - sin = sin.unsqueeze(-2).unsqueeze(-2) - - rotate_style = 0 if self.is_neox_style else 1 - - num_tokens = positions.numel() - - query_shape = query.shape - query = query.view(1, num_tokens, -1, self.head_size) - if key is not None: - key_shape = key.shape - key = key.view(1, num_tokens, -1, self.head_size) - - positions = positions.view(*query.shape[:2]) - if offsets is not None: - offsets = offsets.view(*query.shape[:2]) - - if not is_nope_first: - query_ = query[..., :self.rotary_dim] - key_ = key[..., :self.rotary_dim] if key is not None else None - else: - query_ = query[..., -self.rotary_dim:] - key_ = key[..., -self.rotary_dim:] if key is not None else None - - if key_ is None: - torch.ops.vllm.rocm_aiter_rotary_emb_without_key_forward_hip( - positions, sin, cos, query_, offsets, rotate_style, - is_nope_first) - return query.view(query_shape), None - - torch.ops.vllm.rocm_aiter_rotary_emb_with_key_forward_hip( - positions, sin, cos, query_, key_, offsets, rotate_style, - is_nope_first) - - return query.view(query_shape), key.view(key_shape) - def forward_xpu( self, positions: torch.Tensor, diff --git a/vllm/model_executor/layers/rotary_embedding/common.py b/vllm/model_executor/layers/rotary_embedding/common.py index 99b6bb212..8d821bea1 100644 --- a/vllm/model_executor/layers/rotary_embedding/common.py +++ b/vllm/model_executor/layers/rotary_embedding/common.py @@ -99,7 +99,7 @@ def yarn_linear_ramp_mask(low: float, high: float, dim: int, return ramp_func -def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float: +def yarn_get_mscale(scale: float = 1) -> float: if scale <= 1: return 1.0 - return 0.1 * mscale * math.log(scale) + 1.0 + return 0.1 * math.log(scale) + 1.0 diff --git a/vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py b/vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py index 5af671703..cd888b733 100644 --- a/vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py +++ b/vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import math from typing import Optional import torch @@ -9,7 +10,13 @@ from vllm.platforms import current_platform from .base import RotaryEmbedding from .common import (rotate_gptj, rotate_neox, yarn_find_correction_range, - yarn_get_mscale, yarn_linear_ramp_mask) + yarn_linear_ramp_mask) + + +def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float: + if scale <= 1: + return 1.0 + return 0.1 * mscale * math.log(scale) + 1.0 class DeepseekScalingRotaryEmbedding(RotaryEmbedding): @@ -89,9 +96,6 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding): offsets: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: """PyTorch-native implementation equivalent to forward().""" - if self.is_rocm_aiter_enabled: - return self.forward_hip_rocm_aiter(positions, query, key, offsets) - assert key is not None query_rot = query[..., :self.rotary_dim] key_rot = key[..., :self.rotary_dim] diff --git a/vllm/model_executor/layers/rotary_embedding/rocm_aiter_rope_ops.py b/vllm/model_executor/layers/rotary_embedding/rocm_aiter_rope_ops.py deleted file mode 100644 index 91a2318ba..000000000 --- a/vllm/model_executor/layers/rotary_embedding/rocm_aiter_rope_ops.py +++ /dev/null @@ -1,127 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from typing import Optional - -import torch - -import vllm.envs as envs -from vllm.platforms import current_platform -from vllm.utils import direct_register_custom_op - - -def is_rocm_rotary_embedding_enabled() -> bool: - return (current_platform.is_rocm() and envs.VLLM_ROCM_USE_AITER) - - -def rocm_aiter_rotary_emb_without_key_forward_hip_impl( - positions: torch.Tensor, - sin: torch.Tensor, - cos: torch.Tensor, - query: torch.Tensor, - offsets: Optional[torch.Tensor] = None, - rotate_style: int = 0, - is_nope_first: bool = False, -) -> None: - import aiter as ops - if offsets is None: - ops.rope_cached_positions_fwd_inplace( - query, - cos, - sin, - positions, - rotate_style, - reuse_freqs_front_part=True, - nope_first=is_nope_first, - ) - else: - ops.rope_cached_positions_offsets_fwd_inplace( - query, - cos, - sin, - positions, - offsets, - rotate_style, - reuse_freqs_front_part=True, - nope_first=is_nope_first, - ) - - -def rocm_aiter_rotary_emb_with_key_forward_hip_impl( - positions: torch.Tensor, - sin: torch.Tensor, - cos: torch.Tensor, - query: torch.Tensor, - key: torch.Tensor, - offsets: Optional[torch.Tensor] = None, - rotate_style: int = 0, - is_nope_first: bool = False, -) -> None: - import aiter as ops - if offsets is None: - ops.rope_cached_positions_2c_fwd_inplace( - query, - key, - cos, - sin, - positions, - rotate_style, - reuse_freqs_front_part=True, - nope_first=is_nope_first, - ) - else: - ops.rope_cached_positions_offsets_2c_fwd_inplace( - query, - key, - cos, - sin, - positions, - offsets, - rotate_style, - reuse_freqs_front_part=True, - nope_first=is_nope_first, - ) - - -def rocm_aiter_rotary_emb_with_key_forward_hip_fake( - positions: torch.Tensor, - sin: torch.Tensor, - cos: torch.Tensor, - query: torch.Tensor, - key: torch.Tensor, - offsets: Optional[torch.Tensor] = None, - rotate_style: int = 0, - is_nope_first: bool = False, -) -> None: - pass - - -def rocm_aiter_rotary_emb_without_key_forward_hip_fake( - positions: torch.Tensor, - sin: torch.Tensor, - cos: torch.Tensor, - query: torch.Tensor, - offsets: Optional[torch.Tensor] = None, - rotate_style: int = 0, - is_nope_first: bool = False, -) -> None: - pass - - -if is_rocm_rotary_embedding_enabled(): - - direct_register_custom_op( - op_name="rocm_aiter_rotary_emb_with_key_forward_hip", - op_func=rocm_aiter_rotary_emb_with_key_forward_hip_impl, - mutates_args=["key", "query"], - fake_impl=rocm_aiter_rotary_emb_with_key_forward_hip_fake, - dispatch_key=current_platform.dispatch_key, - ) - - direct_register_custom_op( - op_name="rocm_aiter_rotary_emb_without_key_forward_hip", - op_func=rocm_aiter_rotary_emb_without_key_forward_hip_impl, - mutates_args=["query"], - fake_impl=rocm_aiter_rotary_emb_without_key_forward_hip_fake, - dispatch_key=current_platform.dispatch_key, - ) \ No newline at end of file diff --git a/vllm/model_executor/models/bamba.py b/vllm/model_executor/models/bamba.py index 4a2ae0758..e2cd31af5 100644 --- a/vllm/model_executor/models/bamba.py +++ b/vllm/model_executor/models/bamba.py @@ -12,7 +12,7 @@ from transformers import BambaConfig from vllm import envs from vllm.attention.layer import Attention from vllm.compilation.decorators import support_torch_compile -from vllm.config import CacheConfig, VllmConfig +from vllm.config import CacheConfig, ModelConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed.parallel_state import get_pp_group from vllm.forward_context import get_forward_context @@ -26,7 +26,7 @@ from vllm.model_executor.layers.mamba.mamba2_metadata import ( Mamba2Metadata, prepare_mamba2_metadata) from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 from vllm.model_executor.layers.mamba.mamba_utils import ( - MambaStateShapeCalculator) + MambaStateDtypeCalculator, MambaStateShapeCalculator) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -83,6 +83,7 @@ class BambaMixerDecoderLayer(nn.Module): def __init__(self, config: BambaConfig, layer_idx: int, + model_config: Optional[ModelConfig] = None, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "") -> None: @@ -100,6 +101,8 @@ class BambaMixerDecoderLayer(nn.Module): head_dim=config.mamba_d_head, rms_norm_eps=config.rms_norm_eps, activation=config.hidden_act, + model_config=model_config, + cache_config=cache_config, quant_config=quant_config, prefix=f"{prefix}.mixer") @@ -138,6 +141,7 @@ class BambaAttentionDecoderLayer(nn.Module): self, config: BambaConfig, layer_idx: int, + model_config: Optional[ModelConfig] = None, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", @@ -266,6 +270,7 @@ class BambaModel(nn.Module): super().__init__() config: BambaConfig = vllm_config.model_config.hf_config + model_config = vllm_config.model_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config lora_config = vllm_config.lora_config @@ -289,6 +294,7 @@ class BambaModel(nn.Module): return layer_class( config, layer_idx, + model_config, cache_config, quant_config=quant_config, prefix=prefix, @@ -437,6 +443,18 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, } embedding_padding_modules = ["lm_head"] + @classmethod + def get_mamba_state_dtype_from_config( + cls, + vllm_config: "VllmConfig", + ) -> tuple[torch.dtype, torch.dtype]: + + return MambaStateDtypeCalculator.mamba2_state_dtype( + vllm_config.model_config.dtype, + vllm_config.cache_config.mamba_cache_dtype, + vllm_config.cache_config.mamba_ssm_cache_dtype, + ) + @classmethod def get_mamba_state_shape_from_config( cls, @@ -528,10 +546,13 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, mamba_state_shape = \ self.get_mamba_state_shape_from_config( self.vllm_config, use_v1=False) + mamba_state_dtype = \ + self.get_mamba_state_dtype_from_config( + self.vllm_config) self.mamba_cache = MambaCacheManager(self.vllm_config, - self.lm_head.weight.dtype, num_mamba_layers, - *mamba_state_shape) + *mamba_state_shape, + *mamba_state_dtype) mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index 6f21cd267..882df7e81 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -318,7 +318,7 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig): # get mamba page size mamba_page_size = MambaSpec( shapes=model_cls.get_mamba_state_shape_from_config(vllm_config), - dtype=kv_cache_dtype, + dtypes=model_cls.get_mamba_state_dtype_from_config(vllm_config), block_size=model_config.max_model_len, ).page_size_bytes diff --git a/vllm/model_executor/models/falcon_h1.py b/vllm/model_executor/models/falcon_h1.py index 85d64af5b..5e2b6d691 100644 --- a/vllm/model_executor/models/falcon_h1.py +++ b/vllm/model_executor/models/falcon_h1.py @@ -11,7 +11,7 @@ from transformers import FalconH1Config from vllm import envs from vllm.attention.layer import Attention from vllm.compilation.decorators import support_torch_compile -from vllm.config import CacheConfig, VllmConfig +from vllm.config import CacheConfig, ModelConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed.parallel_state import get_pp_group from vllm.forward_context import get_forward_context @@ -25,7 +25,7 @@ from vllm.model_executor.layers.mamba.mamba2_metadata import ( Mamba2Metadata, prepare_mamba2_metadata) from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 from vllm.model_executor.layers.mamba.mamba_utils import ( - MambaStateShapeCalculator) + MambaStateDtypeCalculator, MambaStateShapeCalculator) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -85,6 +85,7 @@ class FalconH1SSMDecoderLayer(nn.Module): def __init__( self, config: FalconH1Config, + model_config: Optional[ModelConfig] = None, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", @@ -108,6 +109,8 @@ class FalconH1SSMDecoderLayer(nn.Module): head_dim=config.mamba_d_head, rms_norm_eps=config.rms_norm_eps, activation=config.hidden_act, + model_config=model_config, + cache_config=cache_config, quant_config=quant_config, use_rms_norm=config.mamba_rms_norm, prefix=f"{prefix}.mixer", @@ -317,6 +320,7 @@ class FalconH1ParallelHybrid(nn.Module): self, config: FalconH1Config, layer_idx: int, + model_config: Optional[ModelConfig] = None, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", @@ -339,6 +343,7 @@ class FalconH1ParallelHybrid(nn.Module): # Instantiate the SSM branch self.mamba = FalconH1SSMDecoderLayer( config=config, + model_config=model_config, cache_config=cache_config, quant_config=quant_config, prefix=ssm_prefix, @@ -408,6 +413,7 @@ class FalconH1Model(nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config: FalconH1Config = vllm_config.model_config.hf_config + model_config = vllm_config.model_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config lora_config = vllm_config.lora_config @@ -435,6 +441,7 @@ class FalconH1Model(nn.Module): return layer_class( config, layer_idx, + model_config, cache_config, quant_config=quant_config, prefix=prefix, @@ -519,6 +526,18 @@ class FalconH1ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, } embedding_padding_modules = ["lm_head"] + @classmethod + def get_mamba_state_dtype_from_config( + cls, + vllm_config: "VllmConfig", + ) -> tuple[torch.dtype, torch.dtype]: + + return MambaStateDtypeCalculator.mamba2_state_dtype( + vllm_config.model_config.dtype, + vllm_config.cache_config.mamba_cache_dtype, + vllm_config.cache_config.mamba_ssm_cache_dtype, + ) + @classmethod def get_mamba_state_shape_from_config( cls, @@ -624,12 +643,14 @@ class FalconH1ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, mamba_state_shape = \ self.get_mamba_state_shape_from_config( self.vllm_config, use_v1=False) + mamba_state_dtype = \ + self.get_mamba_state_dtype_from_config( + self.vllm_config) self.mamba_cache = MambaCacheManager( self.vllm_config, - self.lm_head.weight.dtype if hasattr( - self.lm_head, 'weight') else torch.bfloat16, self.config.num_hidden_layers, *mamba_state_shape, + *mamba_state_dtype, ) mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) diff --git a/vllm/model_executor/models/granitemoehybrid.py b/vllm/model_executor/models/granitemoehybrid.py index e59502f12..5704496b9 100644 --- a/vllm/model_executor/models/granitemoehybrid.py +++ b/vllm/model_executor/models/granitemoehybrid.py @@ -12,7 +12,7 @@ from transformers import GraniteMoeHybridConfig from vllm import envs from vllm.attention.layer import Attention from vllm.compilation.decorators import support_torch_compile -from vllm.config import CacheConfig, VllmConfig +from vllm.config import CacheConfig, ModelConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed.parallel_state import get_pp_group from vllm.forward_context import get_forward_context @@ -24,7 +24,7 @@ from vllm.model_executor.layers.mamba.mamba2_metadata import ( Mamba2Metadata, prepare_mamba2_metadata) from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 from vllm.model_executor.layers.mamba.mamba_utils import ( - MambaStateShapeCalculator) + MambaStateDtypeCalculator, MambaStateShapeCalculator) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -50,6 +50,7 @@ class GraniteMoeHybridMambaDecoderLayer(nn.Module): def __init__(self, config: GraniteMoeHybridConfig, layer_idx: int, + model_config: Optional[ModelConfig] = None, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "") -> None: @@ -70,6 +71,8 @@ class GraniteMoeHybridMambaDecoderLayer(nn.Module): head_dim=config.mamba_d_head, rms_norm_eps=config.rms_norm_eps, activation=config.hidden_act, + model_config=model_config, + cache_config=cache_config, quant_config=quant_config, prefix=f"{prefix}.mixer") @@ -137,6 +140,7 @@ class GraniteMoeHybridAttentionDecoderLayer(nn.Module): self, config: GraniteMoeHybridConfig, layer_idx: int, + model_config: Optional[ModelConfig] = None, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", @@ -217,6 +221,7 @@ class GraniteMoeHybridAttention(nn.Module): def __init__( self, config: GraniteMoeHybridConfig, + model_config: Optional[ModelConfig] = None, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", @@ -316,6 +321,7 @@ class GraniteMoeHybridModel(nn.Module): super().__init__() config = vllm_config.model_config.hf_config + model_config = vllm_config.model_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config lora_config = vllm_config.lora_config @@ -340,6 +346,7 @@ class GraniteMoeHybridModel(nn.Module): return layer_class( config, layer_idx, + model_config, cache_config, quant_config=quant_config, prefix=prefix, @@ -527,6 +534,18 @@ class GraniteMoeHybridForCausalLM(nn.Module, HasInnerState, SupportsLoRA, } embedding_padding_modules = ["lm_head"] + @classmethod + def get_mamba_state_dtype_from_config( + cls, + vllm_config: "VllmConfig", + ) -> tuple[torch.dtype, torch.dtype]: + + return MambaStateDtypeCalculator.mamba2_state_dtype( + vllm_config.model_config.dtype, + vllm_config.cache_config.mamba_cache_dtype, + vllm_config.cache_config.mamba_ssm_cache_dtype, + ) + @classmethod def get_mamba_state_shape_from_config( cls, @@ -625,10 +644,13 @@ class GraniteMoeHybridForCausalLM(nn.Module, HasInnerState, SupportsLoRA, mamba_state_shape = \ self.get_mamba_state_shape_from_config( self.vllm_config, use_v1=False) + mamba_state_dtype = \ + self.get_mamba_state_dtype_from_config( + self.vllm_config) self.mamba_cache = MambaCacheManager(self.vllm_config, - self.model_config.dtype, num_mamba_layers, - *mamba_state_shape) + *mamba_state_shape, + *mamba_state_dtype) mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index fbd310121..0b32d6f25 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -10,7 +10,7 @@ from transformers import JambaConfig from vllm import envs from vllm.attention.layer import Attention -from vllm.config import CacheConfig, VllmConfig +from vllm.config import CacheConfig, ModelConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed.parallel_state import get_pp_group from vllm.model_executor.layers.fused_moe import FusedMoE @@ -21,7 +21,7 @@ from vllm.model_executor.layers.linear import (QKVParallelLinear, from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer from vllm.model_executor.layers.mamba.mamba_utils import ( - MambaStateShapeCalculator) + MambaStateDtypeCalculator, MambaStateShapeCalculator) from vllm.model_executor.layers.pooler import DispatchPooler, Pooler from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -94,6 +94,7 @@ class JambaMambaDecoderLayer(nn.Module): def __init__(self, config: JambaConfig, layer_idx: int, + model_config: Optional[ModelConfig] = None, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, is_lora_enabled: Optional[bool] = False, @@ -114,6 +115,8 @@ class JambaMambaDecoderLayer(nn.Module): rms_norm_eps=config.rms_norm_eps, activation=config.hidden_act, is_lora_enabled = self.is_lora_enabled, + model_config=model_config, + cache_config=cache_config, prefix=f"{prefix}.mixer", ) @@ -164,6 +167,7 @@ class JambaAttentionDecoderLayer(nn.Module): def __init__(self, config: JambaConfig, layer_idx: int, + model_config: Optional[ModelConfig] = None, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", @@ -280,6 +284,7 @@ class JambaModel(nn.Module): super().__init__() config = vllm_config.model_config.hf_config + model_config = vllm_config.model_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config lora_config = vllm_config.lora_config @@ -304,6 +309,7 @@ class JambaModel(nn.Module): config.layers_block_type[layer_idx]] return layer_class(config, layer_idx, + model_config, cache_config, quant_config=quant_config, prefix=prefix, @@ -520,9 +526,11 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, self.vllm_config.parallel_config, LayerBlockType.mamba) state_shape = self.get_mamba_state_shape_from_config( self.vllm_config) + state_dtype = self.get_mamba_state_dtype_from_config( + self.vllm_config) self.mamba_cache = MambaCacheManager(self.vllm_config, - self.lm_head.weight.dtype, - num_layers, *state_shape) + num_layers, *state_shape, + *state_dtype) mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) @@ -537,6 +545,18 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, def get_seqlen_agnostic_capture_inputs(self, batch_size: int): return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size) + @classmethod + def get_mamba_state_dtype_from_config( + cls, + vllm_config: "VllmConfig", + ) -> tuple[torch.dtype, torch.dtype]: + + return MambaStateDtypeCalculator.mamba1_state_dtype( + vllm_config.model_config.dtype, + vllm_config.cache_config.mamba_cache_dtype, + vllm_config.cache_config.mamba_ssm_cache_dtype, + ) + @classmethod def get_mamba_state_shape_from_config( cls, diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py index 80b63e153..f4aaf0c6f 100644 --- a/vllm/model_executor/models/mamba.py +++ b/vllm/model_executor/models/mamba.py @@ -9,13 +9,13 @@ from torch import nn from transformers import MambaConfig from vllm import envs -from vllm.config import CacheConfig, VllmConfig +from vllm.config import CacheConfig, ModelConfig, VllmConfig from vllm.distributed.parallel_state import get_pp_group from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer from vllm.model_executor.layers.mamba.mamba_utils import ( - MambaStateShapeCalculator) + MambaStateDtypeCalculator, MambaStateShapeCalculator) from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -40,6 +40,7 @@ class MambaDecoderLayer(nn.Module): def __init__(self, config: MambaConfig, + model_config: Optional[ModelConfig] = None, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, is_lora_enabled: Optional[bool] = False, @@ -61,6 +62,8 @@ class MambaDecoderLayer(nn.Module): rms_norm_eps=mixer_rms_eps, activation=config.hidden_act, is_lora_enabled=self.is_lora_enabled, + model_config=model_config, + cache_config=cache_config, prefix=f"{prefix}.mixer") self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) @@ -88,6 +91,7 @@ class MambaModel(nn.Module): super().__init__() config = vllm_config.model_config.hf_config + model_config = vllm_config.model_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config lora_config = vllm_config.lora_config @@ -108,6 +112,7 @@ class MambaModel(nn.Module): self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, lambda prefix: MambaDecoderLayer(config, + model_config=model_config, cache_config=cache_config, quant_config=quant_config, is_lora_enabled=is_lora_enabled, @@ -243,9 +248,11 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP): self.vllm_config.parallel_config, LayerBlockType.mamba) state_shape = self.get_mamba_state_shape_from_config( self.vllm_config) + state_dtype = self.get_mamba_state_dtype_from_config( + self.vllm_config) self.mamba_cache = MambaCacheManager(self.vllm_config, - self.lm_head.weight.dtype, - num_layers, *state_shape) + num_layers, *state_shape, + *state_dtype) mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) @@ -254,6 +261,18 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP): return hidden_states + @classmethod + def get_mamba_state_dtype_from_config( + cls, + vllm_config: "VllmConfig", + ) -> tuple[torch.dtype, torch.dtype]: + + return MambaStateDtypeCalculator.mamba1_state_dtype( + vllm_config.model_config.dtype, + vllm_config.cache_config.mamba_cache_dtype, + vllm_config.cache_config.mamba_ssm_cache_dtype, + ) + @classmethod def get_mamba_state_shape_from_config( cls, diff --git a/vllm/model_executor/models/mamba2.py b/vllm/model_executor/models/mamba2.py index 75e92b017..3432cf29f 100644 --- a/vllm/model_executor/models/mamba2.py +++ b/vllm/model_executor/models/mamba2.py @@ -11,7 +11,7 @@ from transformers import MambaConfig from vllm import envs from vllm.attention.backends.abstract import AttentionMetadata from vllm.compilation.decorators import support_torch_compile -from vllm.config import VllmConfig +from vllm.config import CacheConfig, ModelConfig, VllmConfig from vllm.distributed.parallel_state import get_pp_group from vllm.forward_context import get_forward_context from vllm.model_executor.layers.layernorm import RMSNorm @@ -20,7 +20,7 @@ from vllm.model_executor.layers.mamba.mamba2_metadata import ( Mamba2Metadata, prepare_mamba2_metadata) from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 from vllm.model_executor.layers.mamba.mamba_utils import ( - MambaStateShapeCalculator) + MambaStateDtypeCalculator, MambaStateShapeCalculator) from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -45,6 +45,8 @@ class Mamba2DecoderLayer(nn.Module): def __init__(self, config: MambaConfig, + model_config: Optional[ModelConfig] = None, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "") -> None: super().__init__() @@ -62,6 +64,8 @@ class Mamba2DecoderLayer(nn.Module): head_dim=config.head_dim, rms_norm_eps=config.layer_norm_epsilon, activation=config.hidden_act, + model_config=model_config, + cache_config=cache_config, quant_config=quant_config, prefix=f"{prefix}.mixer") @@ -93,6 +97,8 @@ class Mamba2Model(nn.Module): super().__init__() config = vllm_config.model_config.hf_config + model_config = vllm_config.model_config + cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config lora_config = vllm_config.lora_config is_lora_enabled = bool(lora_config) @@ -112,8 +118,11 @@ class Mamba2Model(nn.Module): self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, - lambda prefix: Mamba2DecoderLayer( - config, quant_config=quant_config, prefix=prefix), + lambda prefix: Mamba2DecoderLayer(config, + model_config=model_config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix), prefix=f"{prefix}.layers") self.norm_f = RMSNorm(config.hidden_size, @@ -200,6 +209,18 @@ class Mamba2Model(nn.Module): class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree): + @classmethod + def get_mamba_state_dtype_from_config( + cls, + vllm_config: "VllmConfig", + ) -> tuple[torch.dtype, torch.dtype]: + + return MambaStateDtypeCalculator.mamba2_state_dtype( + vllm_config.model_config.dtype, + vllm_config.cache_config.mamba_cache_dtype, + vllm_config.cache_config.mamba_ssm_cache_dtype, + ) + @classmethod def get_mamba_state_shape_from_config( cls, @@ -290,10 +311,13 @@ class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree): mamba_state_shape = \ self.get_mamba_state_shape_from_config( self.vllm_config, use_v1=False) + mamba_state_dtype = \ + self.get_mamba_state_dtype_from_config( + self.vllm_config) self.mamba_cache = MambaCacheManager(self.vllm_config, - self.lm_head.weight.dtype, num_mamba_layers, - *mamba_state_shape) + *mamba_state_shape, + *mamba_state_dtype) mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) else: diff --git a/vllm/model_executor/models/mamba_cache.py b/vllm/model_executor/models/mamba_cache.py index 27685c59a..6b16e3ce7 100644 --- a/vllm/model_executor/models/mamba_cache.py +++ b/vllm/model_executor/models/mamba_cache.py @@ -24,9 +24,14 @@ class MambaCacheParams: class MambaCacheManager(ConstantSizeCache): - def __init__(self, vllm_config: VllmConfig, dtype: torch.dtype, - num_mamba_layers: int, conv_state_shape: tuple[int, int], - temporal_state_shape: tuple[int, int]): + def __init__(self, vllm_config: VllmConfig, num_mamba_layers: int, + conv_state_shape: tuple[int, int], + temporal_state_shape: tuple[int, int], + conv_state_dtype: torch.dtype, + temporal_state_dtype: torch.dtype): + + self.conv_state_dtype = conv_state_dtype + self.temporal_state_dtype = temporal_state_dtype # Determine max batch size to set size of MambaCache max_batch_size = vllm_config.scheduler_config.max_num_seqs @@ -40,11 +45,11 @@ class MambaCacheManager(ConstantSizeCache): assert conv_state_shape[0] > conv_state_shape[1] conv_state = torch.empty(size=(num_mamba_layers, max_batch_size) + (conv_state_shape[1], conv_state_shape[0]), - dtype=dtype, + dtype=self.conv_state_dtype, device="cuda").transpose(-1, -2) temporal_state = torch.empty(size=(num_mamba_layers, max_batch_size) + temporal_state_shape, - dtype=dtype, + dtype=self.temporal_state_dtype, device="cuda") self._mamba_cache = (conv_state, temporal_state) diff --git a/vllm/model_executor/models/minimax_text_01.py b/vllm/model_executor/models/minimax_text_01.py index 3d14a6ad5..82e96844c 100644 --- a/vllm/model_executor/models/minimax_text_01.py +++ b/vllm/model_executor/models/minimax_text_01.py @@ -16,7 +16,8 @@ from transformers import MiniMaxConfig from vllm import envs from vllm.attention import Attention, AttentionMetadata -from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config +from vllm.config import (CacheConfig, ModelConfig, VllmConfig, + get_current_vllm_config) from vllm.distributed.communication_op import tensor_model_parallel_all_reduce from vllm.distributed.parallel_state import ( get_pp_group, get_tensor_model_parallel_rank, @@ -36,7 +37,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.mamba.abstract import MambaBase from vllm.model_executor.layers.mamba.mamba_utils import ( - MambaStateShapeCalculator) + MambaStateDtypeCalculator, MambaStateShapeCalculator) from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -338,6 +339,12 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase): def mamba_type(self) -> str: return "linear_attention" + def get_state_dtype(self) -> tuple[torch.dtype]: + return MambaStateDtypeCalculator.linear_attention_state_dtype( + self.model_config.dtype, + self.cache_config.mamba_cache_dtype, + ) + def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]: return MambaStateShapeCalculator.linear_attention_state_shape( num_heads=self.num_heads, @@ -353,6 +360,8 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase): max_position: int, block_size: int, num_hidden_layer: int, + model_config: Optional[ModelConfig] = None, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, layer_idx: int = 0, linear_layer_idx: int = 0, @@ -374,6 +383,8 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase): self.tp_heads = self.total_num_heads // self.tp_size self.qkv_size = self.num_heads * self.head_dim self.tp_hidden = self.head_dim * self.tp_heads + self.model_config = model_config + self.cache_config = cache_config self.prefix = prefix self.qkv_proj = ColumnParallelLinear( @@ -657,6 +668,7 @@ class MiniMaxText01DecoderLayer(nn.Module): def __init__( self, config: MiniMaxConfig, + model_config: Optional[ModelConfig] = None, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, expert_num: int = 1, @@ -693,6 +705,8 @@ class MiniMaxText01DecoderLayer(nn.Module): max_position=max_position_embeddings, block_size=config.block if hasattr(config, "block") else 256, num_hidden_layer=config.num_hidden_layers, + model_config=model_config, + cache_config=cache_config, quant_config=quant_config, layer_idx=self._ilayer, linear_layer_idx=linear_layer_id, @@ -861,6 +875,7 @@ class MiniMaxText01Model(nn.Module): def __init__( self, config: MiniMaxConfig, + model_config: Optional[ModelConfig] = None, quant_config: Optional[QuantizationConfig] = None, cache_config: Optional[CacheConfig] = None, scheduler_config=None, @@ -910,6 +925,7 @@ class MiniMaxText01Model(nn.Module): decoder_kwargs = { "quant_config": quant_config, "layer_id": layer_idx, + "model_config": model_config, "cache_config": cache_config } @@ -1111,8 +1127,9 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid): self.config.max_model_len = vllm_config.model_config.max_model_len self.model = MiniMaxText01Model( self.config, - quant_config, + model_config=vllm_config.model_config, cache_config=vllm_config.cache_config, + quant_config=quant_config, scheduler_config=vllm_config.scheduler_config, prefix=maybe_prefix(prefix, "model")) if get_pp_group().is_last_rank: @@ -1409,6 +1426,17 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid): load_basic_weight(name, loaded_weight, self) return loaded_params + @classmethod + def get_mamba_state_dtype_from_config( + cls, + vllm_config: "VllmConfig", + ) -> tuple[torch.dtype, torch.dtype]: + + return MambaStateDtypeCalculator.linear_attention_state_dtype( + vllm_config.model_config.dtype, + vllm_config.cache_config.mamba_cache_dtype, + ) + @classmethod def get_mamba_state_shape_from_config( cls, diff --git a/vllm/model_executor/models/nemotron_h.py b/vllm/model_executor/models/nemotron_h.py index 08315a138..07cd5a4c6 100644 --- a/vllm/model_executor/models/nemotron_h.py +++ b/vllm/model_executor/models/nemotron_h.py @@ -26,7 +26,7 @@ from torch import nn from vllm import envs from vllm.attention.layer import Attention from vllm.compilation.decorators import support_torch_compile -from vllm.config import CacheConfig, VllmConfig +from vllm.config import CacheConfig, ModelConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed.parallel_state import get_pp_group from vllm.forward_context import get_forward_context @@ -40,7 +40,7 @@ from vllm.model_executor.layers.mamba.mamba2_metadata import ( Mamba2Metadata, prepare_mamba2_metadata) from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 from vllm.model_executor.layers.mamba.mamba_utils import ( - MambaStateShapeCalculator) + MambaStateDtypeCalculator, MambaStateShapeCalculator) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) @@ -110,6 +110,7 @@ class NemotronHMLPDecoderLayer(nn.Module): self, config: NemotronHConfig, layer_idx: int, + model_config: Optional[ModelConfig] = None, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", @@ -149,6 +150,7 @@ class NemotronHMambaDecoderLayer(nn.Module): self, config: NemotronHConfig, layer_idx: int, + model_config: Optional[ModelConfig] = None, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", @@ -167,6 +169,8 @@ class NemotronHMambaDecoderLayer(nn.Module): head_dim=config.mamba_head_dim, rms_norm_eps=config.rms_norm_eps, activation=config.mamba_hidden_act, + model_config=model_config, + cache_config=cache_config, quant_config=quant_config, prefix=f"{prefix}.mixer", ) @@ -198,6 +202,7 @@ class NemotronHAttention(nn.Module): self, config: NemotronHConfig, layer_idx: int, + model_config: Optional[ModelConfig] = None, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", @@ -270,6 +275,7 @@ class NemotronHAttentionDecoderLayer(nn.Module): self, config: NemotronHConfig, layer_idx: int, + model_config: Optional[ModelConfig] = None, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", @@ -279,6 +285,7 @@ class NemotronHAttentionDecoderLayer(nn.Module): self.mixer = NemotronHAttention( config, layer_idx, + model_config, cache_config, quant_config, prefix=f"{prefix}.mixer", @@ -317,6 +324,7 @@ class NemotronHModel(nn.Module): super().__init__() config: NemotronHConfig = vllm_config.model_config.hf_config + model_config = vllm_config.model_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config lora_config = vllm_config.lora_config @@ -340,6 +348,7 @@ class NemotronHModel(nn.Module): return layer_class( config, layer_idx, + model_config, cache_config, quant_config=quant_config, prefix=prefix, @@ -478,6 +487,18 @@ class NemotronHForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, } embedding_padding_modules = ["lm_head"] + @classmethod + def get_mamba_state_dtype_from_config( + cls, + vllm_config: "VllmConfig", + ) -> tuple[torch.dtype, torch.dtype]: + + return MambaStateDtypeCalculator.mamba2_state_dtype( + vllm_config.model_config.dtype, + vllm_config.cache_config.mamba_cache_dtype, + vllm_config.cache_config.mamba_ssm_cache_dtype, + ) + @classmethod def get_mamba_state_shape_from_config( cls, @@ -569,10 +590,13 @@ class NemotronHForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, mamba_state_shape = \ self.get_mamba_state_shape_from_config( self.vllm_config, use_v1=False) + mamba_state_dtype = \ + self.get_mamba_state_dtype_from_config( + self.vllm_config) self.mamba_cache = MambaCacheManager(self.vllm_config, - self.lm_head.weight.dtype, num_mamba_layers, - *mamba_state_shape) + *mamba_state_shape, + *mamba_state_dtype) mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) diff --git a/vllm/model_executor/models/zamba2.py b/vllm/model_executor/models/zamba2.py index 4cb0becf3..ed65944c1 100644 --- a/vllm/model_executor/models/zamba2.py +++ b/vllm/model_executor/models/zamba2.py @@ -18,7 +18,7 @@ from transformers import Zamba2Config from vllm import envs from vllm.attention.layer import Attention from vllm.compilation.decorators import support_torch_compile -from vllm.config import CacheConfig, VllmConfig +from vllm.config import CacheConfig, ModelConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.forward_context import get_forward_context from vllm.model_executor.layers.activation import GeluAndMul @@ -33,7 +33,7 @@ from vllm.model_executor.layers.mamba.mamba2_metadata import ( Mamba2Metadata, prepare_mamba2_metadata) from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 from vllm.model_executor.layers.mamba.mamba_utils import ( - MambaStateShapeCalculator) + MambaStateDtypeCalculator, MambaStateShapeCalculator) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -478,6 +478,8 @@ class Zamba2MambaDecoderLayer(nn.Module): def __init__(self, config: Zamba2Config, + model_config: Optional[ModelConfig] = None, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "") -> None: """Initialize the Mamba decoder layer. @@ -502,6 +504,8 @@ class Zamba2MambaDecoderLayer(nn.Module): config.n_mamba_heads, rms_norm_eps=config.rms_norm_eps, activation="silu", + model_config=model_config, + cache_config=cache_config, quant_config=quant_config, prefix=f"{prefix}.mixer") @@ -578,6 +582,8 @@ class Zamba2HybridLayer(nn.Module): shared_transformer: Zamba2AttentionDecoderLayer, config: Zamba2Config, block_idx: int, + model_config: Optional[ModelConfig] = None, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ) -> None: @@ -596,6 +602,8 @@ class Zamba2HybridLayer(nn.Module): bias=False, quant_config=quant_config) self.mamba_decoder = Zamba2MambaDecoderLayer(config, + model_config=model_config, + cache_config=cache_config, quant_config=quant_config, prefix=prefix) @@ -669,6 +677,7 @@ class Zamba2Model(nn.Module): super().__init__() config = vllm_config.model_config.hf_config + model_config = vllm_config.model_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config lora_config = vllm_config.lora_config @@ -718,11 +727,15 @@ class Zamba2Model(nn.Module): Zamba2HybridLayer(block, config, block_idx, - quant_config, + model_config=model_config, + cache_config=cache_config, + quant_config=quant_config, prefix=prefix)) else: layers.append( Zamba2MambaDecoderLayer(config, + model_config=model_config, + cache_config=cache_config, quant_config=quant_config, prefix=prefix)) self.layers = nn.ModuleList(layers) @@ -848,6 +861,18 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid): "1.weight": "B.weight", }) + @classmethod + def get_mamba_state_dtype_from_config( + cls, + vllm_config: "VllmConfig", + ) -> tuple[torch.dtype, torch.dtype]: + + return MambaStateDtypeCalculator.mamba2_state_dtype( + vllm_config.model_config.dtype, + vllm_config.cache_config.mamba_cache_dtype, + vllm_config.cache_config.mamba_ssm_cache_dtype, + ) + @classmethod def get_mamba_state_shape_from_config( cls, @@ -966,10 +991,13 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid): mamba_state_shape = \ self.get_mamba_state_shape_from_config( self.vllm_config, use_v1=False) + mamba_state_dtype = \ + self.get_mamba_state_dtype_from_config( + self.vllm_config) self.mamba_cache = MambaCacheManager(self.vllm_config, - self.lm_head.weight.dtype, num_mamba_layers, - *mamba_state_shape) + *mamba_state_shape, + *mamba_state_dtype) # Get cache parameters for current run mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) diff --git a/vllm/model_executor/warmup/kernel_warmup.py b/vllm/model_executor/warmup/kernel_warmup.py index 10f2dc025..761172e4d 100644 --- a/vllm/model_executor/warmup/kernel_warmup.py +++ b/vllm/model_executor/warmup/kernel_warmup.py @@ -5,16 +5,53 @@ Warmup kernels used during model execution. This is useful specifically for JIT'ed kernels as we don't want JIT'ing to happen during model execution. """ +from typing import TYPE_CHECKING + import torch import vllm.envs as envs from vllm.model_executor.warmup.deep_gemm_warmup import deep_gemm_warmup +from vllm.platforms import current_platform from vllm.utils.deep_gemm import is_deep_gemm_supported +from vllm.utils.flashinfer import has_flashinfer + +if TYPE_CHECKING: + from vllm.v1.worker.gpu_model_runner import GPUModelRunner + from vllm.v1.worker.gpu_worker import Worker -def kernel_warmup(model: torch.nn.Module, max_tokens: int): +def kernel_warmup(worker: "Worker"): + # Deep GEMM warmup do_deep_gemm_warmup = (envs.VLLM_USE_DEEP_GEMM and is_deep_gemm_supported() and not envs.VLLM_SKIP_DEEP_GEMM_WARMUP) if do_deep_gemm_warmup: + model = worker.get_model() + max_tokens = worker.scheduler_config.max_num_batched_tokens deep_gemm_warmup(model, max_tokens) + + # FlashInfer autotune for Blackwell (SM 10.0) GPUs + if has_flashinfer() and current_platform.is_device_capability(100): + flashinfer_autotune(worker.model_runner) + + +def flashinfer_autotune(runner: "GPUModelRunner") -> None: + """ + Autotune FlashInfer operations. + FlashInfer have many implementations for the same operation, + autotuning runs benchmarks for each implementation and stores + the results. The results are cached transparently and + future calls to FlashInfer will use the best implementation. + Without autotuning, FlashInfer will rely on heuristics, which may + be significantly slower. + """ + from vllm.utils.flashinfer import autotune + + with torch.inference_mode(), autotune(): + # We skip EPLB here since we don't want to record dummy metrics + # When autotuning with number of tokens m, flashinfer will autotune + # operations for all number of tokens up to m. + # So we only need to run with the max number of tokens. + runner._dummy_run(runner.scheduler_config.max_num_batched_tokens, + skip_eplb=True, + is_profile=True) diff --git a/vllm/multimodal/hasher.py b/vllm/multimodal/hasher.py index ac27bb66f..c9ce1f0be 100644 --- a/vllm/multimodal/hasher.py +++ b/vllm/multimodal/hasher.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import pickle +import uuid from collections.abc import Iterable, Mapping from typing import Union @@ -34,6 +35,11 @@ class MultiModalHasher: return np.array(obj).tobytes() if isinstance(obj, Image.Image): + exif = obj.getexif() + if Image.ExifTags.Base.ImageID in exif and isinstance( + exif[Image.ExifTags.Base.ImageID], uuid.UUID): + # If the image has exif ImageID tag, use that + return exif[Image.ExifTags.Base.ImageID].bytes return cls.item_to_bytes( "image", np.asarray(convert_image_mode(obj, "RGBA"))) if isinstance(obj, torch.Tensor): diff --git a/vllm/multimodal/utils.py b/vllm/multimodal/utils.py index 3b01ee7ad..f914d0dc6 100644 --- a/vllm/multimodal/utils.py +++ b/vllm/multimodal/utils.py @@ -9,6 +9,7 @@ from itertools import groupby from pathlib import Path from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union from urllib.parse import ParseResult, urlparse +from urllib.request import url2pathname import numpy as np import numpy.typing as npt @@ -108,7 +109,7 @@ class MediaConnector: raise RuntimeError("Cannot load local files without " "`--allowed-local-media-path`.") - filepath = Path(url_spec.path) + filepath = Path(url2pathname(url_spec.path)) if allowed_local_media_path not in filepath.resolve().parents: raise ValueError( f"The file path {filepath} must be a subpath " diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 63f6b373c..321db8287 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -177,17 +177,20 @@ class CudaPlatformBase(Platform): logger.info("Forcing kv cache block size to 128 for " "CUTLASS_MLA backend.") + # lazy import to avoid circular import + from vllm.config import CUDAGraphMode + compilation_config = vllm_config.compilation_config if (envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput" and parallel_config.data_parallel_size > 1 - and compilation_config.use_cudagraph): + and compilation_config.cudagraph_mode != CUDAGraphMode.NONE): logger.info( - "Data Parallel: Forcing enforce eager to be True since DP " + "Data Parallel: disabling cudagraphs since DP " "with DeepEP high-throughput kernels are not CUDA Graph " "compatible. The DeepEP low-latency kernels are CUDA Graph " "compatible. Set the all_to_all backend to deepep_low_latency " "to use those kernels instead.") - compilation_config.use_cudagraph = False + compilation_config.cudagraph_mode = CUDAGraphMode.NONE if model_config is not None: model_config.enforce_eager = True @@ -316,7 +319,7 @@ class CudaPlatformBase(Platform): # FlashAttention is the default for SM 8.0+ GPUs if cls.has_device_capability(80): - if has_sink: + if has_sink and not cls.is_device_capability(90): logger.info_once("Using Triton backend on V1 engine.") return TRITON_ATTN_VLLM_V1 if is_default_backend_supported := is_attn_backend_supported( @@ -454,8 +457,8 @@ class CudaPlatformBase(Platform): return True @classmethod - def get_piecewise_backend_cls(cls) -> str: - return "vllm.compilation.cuda_piecewise_backend.CUDAPiecewiseBackend" # noqa + def get_static_graph_wrapper_cls(cls) -> str: + return "vllm.compilation.cuda_graph.CUDAGraphWrapper" @classmethod def stateless_init_device_torch_dist_pg( diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 91d531490..4017f1ca7 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -7,7 +7,7 @@ import random import sys from datetime import timedelta from platform import uname -from typing import TYPE_CHECKING, NamedTuple, Optional, Union +from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Union import numpy as np import torch @@ -137,6 +137,8 @@ class Platform: additional_env_vars: list[str] = [] + _global_graph_pool: Optional[Any] = None + @property def supported_dtypes(self) -> list[torch.dtype]: """Returns the supported dtypes for the current platform.""" @@ -522,6 +524,15 @@ class Platform: " attribute.", self.device_type, key) return None + def get_global_graph_pool(self) -> Any: + """ + Return the global graph pool for the this platform. + """ + cls = self.__class__ + if cls._global_graph_pool is None: + cls._global_graph_pool = self.graph_pool_handle() + return cls._global_graph_pool + @classmethod def get_cu_count(cls, device_id: int = 0) -> int: """ @@ -530,11 +541,11 @@ class Platform: raise NotImplementedError @classmethod - def get_piecewise_backend_cls(cls) -> str: + def get_static_graph_wrapper_cls(cls) -> str: """ - Get piecewise backend class for piecewise graph. + Get static graph wrapper class for static graph. """ - return "vllm.compilation.base_piecewise_backend.AbstractPiecewiseBackend" # noqa + return "vllm.compilation.base_static_graph.AbstractStaticGraphWrapper" @classmethod def stateless_init_device_torch_dist_pg( diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 2d5bee5fc..3ede86e15 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -421,8 +421,8 @@ class RocmPlatform(Platform): return 'gfx1' in torch.cuda.get_device_properties(0).gcnArchName @classmethod - def get_piecewise_backend_cls(cls) -> str: - return "vllm.compilation.cuda_piecewise_backend.CUDAPiecewiseBackend" # noqa + def get_static_graph_wrapper_cls(cls) -> str: + return "vllm.compilation.cuda_graph.CUDAGraphWrapper" @classmethod def stateless_init_device_torch_dist_pg( diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index c7522a89c..ba06abd07 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -99,7 +99,7 @@ class TpuPlatform(Platform): @classmethod def check_and_update_config(cls, vllm_config: VllmConfig) -> None: - from vllm.config import CompilationLevel + from vllm.config import CompilationLevel, CUDAGraphMode cache_config = vllm_config.cache_config # For v0, the default block size is 16. @@ -109,9 +109,17 @@ class TpuPlatform(Platform): # TPU only supports DYNAMO_ONCE compilation level if compilation_config.level != CompilationLevel.DYNAMO_ONCE: - logger.info("[TPU] Forcing DYNAMO_ONCE compilation level") + logger.info("[TPU] Forcing DYNAMO_ONCE compilation level, and " + "disabling cudagraph.") compilation_config.level = CompilationLevel.DYNAMO_ONCE + if compilation_config.cudagraph_mode is None or \ + compilation_config.cudagraph_mode.max_cudagraph_mode() \ + != CUDAGraphMode.NONE: + logger.info("[TPU] CUDA graph is not supported on TPU, " + "disabling cudagraphs.") + compilation_config.cudagraph_mode = CUDAGraphMode.NONE + if compilation_config.backend == "": compilation_config.backend = "openxla" diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py index abd58dbbc..66ebc8ad9 100644 --- a/vllm/platforms/xpu.py +++ b/vllm/platforms/xpu.py @@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Optional import torch import vllm.envs as envs +from vllm.config import CUDAGraphMode from vllm.logger import init_logger from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS @@ -100,16 +101,17 @@ class XPUPlatform(Platform): # Instances created using VllmConfig() typically have model_config as # None by default. The modification involves adding a check to prevent # potential null exceptions check and update model config. - if model_config is not None: - if model_config.dtype == torch.bfloat16: - bf16_supported = cls.device_support_bf16() - if not bf16_supported: - model_config.dtype = torch.float16 - if not model_config.enforce_eager: - logger.warning( - "CUDA graph is not supported on XPU, fallback to the eager " - "mode.") - model_config.enforce_eager = True + if model_config is not None and model_config.dtype == torch.bfloat16 \ + and not cls.device_support_bf16(): + model_config.dtype = torch.float16 + + compilation_config = vllm_config.compilation_config + if compilation_config.cudagraph_mode is None or \ + compilation_config.cudagraph_mode.max_cudagraph_mode() \ + != CUDAGraphMode.NONE: + logger.info("[XPU] CUDA graph is not supported on XPU, " + "disabling cudagraphs.") + compilation_config.cudagraph_mode = CUDAGraphMode.NONE # check and update parallel config parallel_config = vllm_config.parallel_config diff --git a/vllm/scalar_type.py b/vllm/scalar_type.py index 9060b55c7..6f11ab8e0 100644 --- a/vllm/scalar_type.py +++ b/vllm/scalar_type.py @@ -327,6 +327,8 @@ class scalar_types: uint8 = ScalarType.uint(8, None) float8_e4m3fn = ScalarType.float_(4, 3, True, NanRepr.EXTD_RANGE_MAX_MIN) float8_e5m2 = ScalarType.float_IEEE754(5, 2) + float8_e8m0fnu = ScalarType(8, 0, False, 0, True, + NanRepr.EXTD_RANGE_MAX_MIN) float16_e8m7 = ScalarType.float_IEEE754(8, 7) float16_e5m10 = ScalarType.float_IEEE754(5, 10) diff --git a/vllm/utils/__init__.py b/vllm/utils/__init__.py index cae4eecc0..a1f8ad164 100644 --- a/vllm/utils/__init__.py +++ b/vllm/utils/__init__.py @@ -173,6 +173,7 @@ CYAN = '\033[1;36m' RESET = '\033[0;0m' STR_DTYPE_TO_TORCH_DTYPE = { + "float32": torch.float32, "half": torch.half, "bfloat16": torch.bfloat16, "float": torch.float, diff --git a/vllm/utils/flashinfer.py b/vllm/utils/flashinfer.py index 6b23ed426..0d7d4b694 100644 --- a/vllm/utils/flashinfer.py +++ b/vllm/utils/flashinfer.py @@ -14,6 +14,7 @@ import os from typing import Any, Callable, NoReturn, Optional import requests +import torch import vllm.envs as envs from vllm.logger import init_logger @@ -193,6 +194,75 @@ def use_trtllm_attention( return use_trtllm +if has_flashinfer(): + + @torch.library.custom_op( + "vllm::flashinfer_mm_fp4", + mutates_args=[], + device_types="cuda", + ) + def flashinfer_mm_fp4( + A: torch.Tensor, + B: torch.Tensor, + A_scale: torch.Tensor, + B_scale: torch.Tensor, + g_scale: torch.Tensor, + dtype: torch.dtype, + backend: str, + ) -> torch.Tensor: + from flashinfer import mm_fp4 as flashinfer_mm_fp4_ + return flashinfer_mm_fp4_(A, + B, + A_scale, + B_scale, + g_scale, + dtype, + block_size=16, + backend=backend) + + @torch.library.register_fake("vllm::flashinfer_mm_fp4", ) + def flashinfer_mm_fp4_fake( + A: torch.Tensor, + B: torch.Tensor, + A_scale: torch.Tensor, + B_scale: torch.Tensor, + g_scale: torch.Tensor, + dtype: torch.dtype, + backend: str, + ) -> torch.Tensor: + return torch.empty(A.shape[0], + B.shape[1], + dtype=dtype, + device=A.device) + + +def flashinfer_scaled_fp4_mm(a: torch.Tensor, b: torch.Tensor, + block_scale_a: torch.Tensor, + block_scale_b: torch.Tensor, alpha: torch.Tensor, + out_dtype: torch.dtype, + backend: str) -> torch.Tensor: + assert a.ndim == 2 and b.ndim == 2 + assert block_scale_a.ndim == 2 and block_scale_b.ndim == 2 + assert a.stride(-1) == 1 and b.stride(-1) == 1 + assert a.shape[1] == b.shape[1] + assert block_scale_a.shape[1] == a.shape[1] // 8 + assert block_scale_b.shape[1] == b.shape[1] // 8 + + if backend == "cutlass": + block_scale_a = block_scale_a.view(torch.uint8) + block_scale_b = block_scale_b.view(torch.uint8) + + return flashinfer_mm_fp4( + a, + b.t(), + block_scale_a, + block_scale_b.t(), + alpha, + out_dtype, + backend=backend, + ) + + __all__ = [ "has_flashinfer", "flashinfer_trtllm_fp8_block_scale_moe", @@ -205,4 +275,5 @@ __all__ = [ "has_flashinfer_cutlass_fused_moe", "has_nvidia_artifactory", "use_trtllm_attention", + "flashinfer_scaled_fp4_mm", ] diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index a411477bc..ab7a71a39 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Attention layer with FlashAttention.""" from dataclasses import dataclass -from typing import ClassVar, Optional +from typing import Optional import numpy as np import torch @@ -154,9 +154,26 @@ def _get_sliding_window_configs( class FlashAttentionMetadataBuilder( AttentionMetadataBuilder[FlashAttentionMetadata]): - attn_cudagraph_support: ClassVar[AttentionCGSupport] = \ - AttentionCGSupport.NEVER if get_flash_attn_version() == 2 \ - else AttentionCGSupport.ALWAYS + # FA3: + # Supports full cudagraphs for all cases. + # + # FA2: + # For FA2, a graph is captured with max_query_len=1, (which is what we + # capture by default for num_tokens <= max_num_seqs when there is no + # spec-decode) then these graphs will not work for mixed prefill-decode + # (unlike FA3). This is due to special max_query_len=1 packed-GQA handling + # in FA2. + # In summary if we are running with spec decodes the graphs would + # work for mixed prefill-decode and uniform-decode. But for non-spec decodes + # the graphs would not work for mixed prefill-decode; sorta the inverse + # of UNIFORM_SINGLE_TOKEN_DECODE. + # Theres probably a better way to describe this using `AttentionCGSupport` + # but for now just set it to `UNIFORM_BATCH` to get use to drop down + # to FULL_AND_PIECEWISE. + # TODO(luka, lucas): audit FA2 as part of: + # https://github.com/vllm-project/vllm/issues/22945 + cudagraph_support = AttentionCGSupport.ALWAYS \ + if get_flash_attn_version() == 3 else AttentionCGSupport.UNIFORM_BATCH def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], vllm_config: VllmConfig, device: torch.device): @@ -177,17 +194,13 @@ class FlashAttentionMetadataBuilder( self.max_num_splits = 0 # No upper bound on the number of splits. self.aot_schedule = (get_flash_attn_version() == 3) - self.use_full_cuda_graph = self.compilation_config.full_cuda_graph - if self.use_full_cuda_graph: - if not self.aot_schedule: - raise ValueError( - "AoT scheduling is required for full cuda graph.") - capture_sizes = self.compilation_config.cudagraph_capture_sizes - if not capture_sizes: - raise ValueError( - "cudagraph_capture_sizes should not be None when " - "full_cuda_graph is True.") - self.max_cudagraph_size = max(capture_sizes) + + self.use_full_cuda_graph = \ + self.compilation_config.cudagraph_mode.has_full_cudagraphs() + + if self.use_full_cuda_graph and self.aot_schedule: + self.max_cudagraph_size = self.compilation_config.max_capture_size + if self.max_cudagraph_size > 992: # This condition derives from FA3's internal heuristic. # TODO(woosuk): Support larger cudagraph sizes. @@ -310,9 +323,9 @@ class FlashAttentionMetadataBuilder( seqlens=seq_lens, max_seq_len=max_seq_len, causal=causal) - - if self.use_full_cuda_graph: - assert scheduler_metadata is not None + # For FA3 + full cudagraph + max_num_splits = 0 + if self.use_full_cuda_graph and scheduler_metadata is not None: n = scheduler_metadata.shape[0] self.scheduler_metadata[:n] = scheduler_metadata # NOTE(woosuk): We should zero out the rest of the scheduler @@ -322,14 +335,12 @@ class FlashAttentionMetadataBuilder( self.scheduler_metadata[n:] = 0 scheduler_metadata = self.scheduler_metadata[:n] - max_num_splits = 0 - if (self.use_full_cuda_graph - and num_actual_tokens <= self.max_cudagraph_size): - # NOTE(woosuk): Setting num_splits > 1 may increase the memory - # usage, because the intermediate buffers of size [num_splits, - # num_heads, num_tokens, head_size] are allocated. Therefore, - # we only set num_splits when using cuda graphs. - max_num_splits = self.max_num_splits + if num_actual_tokens <= self.max_cudagraph_size: + # NOTE(woosuk): Setting num_splits > 1 may increase the memory + # usage, because the intermediate buffers of size [num_splits, + # num_heads, num_tokens, head_size] are allocated. Therefore, + # we only set num_splits when using cuda graphs. + max_num_splits = self.max_num_splits attn_metadata = FlashAttentionMetadata( num_actual_tokens=num_actual_tokens, @@ -350,11 +361,6 @@ class FlashAttentionMetadataBuilder( causal=causal) return attn_metadata - def can_run_in_cudagraph( - self, common_attn_metadata: CommonAttentionMetadata) -> bool: - # Full CUDA Graph always supported (FA2 support checked separately) - return True - def use_cascade_attention(self, *args, **kwargs) -> bool: return use_cascade_attention(*args, **kwargs) diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 12e5542d6..02decb171 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -17,7 +17,7 @@ from flashinfer.prefill import trtllm_batch_context_with_kv_cache import vllm.envs as envs from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionType) -from vllm.config import VllmConfig +from vllm.config import CUDAGraphMode, VllmConfig from vllm.logger import init_logger from vllm.utils import cdiv, is_pin_memory_available from vllm.utils.flashinfer import use_trtllm_attention @@ -183,8 +183,8 @@ class FlashInferMetadata: class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): - attn_cudagraph_support: ClassVar[AttentionCGSupport] = \ - AttentionCGSupport.PURE_DECODE_ONLY + cudagraph_support: ClassVar[AttentionCGSupport] = \ + AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE reorder_batch_threshold: ClassVar[int] = 1 @@ -203,7 +203,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): self.kv_cache_spec.block_size) max_num_reqs = vllm_config.scheduler_config.max_num_seqs max_num_pages = max_num_reqs * max_num_pages_per_req - self.enable_cuda_graph = self.compilation_config.full_cuda_graph + self.enable_cuda_graph = self.compilation_config.cudagraph_mode.\ + decode_mode() == CUDAGraphMode.FULL if self.enable_cuda_graph: # For full cudagraph capture, one `decode_wrapper` for each batch # size is needed for FlashInfer. @@ -586,10 +587,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): return self.build(0, m) - def can_run_in_cudagraph( - self, common_attn_metadata: CommonAttentionMetadata) -> bool: - return common_attn_metadata.max_query_len == 1 - def use_cascade_attention(self, *args, **kwargs) -> bool: if self.kv_cache_spec.dtype != self.vllm_config.model_config.dtype: # TODO: The cascade wrapper currently does not support setting diff --git a/vllm/v1/attention/backends/mamba1_attn.py b/vllm/v1/attention/backends/mamba1_attn.py index f0e4636fd..6cdc50908 100644 --- a/vllm/v1/attention/backends/mamba1_attn.py +++ b/vllm/v1/attention/backends/mamba1_attn.py @@ -2,14 +2,15 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass -from typing import ClassVar +from typing import ClassVar, Optional import torch from vllm.attention.backends.abstract import AttentionBackend from vllm.config import VllmConfig from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, - CommonAttentionMetadata) + CommonAttentionMetadata, + split_decodes_and_prefills) from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec @@ -25,12 +26,15 @@ class Mamba1AttentionMetadata: query_start_loc: torch.Tensor context_lens_tensor: torch.Tensor state_indices_tensor: torch.Tensor - has_initial_states: torch.Tensor + has_initial_states: Optional[torch.Tensor] + num_prefills: int + num_prefill_tokens: int + num_decodes: int + num_decode_tokens: int class Mamba1AttentionMetadataBuilder( AttentionMetadataBuilder[Mamba1AttentionMetadata]): - reorder_batch_threshold: ClassVar[int] = 1 def __init__( @@ -57,11 +61,23 @@ class Mamba1AttentionMetadataBuilder( state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0] context_lens_tensor = common_attn_metadata.num_computed_tokens_cpu.to( query_start_loc.device) - has_initial_states = (context_lens_tensor > 0) + + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( + split_decodes_and_prefills(common_attn_metadata, + decode_threshold=1)) + + has_initial_states = None + + if num_prefills > 0: + has_initial_states = context_lens_tensor > 0 return Mamba1AttentionMetadata( query_start_loc=query_start_loc, context_lens_tensor=context_lens_tensor, has_initial_states=has_initial_states, state_indices_tensor=state_indices_tensor, + num_prefills=num_prefills, + num_prefill_tokens=num_prefill_tokens, + num_decodes=num_decodes, + num_decode_tokens=num_decode_tokens, ) diff --git a/vllm/v1/attention/backends/mamba_attn.py b/vllm/v1/attention/backends/mamba2_attn.py similarity index 96% rename from vllm/v1/attention/backends/mamba_attn.py rename to vllm/v1/attention/backends/mamba2_attn.py index 3f84f8967..ace078e2b 100644 --- a/vllm/v1/attention/backends/mamba_attn.py +++ b/vllm/v1/attention/backends/mamba2_attn.py @@ -89,8 +89,8 @@ class Mamba2AttentionMetadata: class Mamba2AttentionMetadataBuilder( AttentionMetadataBuilder[Mamba2AttentionMetadata]): - attn_cudagraph_support: ClassVar[AttentionCGSupport] = \ - AttentionCGSupport.PURE_DECODE_ONLY + cudagraph_support: ClassVar[AttentionCGSupport] = \ + AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE reorder_batch_threshold: ClassVar[int] = 1 @@ -203,7 +203,3 @@ class Mamba2AttentionMetadataBuilder( m.max_query_len = 1 # decode-only return self.build(0, m) - - def can_run_in_cudagraph( - self, common_attn_metadata: CommonAttentionMetadata) -> bool: - return common_attn_metadata.max_query_len == 1 diff --git a/vllm/v1/attention/backends/mamba_selectors.py b/vllm/v1/attention/backends/mamba_selectors.py index 852e0dfe1..d3a0c63c5 100644 --- a/vllm/v1/attention/backends/mamba_selectors.py +++ b/vllm/v1/attention/backends/mamba_selectors.py @@ -3,7 +3,7 @@ from vllm.attention.backends.abstract import AttentionBackend from vllm.v1.attention.backends.linear_attn import LinearAttentionBackend from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionBackend -from vllm.v1.attention.backends.mamba_attn import Mamba2AttentionBackend +from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionBackend def get_mamba_attn_backend(mamba_type: str) -> type[AttentionBackend]: diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index badff6765..f2610671f 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -575,7 +575,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): "MLA only supports decode-only full CUDAGraph capture. " \ "Make sure all cudagraph capture sizes <= max_num_seq." - m.max_query_len = 1 # decode-only + assert m.max_query_len == 1 # decode-only return self.build(0, m) @@ -728,10 +728,6 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): return attn_metadata - def can_run_in_cudagraph( - self, common_attn_metadata: CommonAttentionMetadata) -> bool: - return common_attn_metadata.max_query_len == 1 - class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): """ diff --git a/vllm/v1/attention/backends/mla/cutlass_mla.py b/vllm/v1/attention/backends/mla/cutlass_mla.py index b23a8f0a5..6e1e5d653 100644 --- a/vllm/v1/attention/backends/mla/cutlass_mla.py +++ b/vllm/v1/attention/backends/mla/cutlass_mla.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import os -from typing import Optional +from typing import ClassVar, Optional import torch @@ -12,11 +12,19 @@ from vllm.attention.backends.abstract import (AttentionType, from vllm.logger import init_logger from vllm.v1.attention.backends.mla.common import (MLACommonBackend, MLACommonImpl, - MLACommonMetadata) + MLACommonMetadata, + MLACommonMetadataBuilder) +from vllm.v1.attention.backends.utils import AttentionCGSupport logger = init_logger(__name__) +class CutlassMLAMetadataBuilder(MLACommonMetadataBuilder[MLACommonMetadata]): + # enable full CUDA Graph support for decode-only capture + attn_cudagraph_support: ClassVar[ + AttentionCGSupport] = AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE + + class CutlassMLABackend(MLACommonBackend): @staticmethod @@ -27,6 +35,10 @@ class CutlassMLABackend(MLACommonBackend): def get_impl_cls() -> type["CutlassMLAImpl"]: return CutlassMLAImpl + @staticmethod + def get_builder_cls() -> type["CutlassMLAMetadataBuilder"]: + return CutlassMLAMetadataBuilder + class SM100Workspace: diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py index 2b0f52cf8..116744234 100644 --- a/vllm/v1/attention/backends/mla/flashmla.py +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -55,8 +55,8 @@ class FlashMLAMetadata(MLACommonMetadata[FlashMLADecodeMetadata]): class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): - attn_cudagraph_support: ClassVar[AttentionCGSupport] = \ - AttentionCGSupport.PURE_DECODE_ONLY + cudagraph_support: ClassVar[AttentionCGSupport] = \ + AttentionCGSupport.UNIFORM_BATCH def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], vllm_config: VllmConfig, device: torch.device): @@ -73,7 +73,7 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): device_properties = torch.cuda.get_device_properties(self.device) num_sms = device_properties.multi_processor_count - if self.compilation_config.full_cuda_graph: + if self.compilation_config.cudagraph_mode.has_full_cudagraphs(): self.cg_buf_tile_scheduler_metadata = torch.zeros( # Upper bound on size (<= #SMs, TileSchedulerMetaDataSize) # TileSchedulerMetaDataSize = 8 @@ -95,7 +95,10 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): 1, # MQA for the decode path ) - if self.compilation_config.full_cuda_graph: + # TODO: we can disambiguate between decode and mixed-prefill decode here + # so we can only use the persistent buffer if a cudagraph is actually + # being used. + if self.compilation_config.cudagraph_mode.has_full_cudagraphs(): assert self.cg_buf_tile_scheduler_metadata is not None assert self.cg_buf_num_splits is not None diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py index 8b55e1a30..082c7e6f7 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py @@ -65,8 +65,10 @@ class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]): class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): - attn_cudagraph_support: ClassVar[AttentionCGSupport] = \ - AttentionCGSupport.PURE_DECODE_ONLY + # TODO(luka, lucas): audit this as part of: + # https://github.com/vllm-project/vllm/issues/22945 + cudagraph_support: ClassVar[AttentionCGSupport] = \ + AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], vllm_config: VllmConfig, device: torch.device): @@ -82,7 +84,10 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): max_num_pages = max_num_reqs * max_num_pages_per_req # Preparing persistent buffers - if vllm_config.compilation_config.full_cuda_graph: + # TODO: we can disambiguate between decode and mixed-prefill decode here + # so we can only use the persistent buffer if a cudagraph is actually + # being used. + if self.compilation_config.cudagraph_mode.has_full_cudagraphs(): self.paged_kv_indptr = torch.zeros(max_num_reqs + 1, dtype=torch.int32, device=device) @@ -120,7 +125,7 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): block_table_bounds.cumsum(dim=0, dtype=torch.int32) ]) - if self.compilation_config.full_cuda_graph: + if self.compilation_config.cudagraph_mode.has_full_cudagraphs(): num_actual_pages = paged_kv_indices.size(0) diff --git a/vllm/v1/attention/backends/rocm_aiter_fa.py b/vllm/v1/attention/backends/rocm_aiter_fa.py index e8bffbef4..7d09ac0a4 100644 --- a/vllm/v1/attention/backends/rocm_aiter_fa.py +++ b/vllm/v1/attention/backends/rocm_aiter_fa.py @@ -311,11 +311,6 @@ class AiterFlashAttentionMetadataBuilder( ) return attn_metadata - def can_run_in_cudagraph( - self, common_attn_metadata: CommonAttentionMetadata) -> bool: - # Full CUDA Graph always supported (FA2 support checked separately) - return True - def use_cascade_attention(self, *args, **kwargs) -> bool: return False diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index c33afbfeb..48a9af3de 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -58,8 +58,7 @@ class TritonAttentionMetadata: class TritonAttentionMetadataBuilder( AttentionMetadataBuilder[TritonAttentionMetadata]): - attn_cudagraph_support: ClassVar[AttentionCGSupport] = \ - AttentionCGSupport.ALWAYS + cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.ALWAYS def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], vllm_config: VllmConfig, device: torch.device): @@ -132,11 +131,6 @@ class TritonAttentionMetadataBuilder( ) return attn_metadata - def can_run_in_cudagraph( - self, common_attn_metadata: CommonAttentionMetadata) -> bool: - # Full CUDA Graph always supported - return True - class TritonAttentionBackend(AttentionBackend): diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 91eb84245..1c7d08798 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -158,18 +158,21 @@ class AttentionCGSupport(enum.Enum): Here we do not consider the cascade attention, as currently it is never cudagraph supported.""" + ALWAYS = 3 + """Cudagraph always supported; supports mixed-prefill-decode""" + UNIFORM_BATCH = 2 + """Cudagraph supported for batches the only contain query lengths that are + the same, this can be used for spec-decode + i.e. "decodes" are 1 + num_speculative_tokens""" + UNIFORM_SINGLE_TOKEN_DECODE = 1 + """Cudagraph supported for batches the only contain query_len==1 decodes""" NEVER = 0 """NO cudagraph support""" - PURE_DECODE_ONLY = 1 - """Cudagraph supported for pure decode, need to run without - cudagraph for mixed prefill-decode batches""" - ALWAYS = 2 - """Cudagraph always supported""" class AttentionMetadataBuilder(abc.ABC, Generic[M]): - # Does this backend/builder support CUDA Graphs for attention. - attn_cudagraph_support: ClassVar[AttentionCGSupport] = \ + # Does this backend/builder support CUDA Graphs for attention (default: no). + cudagraph_support: ClassVar[AttentionCGSupport] = \ AttentionCGSupport.NEVER # Does this backend/builder reorder the batch? # If not, set this to None. Otherwise set it to the query @@ -199,13 +202,6 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]): """ raise NotImplementedError - def can_run_in_cudagraph( - self, common_attn_metadata: CommonAttentionMetadata) -> bool: - """ - Can this batch (with given metadata) use CUDA Graphs for attention. - """ - return False - def build_for_cudagraph_capture( self, common_attn_metadata: CommonAttentionMetadata) -> M: """ diff --git a/vllm/v1/cudagraph_dispatcher.py b/vllm/v1/cudagraph_dispatcher.py new file mode 100644 index 000000000..02e65820b --- /dev/null +++ b/vllm/v1/cudagraph_dispatcher.py @@ -0,0 +1,120 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Optional + +from vllm.config import CompilationLevel, CUDAGraphMode, VllmConfig +from vllm.forward_context import BatchDescriptor +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +class CudagraphDispatcher: + """ + Runtime cudagraph dispatcher to dispach keys for multiple set of cudagraphs. + + The dispatcher stores two sets of dispatch keys, one for PIECEWISE and one + for FULL cudagraph runtime mode. The keys are initialized depending on + attention support and what cudagraph mode is set in CompilationConfig. The + keys stored in dispatcher are the only source of truth for valid + cudagraphs that can be dispatched at runtime. + + At runtime, the dispatch method generates the runtime cudagraph mode (FULL, + PIECEWISE, or NONE for no cudagraph) and the valid key (batch descriptor) + based on the input key. After dispatching (commuicate via forward context), + the cudagraph wrappers will trust the dispatch key to do either capturing + or replaying (if mode matched), or pass through to the underlying runnable + without cudagraph (if mode no match or mode is NONE). + """ + + def __init__(self, vllm_config: VllmConfig): + self.vllm_config = vllm_config + self.compilation_config = vllm_config.compilation_config + self.cudagraph_mode = self.compilation_config.cudagraph_mode + + # Dict to store valid cudagraph dispatching keys. + self.cudagraph_keys: dict[CUDAGraphMode, set[BatchDescriptor]] = { + CUDAGraphMode.PIECEWISE: set(), + CUDAGraphMode.FULL: set(), + } + + assert not self.cudagraph_mode.requires_piecewise_compilation() or \ + (self.compilation_config.level == CompilationLevel.PIECEWISE and + self.compilation_config.splitting_ops_contain_attention()), \ + "Compilation level should be CompilationLevel.PIECEWISE when "\ + "cudagraph_mode piecewise cudagraphs is used, "\ + f"cudagraph_mode={self.cudagraph_mode}, "\ + f"compilation_level={self.compilation_config.level}, "\ + f"splitting_ops={self.compilation_config.splitting_ops}" + + self.keys_initialized = False + + def add_cudagraph_key(self, runtime_mode: CUDAGraphMode, + batch_descriptor: BatchDescriptor): + assert runtime_mode in [CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL], \ + f"Invalid cudagraph runtime mode: {runtime_mode}" + self.cudagraph_keys[runtime_mode].add(batch_descriptor) + + def initialize_cudagraph_keys(self, cudagraph_mode: CUDAGraphMode, + uniform_decode_query_len: int): + # This should be called only after attention backend is initialized. + + # Note: we create all valid keys possible for cudagraph but do not + # guarantee all keys would be used. For example, we create keys for + # piecewise cudagraphs when it is piecewise compilation, which is always + # valid, but for attention backend support unified routine, we may not + # trigger capturing/replaying the piecewise cudagraphs depending on + # CompilationConfig.cudagraph_mode. In addition, if we allow lazy + # capturing in future PR, some keys may never be triggered. + if cudagraph_mode.mixed_mode() != CUDAGraphMode.NONE: + for bs in self.compilation_config.cudagraph_capture_sizes: + self.add_cudagraph_key( + cudagraph_mode.mixed_mode(), + BatchDescriptor(num_tokens=bs, uniform_decode=False)) + + # if decode cudagraph mode is FULL, and we don't already have mixed + # mode full cudagraphs then add them here. + if cudagraph_mode.decode_mode() == CUDAGraphMode.FULL \ + and cudagraph_mode.separate_routine(): + max_num_tokens = uniform_decode_query_len * \ + self.vllm_config.scheduler_config.max_num_seqs + cudagraph_capture_sizes_for_decode = [ + x for x in self.compilation_config.cudagraph_capture_sizes + if x <= max_num_tokens and x >= uniform_decode_query_len + ] + for bs in cudagraph_capture_sizes_for_decode: + self.add_cudagraph_key( + CUDAGraphMode.FULL, + BatchDescriptor(num_tokens=bs, uniform_decode=True)) + self.keys_initialized = True + + def dispatch( + self, batch_descriptor: BatchDescriptor + ) -> tuple[CUDAGraphMode, Optional[BatchDescriptor]]: + """ + Given a batch descriptor, dispatch to a cudagraph mode. + A new batch descriptor is returned as we might dispatch a uniform batch + to a graph that supports a more general batch (uniform to non-uniform). + """ + # if not initialized, just skip dispatching. + if not self.keys_initialized: + logger.warning_once("cudagraph dispatching keys are not " + "initialized. No cudagraph will be used.") + return CUDAGraphMode.NONE, None + + # check if key exists for full cudagraph + if batch_descriptor in self.cudagraph_keys[CUDAGraphMode.FULL]: + return CUDAGraphMode.FULL, batch_descriptor + + # otherwise, check if non-uniform key exists + non_uniform_key = batch_descriptor.non_uniform + if non_uniform_key in self.cudagraph_keys[CUDAGraphMode.FULL]: + return CUDAGraphMode.FULL, non_uniform_key + + # also check if non-uniform key exists for more "general" + # piecewise cudagraph + if non_uniform_key in self.cudagraph_keys[CUDAGraphMode.PIECEWISE]: + return CUDAGraphMode.PIECEWISE, non_uniform_key + + # finally, just return no cudagraphs + return CUDAGraphMode.NONE, None diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index 5ffa55557..29ee0a9df 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -965,7 +965,7 @@ class DPAsyncMPClient(AsyncMPClient): # List of [waiting, running] pair per engine. # Used only by DPLBAsyncMPClient subclass. - self.lb_engines: list[list[int]] = [] + self.lb_engines: list[list[int]] = [[0, 0] for _ in self.core_engines] self.first_req_sock_addr = get_open_zmq_inproc_path() self.first_req_send_socket = self.resources.first_req_send_socket = ( @@ -1121,10 +1121,8 @@ class DPLBAsyncMPClient(DPAsyncMPClient): def get_core_engine_for_request( self, request: EngineCoreRequest) -> EngineIdentity: # Engines are in rank order. - current_counts = self.lb_engines if (eng_index := request.data_parallel_rank) is None: - if not current_counts: - return self.core_engine + current_counts = self.lb_engines # TODO use P2C alg for larger DP sizes num_engines = len(current_counts) min_score = sys.maxsize diff --git a/vllm/v1/engine/output_processor.py b/vllm/v1/engine/output_processor.py index 3be6c4821..2ee55b585 100644 --- a/vllm/v1/engine/output_processor.py +++ b/vllm/v1/engine/output_processor.py @@ -107,6 +107,7 @@ class RequestState: self.max_tokens_param = max_tokens_param self.is_prefilling = True self.queue = queue + self.num_cached_tokens = 0 self.stats = RequestStateStats( arrival_time=arrival_time) if log_stats else None @@ -167,7 +168,6 @@ class RequestState: finish_reason: Optional[FinishReason], stop_reason: Union[int, str, None], kv_transfer_params: Optional[dict[str, Any]] = None, - num_cached_tokens: int = 0, ) -> Optional[Union[RequestOutput, PoolingRequestOutput]]: finished = finish_reason is not None @@ -195,7 +195,7 @@ class RequestState: return None return self._new_request_output(request_id, outputs, finished, - kv_transfer_params, num_cached_tokens) + kv_transfer_params) def _new_request_output( self, @@ -203,14 +203,14 @@ class RequestState: outputs: Union[list[CompletionOutput], list[PoolingOutput]], finished: bool, kv_transfer_params: Optional[dict[str, Any]] = None, - num_cached_tokens: int = 0, ) -> Union[RequestOutput, PoolingRequestOutput]: - if isinstance(outputs[0], PoolingOutput): + first_output = outputs[0] + if isinstance(first_output, PoolingOutput): assert len(outputs) == 1 return PoolingRequestOutput( request_id=request_id, - outputs=outputs[0], + outputs=first_output, prompt_token_ids=self.prompt_token_ids, finished=finished, ) @@ -229,7 +229,7 @@ class RequestState: outputs=cast(list[CompletionOutput], outputs), finished=finished, kv_transfer_params=kv_transfer_params, - num_cached_tokens=num_cached_tokens, + num_cached_tokens=self.num_cached_tokens, ) def _new_completion_output( @@ -308,11 +308,18 @@ class OutputProcessor: if req_state is not None: self.lora_states.abort_request(req_state) request_ids_to_abort.append(request_id) - else: - parent = self.parent_requests.pop(request_id, None) - if parent and parent.child_requests: - self.abort_requests(parent.child_requests) - request_ids_to_abort.extend(parent.child_requests) + # Produce final abort output. + if req_state.queue is not None and ( + request_output := req_state.make_request_output( + [], None, FinishReason.ABORT, None, None)): + req_state.queue.put(request_output) + elif parent := self.parent_requests.get(request_id): + # Abort children prior to removing the parent. + if parent.child_requests: + child_reqs = list(parent.child_requests) + child_reqs = self.abort_requests(child_reqs) + request_ids_to_abort.extend(child_reqs) + self.parent_requests.pop(request_id, None) return request_ids_to_abort def add_request( @@ -390,7 +397,7 @@ class OutputProcessor: finish_reason = engine_core_output.finish_reason stop_reason = engine_core_output.stop_reason kv_transfer_params = engine_core_output.kv_transfer_params - num_cached_tokens = engine_core_output.num_cached_tokens + req_state.num_cached_tokens = engine_core_output.num_cached_tokens req_state.is_prefilling = False if pooling_output is None: @@ -411,7 +418,7 @@ class OutputProcessor: # 4) Create and handle RequestOutput objects. if request_output := req_state.make_request_output( new_token_ids, pooling_output, finish_reason, stop_reason, - kv_transfer_params, num_cached_tokens): + kv_transfer_params): if req_state.queue is not None: # AsyncLLM: put into queue for handling by generate(). req_state.queue.put(request_output) diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index 4ff96f978..429416afa 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -182,14 +182,15 @@ class SlidingWindowSpec(AttentionSpec): @dataclass(frozen=True) class MambaSpec(KVCacheSpec): shapes: tuple[tuple[int, ...], ...] - dtype: torch.dtype + dtypes: tuple[torch.dtype] page_size_padded: Optional[int] = None mamba_type: str = "mamba2" @property def page_size_bytes(self) -> int: - num_elements = sum(prod(shape) for shape in self.shapes) - page_size = num_elements * get_dtype_size(self.dtype) + page_size = sum( + prod(shape) * get_dtype_size(dtype) + for (shape, dtype) in zip(self.shapes, self.dtypes)) if self.page_size_padded is not None: assert self.page_size_padded >= page_size return self.page_size_padded diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 8fb964184..9460d91c5 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -21,7 +21,9 @@ from vllm.attention import Attention, AttentionType from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.layers.chunked_local_attention import ChunkedLocalAttention from vllm.compilation.counter import compilation_counter -from vllm.config import (CompilationLevel, VllmConfig, +from vllm.compilation.cuda_graph import CUDAGraphWrapper +from vllm.compilation.monitor import set_cudagraph_capturing_enabled +from vllm.config import (CompilationLevel, CUDAGraphMode, VllmConfig, get_layers_from_vllm_config, update_config) from vllm.distributed.eplb.eplb_state import EplbState from vllm.distributed.kv_transfer import (get_kv_transfer_group, @@ -29,7 +31,8 @@ from vllm.distributed.kv_transfer import (get_kv_transfer_group, from vllm.distributed.parallel_state import ( get_pp_group, get_tp_group, graph_capture, is_global_first_rank, prepare_communication_buffer_for_model) -from vllm.forward_context import DPMetadata, set_forward_context +from vllm.forward_context import (BatchDescriptor, DPMetadata, + set_forward_context) from vllm.logger import init_logger from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaBase from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding @@ -48,13 +51,15 @@ from vllm.sampling_params import SamplingType from vllm.sequence import IntermediateTensors, PoolerOutput from vllm.tasks import GenerationTask, PoolingTask, SupportedTask from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, - GiB_bytes, LazyLoader, check_use_alibi, get_dtype_size, - is_pin_memory_available, round_up, supports_dynamo) + GiB_bytes, LazyLoader, cdiv, check_use_alibi, + get_dtype_size, is_pin_memory_available, round_up, + supports_dynamo) from vllm.v1.attention.backends.mamba_selectors import get_mamba_attn_backend from vllm.v1.attention.backends.utils import ( AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata, make_kv_sharing_fast_prefill_attention_metadata, reorder_batch_to_split_decodes_and_prefills) +from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher from vllm.v1.kv_cache_interface import (AttentionSpec, ChunkedLocalAttentionSpec, FullAttentionSpec, KVCacheConfig, @@ -218,11 +223,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): is_spec_decode=bool(self.vllm_config.speculative_config), ) - self.use_cuda_graph = ( - self.vllm_config.compilation_config.level - == CompilationLevel.PIECEWISE - and self.vllm_config.compilation_config.use_cudagraph - and not self.model_config.enforce_eager) # TODO(woosuk): Provide an option to tune the max cudagraph batch size. # The convention is different. # self.cudagraph_batch_sizes sorts in ascending order. @@ -230,8 +230,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.cudagraph_batch_sizes = list( reversed(self.compilation_config.cudagraph_capture_sizes)) - self.full_cuda_graph = self.compilation_config.full_cuda_graph - # Cache the device properties. self._init_device_properties() @@ -326,6 +324,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.kv_sharing_fast_prefill_logits_indices = torch.zeros( self.max_num_tokens, dtype=torch.int32, device=self.device) + self.uniform_decode_query_len = 1 if not self.speculative_config else \ + 1 + self.speculative_config.num_speculative_tokens + + # Cudagraph dispatcher for runtime cudagraph dispatching. + self.cudagraph_dispatcher = CudagraphDispatcher(self.vllm_config) + self.mm_budget = (MultiModalBudget( self.model_config, self.scheduler_config, @@ -471,7 +475,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): assert (task := pooling_params.task) is not None, ( "You did not set `task` in the API") - model = cast(VllmModelForPooling, self.model) + model = cast(VllmModelForPooling, self.get_model()) to_update = model.pooler.get_pooling_updates(task) to_update.apply(pooling_params) @@ -679,13 +683,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): def _prepare_inputs( self, scheduler_output: "SchedulerOutput", - ) -> tuple[dict[str, - Any], bool, torch.Tensor, Optional[SpecDecodeMetadata], - np.ndarray, Optional[CommonAttentionMetadata]]: + ) -> tuple[dict[str, Any], torch.Tensor, Optional[SpecDecodeMetadata], + np.ndarray, Optional[CommonAttentionMetadata], int]: """ :return: tuple[ attn_metadata: layer-to-attention_metadata mapping, - attention_cuda_graphs: whether attention can run in cudagraph logits_indices, spec_decode_metadata ] """ @@ -820,7 +822,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # valid, we fill the padded indices with the last index. self.kv_sharing_fast_prefill_logits_indices[num_logits:].fill_( logits_indices[-1].item()) - if (self.use_cuda_graph + if (self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE and num_logits <= self.cudagraph_batch_sizes[-1]): # Use piecewise CUDA graphs. # Add padding to the batch size. @@ -925,17 +927,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): continue attn_metadata[layer_name] = attn_metadata_i - attention_cuda_graphs = all( - g.metadata_builder.can_run_in_cudagraph(common_attn_metadata) - for g in self._attn_group_iterator()) - # Hot-Swap lora model if self.lora_config: self.set_active_loras(self.input_batch, num_scheduled_tokens) - return (attn_metadata, attention_cuda_graphs, logits_indices, - spec_decode_metadata, num_scheduled_tokens, - spec_decode_common_attn_metadata) + return (attn_metadata, logits_indices, spec_decode_metadata, + num_scheduled_tokens, spec_decode_common_attn_metadata, + max_num_scheduled_tokens) def _compute_cascade_attn_prefix_len( self, @@ -1259,6 +1257,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): return mm_embeds def get_model(self) -> nn.Module: + # get raw model out of the cudagraph wrapper. + if isinstance(self.model, CUDAGraphWrapper): + return self.model.unwrap() return self.model def get_supported_generation_tasks(self) -> list[GenerationTask]: @@ -1415,9 +1416,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): return assert self.eplb_state is not None - assert is_mixture_of_experts(self.model) + model = self.get_model() + assert is_mixture_of_experts(model) self.eplb_state.step( - self.model, + model, is_dummy, is_profile, log_stats=self.parallel_config.eplb_log_balancedness, @@ -1507,15 +1509,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.vllm_config) # Prepare the decoder inputs. - (attn_metadata, attention_cuda_graphs, logits_indices, - spec_decode_metadata, num_scheduled_tokens_np, - spec_decode_common_attn_metadata) = ( - self._prepare_inputs(scheduler_output)) + (attn_metadata, logits_indices, spec_decode_metadata, + num_scheduled_tokens_np, spec_decode_common_attn_metadata, + max_query_len) = (self._prepare_inputs(scheduler_output)) num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens - if (self.use_cuda_graph + if (self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]): - # Use piecewise CUDA graphs. + # Use CUDA graphs. # Add padding to the batch size. num_input_tokens = self.vllm_config.pad_for_cudagraph( num_scheduled_tokens) @@ -1581,10 +1582,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): intermediate_tensors = self.sync_and_slice_intermediate_tensors( num_input_tokens, intermediate_tensors, True) - # Some attention backends only support CUDA Graphs in pure decode. - # If attention doesn't support CUDA Graphs for this batch, but we - # compiled with full CUDA graphs, we have to skip them entirely. - skip_cuda_graphs = self.full_cuda_graph and not attention_cuda_graphs + uniform_decode = (max_query_len == self.uniform_decode_query_len) and ( + num_scheduled_tokens == self.input_batch.num_reqs * max_query_len) + batch_descriptor = BatchDescriptor(num_tokens=num_input_tokens, + uniform_decode=uniform_decode) + cudagraph_runtime_mode, batch_descriptor = \ + self.cudagraph_dispatcher.dispatch(batch_descriptor) # Run the model. # Use persistent buffers for CUDA graphs. @@ -1593,10 +1596,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.vllm_config, num_tokens=num_input_tokens, num_tokens_across_dp=num_tokens_across_dp, - skip_cuda_graphs=skip_cuda_graphs, + cudagraph_runtime_mode=cudagraph_runtime_mode, + batch_descriptor=batch_descriptor, ), self.maybe_get_kv_connector_output( scheduler_output) as kv_connector_output: - model_output = self.model( input_ids=input_ids, positions=positions, @@ -2021,20 +2024,31 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.model.compile( fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE, backend=backend) + return + # for other compilation levels, cudagraph behavior is controlled by + # CudagraphWraper and CudagraphDispatcher of vllm. + + # wrap the model with full cudagraph wrapper if needed. + if self.compilation_config.cudagraph_mode.has_full_cudagraphs(): + self.model = CUDAGraphWrapper(self.model, + self.vllm_config, + runtime_mode=CUDAGraphMode.FULL) def reload_weights(self) -> None: assert getattr(self, "model", None) is not None, \ "Cannot reload weights before model is loaded." model_loader = get_model_loader(self.load_config) logger.info("Reloading weights inplace...") - model_loader.load_weights(self.model, model_config=self.model_config) + model = self.get_model() + model_loader.load_weights(model, model_config=self.model_config) def save_tensorized_model( self, tensorizer_config: "TensorizerConfig", ) -> None: + model = self.get_model() TensorizerLoader.save_model( - self.model, + model, tensorizer_config=tensorizer_config, model_config=self.model_config, ) @@ -2210,31 +2224,82 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): def _dummy_run( self, num_tokens: int, - capture_attn_cudagraph: bool = False, + cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, + force_attention: bool = False, + uniform_decode: bool = False, skip_eplb: bool = False, is_profile: bool = False, ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Run a dummy forward pass to warm up/profile run or capture the + CUDA graph for the model. + + Args: + num_tokens: Number of tokens to run the dummy forward pass. + cudagraph_runtime_mode: used to control the behavior. + - CUDAGraphMode.NONE: No cudagraph, for warm up and profile run + - CUDAGraphMode.PIECEWISE: Piecewise cudagraph. + - CUDAGraphMode.FULL: Full cudagraph, attention metadata is + needed. + force_attention: If True, always create attention metadata. Used to + warm up attention backend when mode is NONE. + uniform_decode: If True, the batch is a uniform decode batch. + skip_eplb: If True, skip EPLB state update. + is_profile: If True, this is a profile run. + """ + assert cudagraph_runtime_mode in { + CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL + } # Padding for DP num_pad, num_tokens_across_dp = self.get_dp_padding(num_tokens) num_tokens += num_pad + # If cudagraph_mode.decode_mode() == FULL and + # cudagraph_mode.seperate_routine(). This means that we are using + # different graphs and/or modes for mixed prefill-decode batches vs. + # uniform decode batches. A uniform decode batch means that all + # requests have identical query length, except a potential virtual + # request (shorter) in the batch account for padding. + # Uniform decode batch could either be common pure decode, where + # max_query_len == 1, or speculative decode, where + # max_query_len == 1 + num_spec_decode_tokens. + + # When setting max_query_len = 1, we switch to and capture the optimized + # routine of FA2 for pure decode, i.e., Flashdecode + an optimization + # for GQA/MQA. + max_query_len = self.uniform_decode_query_len if uniform_decode else \ + num_tokens + # Set num_scheduled_tokens based on num_tokens and max_num_seqs # for dummy run with LoRA so that the num_reqs collectively # has num_tokens in total. assert num_tokens <= self.scheduler_config.max_num_batched_tokens max_num_reqs = self.scheduler_config.max_num_seqs - num_reqs = min(num_tokens, max_num_reqs) - min_tokens_per_req = num_tokens // num_reqs - num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs - num_scheduled_tokens_list[-1] += num_tokens % num_reqs + if uniform_decode: + num_reqs = cdiv(num_tokens, max_query_len) + assert num_reqs <= max_num_reqs, \ + "Do not capture num_reqs > max_num_reqs for uniform batch" + num_scheduled_tokens_list = [max_query_len] * num_reqs + if num_tokens % max_query_len != 0: + num_scheduled_tokens_list[-1] = num_tokens % max_query_len + else: + num_reqs = min(num_tokens, max_num_reqs) + min_tokens_per_req = num_tokens // num_reqs + num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs + num_scheduled_tokens_list[-1] += num_tokens % num_reqs + assert sum(num_scheduled_tokens_list) == num_tokens assert len(num_scheduled_tokens_list) == num_reqs num_scheduled_tokens = np.array(num_scheduled_tokens_list, dtype=np.int32) attn_metadata: Optional[dict[str, Any]] = None - if capture_attn_cudagraph: + + # If force_attention is True, we always capture attention. Otherwise, + # it only happens for cudagraph_runtime_mode=FULL. + if force_attention or cudagraph_runtime_mode == \ + CUDAGraphMode.FULL: attn_metadata = {} # Make sure max_model_len is used at the graph capture time. @@ -2255,7 +2320,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): num_computed_tokens_cpu_tensor[:num_reqs], num_reqs=num_reqs, num_actual_tokens=num_tokens, - max_query_len=num_tokens, + max_query_len=max_query_len, block_table_tensor=self.input_batch.block_table[ kv_cache_group_id].get_device_tensor()[:num_reqs], slot_mapping=self.input_batch. @@ -2299,12 +2364,26 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): intermediate_tensors = self.sync_and_slice_intermediate_tensors( num_tokens, None, False) + if cudagraph_runtime_mode == CUDAGraphMode.NONE: + batch_descriptor = None + else: + # filter out the valid batch descriptor + _cg_mode, batch_descriptor = \ + self.cudagraph_dispatcher.dispatch( + BatchDescriptor(num_tokens=num_tokens, + uniform_decode=uniform_decode)) + # sanity check + assert cudagraph_runtime_mode == _cg_mode, ( + f"Cudagraph runtime mode mismatch at dummy_run. " + f"Expected {_cg_mode}, but got {cudagraph_runtime_mode}.") with self.maybe_randomize_inputs(input_ids), set_forward_context( attn_metadata, self.vllm_config, num_tokens=num_tokens, - num_tokens_across_dp=num_tokens_across_dp): + num_tokens_across_dp=num_tokens_across_dp, + cudagraph_runtime_mode=cudagraph_runtime_mode, + batch_descriptor=batch_descriptor): outputs = self.model( input_ids=input_ids, positions=positions, @@ -2436,7 +2515,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): dtype=torch.int32, device=self.device) - model = cast(VllmModelForPooling, self.model) + model = cast(VllmModelForPooling, self.get_model()) dummy_pooling_params = PoolingParams(task=task) to_update = model.pooler.get_pooling_updates(task) to_update.apply(dummy_pooling_params) @@ -2479,50 +2558,56 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): def profile_run(self) -> None: # Profile with multimodal encoder & encoder cache. if self.supports_mm_inputs: - mm_budget = self.mm_budget - assert mm_budget is not None - - # TODO: handle encoder-decoder models once we support them. - if (encoder_budget := mm_budget.get_encoder_budget()) > 0: - # NOTE: Currently model is profiled with a single non-text - # modality with the max possible input tokens even when - # it supports multiple. - ( - dummy_modality, - max_tokens, - ) = mm_budget.get_modality_with_max_tokens() - ( - max_mm_items_per_prompt, - max_mm_items_per_batch, - ) = mm_budget.get_max_items(dummy_modality, max_tokens) - + if self.model_config.multimodal_config.skip_mm_profiling: logger.info( - "Encoder cache will be initialized with a budget of " - "%s tokens, and profiled with %s %s items of the maximum " - "feature size.", - encoder_budget, - max_mm_items_per_batch, - dummy_modality, - ) + "Skipping memory profiling for multimodal encoder and " + "encoder cache.") + else: + mm_budget = self.mm_budget + assert mm_budget is not None - # Create dummy batch of multimodal inputs. - batched_dummy_mm_inputs = self._get_mm_dummy_batch( - dummy_modality, - max_mm_items_per_batch, - ) + # TODO: handle encoder-decoder models once we support them. + if (encoder_budget := mm_budget.get_encoder_budget()) > 0: + # NOTE: Currently model is profiled with a single non-text + # modality with the max possible input tokens even when + # it supports multiple. + ( + dummy_modality, + max_tokens, + ) = mm_budget.get_modality_with_max_tokens() + ( + max_mm_items_per_prompt, + max_mm_items_per_batch, + ) = mm_budget.get_max_items(dummy_modality, max_tokens) - # Run multimodal encoder. - dummy_encoder_outputs = self.model.get_multimodal_embeddings( - **batched_dummy_mm_inputs) + logger.info( + "Encoder cache will be initialized with a budget of " + "%s tokens, and profiled with %s %s items of the " + "maximum feature size.", + encoder_budget, + max_mm_items_per_batch, + dummy_modality, + ) - sanity_check_mm_encoder_outputs( - dummy_encoder_outputs, - expected_num_items=max_mm_items_per_batch, - ) + # Create dummy batch of multimodal inputs. + batched_dummy_mm_inputs = self._get_mm_dummy_batch( + dummy_modality, + max_mm_items_per_batch, + ) - # Cache the dummy encoder outputs. - self.encoder_cache["tmp"] = dict( - enumerate(dummy_encoder_outputs)) + # Run multimodal encoder. + dummy_encoder_outputs = \ + self.model.get_multimodal_embeddings( + **batched_dummy_mm_inputs) + + sanity_check_mm_encoder_outputs( + dummy_encoder_outputs, + expected_num_items=max_mm_items_per_batch, + ) + + # Cache the dummy encoder outputs. + self.encoder_cache["tmp"] = dict( + enumerate(dummy_encoder_outputs)) # Add `is_profile` here to pre-allocate communication buffers hidden_states, last_hidden_states \ @@ -2540,12 +2625,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): gc.collect() def capture_model(self) -> None: - if not self.use_cuda_graph: + if self.compilation_config.cudagraph_mode == CUDAGraphMode.NONE: logger.warning( "Skipping CUDA graph capture. To turn on CUDA graph capture, " - "set -O %s and ensure `use_cudagraph` was not manually set to " - "False", CompilationLevel.PIECEWISE) + "ensure `cudagraph_mode` was not manually set to `NONE`") return + else: + self.initialize_cudagraph_capture() compilation_counter.num_gpu_runner_capture_triggers += 1 @@ -2570,25 +2656,41 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Trigger CUDA graph capture for specific shapes. # Capture the large shapes first so that the smaller shapes # can reuse the memory pool allocated for the large shapes. + set_cudagraph_capturing_enabled(True) with freeze_gc(), graph_capture(device=self.device): - full_cg = self.full_cuda_graph - # Only rank 0 should print progress bar during capture - compilation_cases = reversed(self.cudagraph_batch_sizes) - if is_global_first_rank(): - compilation_cases = tqdm( - list(compilation_cases), - disable=not self.load_config.use_tqdm_on_load, - desc="Capturing CUDA graph shapes") - for num_tokens in compilation_cases: - # We skip EPLB here since we don't want to record dummy metrics - for _ in range( - self.compilation_config.cudagraph_num_of_warmups): - self._dummy_run(num_tokens, - capture_attn_cudagraph=full_cg, - skip_eplb=True) - self._dummy_run(num_tokens, - capture_attn_cudagraph=full_cg, - skip_eplb=True) + cudagraph_mode = self.compilation_config.cudagraph_mode + if cudagraph_mode.mixed_mode() != CUDAGraphMode.NONE: + cudagraph_runtime_mode = cudagraph_mode.mixed_mode() + + compilation_cases = list(reversed(self.cudagraph_batch_sizes)) + self._capture_cudagraphs( + compilation_cases, + cudagraph_runtime_mode=cudagraph_runtime_mode, + uniform_decode=False) + + # Capture full cudagraph for uniform decode batches if we have + # dont already have full mixed prefill-decode cudagraphs + if cudagraph_mode.decode_mode() == CUDAGraphMode.FULL and \ + cudagraph_mode.separate_routine(): + max_num_tokens = self.scheduler_config.max_num_seqs * \ + self.uniform_decode_query_len + decode_cudagraph_batch_sizes = [ + x for x in self.cudagraph_batch_sizes if + x <= max_num_tokens and x >= self.uniform_decode_query_len + ] + compilation_cases_decode = list( + reversed(decode_cudagraph_batch_sizes)) + self._capture_cudagraphs( + compilation_cases=compilation_cases_decode, + cudagraph_runtime_mode=CUDAGraphMode.FULL, + uniform_decode=True) + + # Disable cudagraph capturing globally, so any unexpected cudagraph + # capturing will be detected and raise an error after here. + # Note: We don't put it into graph_capture context manager because + # we may doing lazy capturing in future that still allows capturing + # after here. + set_cudagraph_capturing_enabled(False) end_time = time.perf_counter() end_free_gpu_memory = torch.cuda.mem_get_info()[0] @@ -2598,6 +2700,41 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): logger.info("Graph capturing finished in %.0f secs, took %.2f GiB", elapsed_time, cuda_graph_size / (1 << 30)) + def _capture_cudagraphs(self, compilation_cases: list[int], + cudagraph_runtime_mode: CUDAGraphMode, + uniform_decode: bool): + assert cudagraph_runtime_mode != CUDAGraphMode.NONE and \ + cudagraph_runtime_mode in [CUDAGraphMode.FULL, + CUDAGraphMode.PIECEWISE] + + # Only rank 0 should print progress bar during capture + if is_global_first_rank(): + compilation_cases = tqdm( + compilation_cases, + disable=not self.load_config.use_tqdm_on_load, + desc="Capturing CUDA graphs ({}, {})".format( + "decode" if uniform_decode else "mixed prefill-decode", + cudagraph_runtime_mode.name)) + # We skip EPLB here since we don't want to record dummy metrics + for num_tokens in compilation_cases: + for _ in range(self.compilation_config.cudagraph_num_of_warmups): + # Use CUDAGraphRuntimeStyle.NONE (default) for warmup. + # But be careful, warm up with `NONE`is orthogonal to + # if we want to warm up attention or not. This is + # different from the case where `FULL` implies capture + # attention while `PIECEWISE` implies no attention. + force_attention = ( + cudagraph_runtime_mode == CUDAGraphMode.FULL) + self._dummy_run(num_tokens, + cudagraph_runtime_mode=CUDAGraphMode.NONE, + force_attention=force_attention, + uniform_decode=uniform_decode, + skip_eplb=True) + self._dummy_run(num_tokens, + cudagraph_runtime_mode=cudagraph_runtime_mode, + uniform_decode=uniform_decode, + skip_eplb=True) + def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None: """ Initialize the attention backends and attention metadata builders. @@ -2642,25 +2779,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): attn_metadata_builder_i, layer_names) attn_groups.append(attn_group) - - if self.full_cuda_graph: - if attn_metadata_builder_i.attn_cudagraph_support == \ - AttentionCGSupport.NEVER: - raise ValueError( - f"Full CUDAGraph not supported for " - f"{attn_backend.__name__}. Turn off " - f"CompilationConfig.full_cuda_graph or use a " - f" different attention backend.") - if attn_metadata_builder_i.attn_cudagraph_support == \ - AttentionCGSupport.PURE_DECODE_ONLY: - # Limit the max cudagraph size to the max number of - # sequences for pure decode only cudagraph backend, - # whose max_query_len is 1. - self.cudagraph_batch_sizes = [ - size for size in self.cudagraph_batch_sizes - if size <= self.scheduler_config.max_num_seqs - ] - return attn_groups for kv_cache_group_spec in kv_cache_config.kv_cache_groups: @@ -2728,6 +2846,75 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): "All or none of the layers are expected to be encoder-only" self.is_encoder_only_model = True + def initialize_cudagraph_capture(self) -> None: + min_cg_support = AttentionCGSupport.ALWAYS + min_cg_builder_name = None + + for attn_group in self._attn_group_iterator(): + builder = attn_group.metadata_builder + if builder.cudagraph_support.value < min_cg_support.value: + min_cg_support = builder.cudagraph_support + min_cg_builder_name = builder.__class__.__name__ + + # Flexible resolve the cudagraph mode + cudagraph_mode = self.compilation_config.cudagraph_mode + # check cudagraph for mixed batch is supported + if cudagraph_mode.mixed_mode() == CUDAGraphMode.FULL \ + and min_cg_support != AttentionCGSupport.ALWAYS: + msg = (f"CUDAGraphMode.{cudagraph_mode.name} is not supported " + f"with {min_cg_builder_name} backend (support: " + f"{min_cg_support})") + if min_cg_support == AttentionCGSupport.NEVER: + # if not supported any full cudagraphs, just raise it. + msg += "; please try cudagraph_mode=PIECEWISE, and "\ + "make sure compilation level is piecewise" + raise ValueError(msg) + + # attempt to resolve the full cudagraph related mode + if self.compilation_config.splitting_ops_contain_attention(): + msg += "; setting cudagraph_mode=FULL_AND_PIECEWISE" + cudagraph_mode = self.compilation_config.cudagraph_mode = \ + CUDAGraphMode.FULL_AND_PIECEWISE + else: + msg += "; setting cudagraph_mode=FULL_DECODE_ONLY" + cudagraph_mode = self.compilation_config.cudagraph_mode = \ + CUDAGraphMode.FULL_DECODE_ONLY + logger.warning(msg) + + # check that if we are doing spec-decode + decode full-cudagraphs it is + # supported + if (cudagraph_mode.decode_mode() == CUDAGraphMode.FULL + and self.uniform_decode_query_len > 1 and min_cg_support.value + < AttentionCGSupport.UNIFORM_BATCH.value): + msg = (f"CUDAGraphMode.{cudagraph_mode.name} is not supported" + f" with spec-decode for attention backend " + f"{min_cg_builder_name} (support: {min_cg_support})") + if self.compilation_config.splitting_ops_contain_attention(): + msg += "; setting cudagraph_mode=PIECEWISE" + cudagraph_mode = self.compilation_config.cudagraph_mode = \ + CUDAGraphMode.PIECEWISE + else: + msg += "; setting cudagraph_mode=NONE" + cudagraph_mode = self.compilation_config.cudagraph_mode = \ + CUDAGraphMode.NONE + logger.warning(msg) + + # double check that we can support full cudagraph if they are requested + # even after automatic downgrades + if cudagraph_mode.has_full_cudagraphs() \ + and min_cg_support == AttentionCGSupport.NEVER: + raise ValueError(f"CUDAGraphMode.{cudagraph_mode.name} is not " + f"supported with {min_cg_builder_name} backend (" + f"support:{min_cg_support}) " + "; please try cudagraph_mode=PIECEWISE, " + "and make sure compilation level is piecewise") + + # Trigger cudagraph dispatching keys initialization here (after + # initializing attn backends). + self.cudagraph_dispatcher.initialize_cudagraph_keys( + self.compilation_config.cudagraph_mode, + self.uniform_decode_query_len) + def calculate_reorder_batch_threshold(self) -> None: """ Check that if any backends reorder batches; that the reordering @@ -2878,23 +3065,25 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): elif isinstance(kv_cache_spec, MambaSpec): has_mamba = True raw_tensor = kv_cache_raw_tensors[layer_name] - dtype = kv_cache_spec.dtype - num_element_per_page = (kv_cache_spec.page_size_bytes // - get_dtype_size(dtype)) state_tensors = [] - storage_offset = 0 - for shape in kv_cache_spec.shapes: + storage_offset_bytes = 0 + for (shape, dtype) in zip(kv_cache_spec.shapes, + kv_cache_spec.dtypes): + dtype_size = get_dtype_size(dtype) + num_element_per_page = ( + kv_cache_spec.page_size_bytes // dtype_size) target_shape = (num_blocks, *shape) stride = torch.empty(target_shape).stride() target_stride = (num_element_per_page, *stride[1:]) + assert storage_offset_bytes % dtype_size == 0 tensor = torch.as_strided( raw_tensor.view(dtype), size=target_shape, stride=target_stride, - storage_offset=storage_offset, + storage_offset=storage_offset_bytes // dtype_size, ) state_tensors.append(tensor) - storage_offset += stride[0] + storage_offset_bytes += stride[0] * dtype_size kv_caches[layer_name] = state_tensors else: @@ -3081,7 +3270,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): for layer_name, mamba_module in mamba_layers.items(): kv_cache_spec[layer_name] = MambaSpec( shapes=mamba_module.get_state_shape(), - dtype=self.kv_cache_dtype, + dtypes=mamba_module.get_state_dtype(), block_size=max_model_len, page_size_padded=page_size_padded, mamba_type=mamba_module.mamba_type) diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 0ea23921a..04de8d366 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -310,6 +310,7 @@ class Worker(WorkerBase): for size in sorted(warmup_sizes, reverse=True): logger.info("Compile and warming up model for size %d", size) self.model_runner._dummy_run(size, skip_eplb=True) + if not self.model_config.enforce_eager: self.model_runner.capture_model() @@ -321,16 +322,11 @@ class Worker(WorkerBase): if get_pp_group().is_last_rank: max_num_reqs = min(self.scheduler_config.max_num_seqs, self.scheduler_config.max_num_batched_tokens) - # activate building attn_metadata for this dummy run to avoid - # potential illegal memory access for full cudagraph relay. - attn_cudagraph = self.compilation_config.full_cuda_graph and\ - not self.model_config.enforce_eager # We skip EPLB here since we don't want to record dummy metrics hidden_states, last_hidden_states = \ self.model_runner._dummy_run( num_tokens=max_num_reqs, - capture_attn_cudagraph=attn_cudagraph, skip_eplb=True, ) if self.model_runner.is_pooling_model: @@ -340,8 +336,7 @@ class Worker(WorkerBase): hidden_states=last_hidden_states) # Warmup kernels used during model execution - kernel_warmup(self.get_model(), - max_tokens=self.scheduler_config.max_num_batched_tokens) + kernel_warmup(self) # Reset the seed to ensure that the random state is not affected by # the model initialization and profiling. diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 46262284e..f7e68edba 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -1529,60 +1529,66 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ) -> None: # Profile with multimodal encoder & encoder cache. if self.supports_mm_inputs: - mm_budget = self.mm_budget - assert mm_budget is not None - - # TODO: handle encoder-decoder models once we support them. - if (encoder_budget := mm_budget.get_encoder_budget()) > 0: - # NOTE: Currently model is profiled with a single non-text - # modality with the max possible input tokens even when - # it supports multiple. - ( - dummy_modality, - max_tokens, - ) = mm_budget.get_modality_with_max_tokens() - ( - max_mm_items_per_prompt, - max_mm_items_per_batch, - ) = mm_budget.get_max_items(dummy_modality, max_tokens) - + if self.model_config.multimodal_config.skip_mm_profiling: logger.info( - "Encoder cache will be initialized with a budget of " - "%s tokens, and profiled with %s %s items of the maximum " - "feature size.", - encoder_budget, - max_mm_items_per_batch, - dummy_modality, - ) + "Skipping memory profiling for multimodal encoder and " + "encoder cache.") + else: + mm_budget = self.mm_budget + assert mm_budget is not None - # Create dummy batch of multimodal inputs. - batched_dummy_mm_inputs = self._get_mm_dummy_batch( - dummy_modality, - max_mm_items_per_batch, - ) + # TODO: handle encoder-decoder models once we support them. + if (encoder_budget := mm_budget.get_encoder_budget()) > 0: + # NOTE: Currently model is profiled with a single non-text + # modality with the max possible input tokens even when + # it supports multiple. + ( + dummy_modality, + max_tokens, + ) = mm_budget.get_modality_with_max_tokens() + ( + max_mm_items_per_prompt, + max_mm_items_per_batch, + ) = mm_budget.get_max_items(dummy_modality, max_tokens) - # Run multimodal encoder. - # Isolate encoder graph from post-processing to minimize - # impact of recompilation until it's fixed. - start = time.perf_counter() - xm.mark_step() - dummy_encoder_outputs = self.model.get_multimodal_embeddings( - **batched_dummy_mm_inputs) - xm.mark_step() - xm.wait_device_ops() - end = time.perf_counter() - logger.info( - "Multimodal Encoder profiling finished in in %.2f [secs].", - end - start) + logger.info( + "Encoder cache will be initialized with a budget of " + "%s tokens, and profiled with %s %s items of the " + "maximum feature size.", + encoder_budget, + max_mm_items_per_batch, + dummy_modality, + ) - sanity_check_mm_encoder_outputs( - dummy_encoder_outputs, - expected_num_items=max_mm_items_per_batch, - ) + # Create dummy batch of multimodal inputs. + batched_dummy_mm_inputs = self._get_mm_dummy_batch( + dummy_modality, + max_mm_items_per_batch, + ) - # Cache the dummy encoder outputs. - self.encoder_cache["tmp"] = dict( - enumerate(dummy_encoder_outputs)) + # Run multimodal encoder. + # Isolate encoder graph from post-processing to minimize + # impact of recompilation until it's fixed. + start = time.perf_counter() + xm.mark_step() + dummy_encoder_outputs = \ + self.model.get_multimodal_embeddings( + **batched_dummy_mm_inputs) + xm.mark_step() + xm.wait_device_ops() + end = time.perf_counter() + logger.info( + "Multimodal Encoder profiling finished in %.2f [secs].", + end - start) + + sanity_check_mm_encoder_outputs( + dummy_encoder_outputs, + expected_num_items=max_mm_items_per_batch, + ) + + # Cache the dummy encoder outputs. + self.encoder_cache["tmp"] = dict( + enumerate(dummy_encoder_outputs)) # Trigger compilation for general shape. self._dummy_run(num_tokens, self.num_reqs_max_model_len, diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index a63797e3a..a1c08fa81 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -762,8 +762,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): has Prefills (if any). The rest of the steps are guaranteed to be all decodes. In this case, we set up the padding as if all the sequences are decodes so we may run all steps except the first step in CUDA graph - mode. The padding is accounted for in the multi-step `advance_step` - family of functions. + mode. Args: num_seqs (int): Number of sequences scheduled to run.