diff --git a/tests/lora/test_fused_moe_lora_kernel.py b/tests/lora/test_fused_moe_lora_kernel.py index b2db7968e..3df3a606c 100644 --- a/tests/lora/test_fused_moe_lora_kernel.py +++ b/tests/lora/test_fused_moe_lora_kernel.py @@ -231,17 +231,22 @@ def use_torch( lora_a_stacked, lora_b_stacked, top_k_num, + num_slices=1, ): outputs = [] for i in range(hidden_states.shape[0]): - lora_idx = token_lora_mapping[i] - expert_ids = topk_ids[i] - lora_a = lora_a_stacked[0][lora_idx][expert_ids] - lora_b = lora_b_stacked[0][lora_idx][expert_ids] - tensors = [ - hidden_states[i] @ lora_a[x].T @ lora_b[x].T for x in range(top_k_num) - ] - outputs.append(torch.stack(tensors, dim=0)) + slice_tensors = [] + for slice_id in range(num_slices): + lora_idx = token_lora_mapping[i] + expert_ids = topk_ids[i] + lora_a = lora_a_stacked[slice_id][lora_idx][expert_ids] + lora_b = lora_b_stacked[slice_id][lora_idx][expert_ids] + tensors = [ + hidden_states[i] @ lora_a[x].T @ lora_b[x].T for x in range(top_k_num) + ] + slice_tensors.append(torch.stack(tensors, dim=0)) + + outputs.append(torch.concat(slice_tensors, dim=-1)) return torch.stack(outputs, dim=0) @@ -259,6 +264,7 @@ SEED = [42] @pytest.mark.parametrize("K", [2048]) @pytest.mark.parametrize("max_lora_rank", [16, 32, 64]) @pytest.mark.parametrize("block_size", [16]) +@pytest.mark.parametrize("num_slices", [1, 2]) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("device", DEVICES) @pytest.mark.parametrize("seed", SEED) @@ -271,6 +277,7 @@ def test_fused_moe_lora_kernel( K, max_lora_rank, block_size, + num_slices, dtype, device, seed, @@ -295,17 +302,19 @@ def test_fused_moe_lora_kernel( ), dtype=dtype, ) + for _ in range(num_slices) ] lora_b_stacked = [ torch.rand( ( max_loras, num_experts, - N, + N // num_slices, max_lora_rank, ), dtype=dtype, ) + for _ in range(num_slices) ] hidden_states = torch.rand( ( @@ -340,6 +349,7 @@ def test_fused_moe_lora_kernel( lora_a_stacked, lora_b_stacked, top_k_num, + num_slices, ) torch.testing.assert_close(output, output2, atol=1e-2, rtol=1e-2) @@ -434,6 +444,7 @@ def use_fused_moe_lora_kernel_naive( @pytest.mark.parametrize("K", [2048]) @pytest.mark.parametrize("max_lora_rank", [16, 32]) @pytest.mark.parametrize("block_size", [16]) +@pytest.mark.parametrize("num_slices", [1, 2]) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("device", DEVICES) @pytest.mark.parametrize("seed", SEED) @@ -446,6 +457,7 @@ def test_fused_moe_lora_kernel_naive_block_assignment( K, max_lora_rank, block_size, + num_slices, dtype, device, seed, @@ -484,17 +496,19 @@ def test_fused_moe_lora_kernel_naive_block_assignment( ), dtype=dtype, ) + for _ in range(num_slices) ] lora_b_stacked = [ torch.rand( ( max_loras, num_experts, - N, + N // num_slices, max_lora_rank, ), dtype=dtype, ) + for _ in range(num_slices) ] hidden_states = torch.rand( ( @@ -529,6 +543,7 @@ def test_fused_moe_lora_kernel_naive_block_assignment( lora_a_stacked, lora_b_stacked, top_k_num, + num_slices, ) torch.testing.assert_close(output, output_ref, atol=1e-2, rtol=1e-2) diff --git a/tests/lora/test_olmoe_tp.py b/tests/lora/test_olmoe_tp.py index e10419d24..5e38638b9 100644 --- a/tests/lora/test_olmoe_tp.py +++ b/tests/lora/test_olmoe_tp.py @@ -2,7 +2,11 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import shutil + import pytest +import torch +from safetensors.torch import load_file, save_file import vllm from vllm.lora.request import LoRARequest @@ -122,6 +126,41 @@ def test_olmoe_lora_mixed(olmoe_lora_files): generate_and_test(llm, olmoe_lora_files, lora_id=[1, None, 3, None]) +def test_olmoe_lora_mixed_random(olmoe_lora_files, tmp_path): + # Create a dummy LoRA with random weights based on the real one + random_lora_path = tmp_path / "random_lora" + shutil.copytree(olmoe_lora_files, random_lora_path) + + weights_path = random_lora_path / "adapter_model.safetensors" + weights = load_file(str(weights_path)) + random_weights = {k: torch.randn_like(v) for k, v in weights.items()} + save_file(random_weights, str(weights_path)) + + llm = vllm.LLM( + MODEL_PATH, + max_model_len=1024, + enable_lora=True, + max_loras=4, + enforce_eager=True, + trust_remote_code=True, + enable_chunked_prefill=True, + ) + + prompts = [ + PROMPT_TEMPLATE.format(context="How many candidates are there?"), + PROMPT_TEMPLATE.format(context="Count the number of candidates."), + ] + + lora_requests = [ + LoRARequest("real", 1, olmoe_lora_files), + LoRARequest("random", 2, str(random_lora_path)), + ] + + sampling_params = vllm.SamplingParams(temperature=0, max_tokens=64) + outputs = llm.generate(prompts, sampling_params, lora_request=lora_requests) + assert outputs[0].outputs[0].text.strip().startswith(EXPECTED_LORA_OUTPUT[0]) + + @pytest.mark.parametrize("fully_sharded_loras", [False, True]) @multi_gpu_test(num_gpus=2) def test_olmoe_lora_tp2(olmoe_lora_files, fully_sharded_loras): diff --git a/vllm/lora/ops/triton_ops/fused_moe_lora_op.py b/vllm/lora/ops/triton_ops/fused_moe_lora_op.py index c9c85c194..8072f8769 100644 --- a/vllm/lora/ops/triton_ops/fused_moe_lora_op.py +++ b/vllm/lora/ops/triton_ops/fused_moe_lora_op.py @@ -8,9 +8,10 @@ from vllm.distributed import ( tensor_model_parallel_all_reduce, ) from vllm.triton_utils import tl, triton +from vllm.triton_utils.allocation import set_triton_allocator from vllm.utils.torch_utils import direct_register_custom_op -from .utils import supports_pdl +from .utils import supports_pdl, supports_tma @triton.jit @@ -70,6 +71,37 @@ def _get_token_offs( ) +@triton.jit +def _get_c_ptrs( + cur_c_ptr, + lora_id, + pid_m, + offs, + offs_token, + offs_cn, + stride_cm, + stride_cn, + EM: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + sort_c: tl.constexpr, +): + # When sort_c is true, store the output in c_ptr using token order defined + # in sorted_token_ids_ptr; otherwise, use the original token order from the prompt + if sort_c: + offs_token_id = pid_m * BLOCK_SIZE_M + offs + c_ptrs = ( + cur_c_ptr + + lora_id * EM * stride_cm + + stride_cm * offs_token_id[:, None] + + stride_cn * offs_cn[None, :] + ) + else: + c_ptrs = ( + cur_c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :] + ) + return c_ptrs + + _LORA_PTR_DICT: dict[tuple[int, ...], torch.tensor] = {} @@ -125,7 +157,9 @@ def _adjust_kernel_inputs( ) def _fused_moe_lora_kernel( a_ptr, + a_desc, b_ptr, + b_desc, c_ptr, topk_weights_ptr, sorted_token_ids_ptr, @@ -177,6 +211,18 @@ def _fused_moe_lora_kernel( USE_GDC: tl.constexpr, launch_pdl: tl.constexpr, IS_PRIMARY: tl.constexpr, + USE_TMA: tl.constexpr, + # sort_c determines whether tokens are stored in C in the order determined + # by sorted_token_ids to enable later TMA loads from this tensor. + # + # When USE_TMA is enabled, the parameter combinations are: + # a_desc | b_desc | sort_c | Use Case + # --------|---------|--------|----------------------------- + # yes | yes | False | expand kernel (num_slices=1) + # no | yes | True | shrink kernel (num_slices=1) + # yes | no | False | expand kernel (num_slices>1) + # no | no | True | shrink kernel (num_slices>1) + sort_c: tl.constexpr, ): pid = tl.program_id(axis=0) slice_id = tl.program_id(axis=1) @@ -250,58 +296,90 @@ def _fused_moe_lora_kernel( cur_b_ptr = tl.load(b_ptr + slice_id).to(tl.pointer_type(c_ptr.dtype.element_ty)) cur_c_ptr = c_ptr + (slice_id % num_slice_c) * slice_c_size - # remove modulo wrap-around - offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int32) offs_k = pid_sk * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) token_mask = offs_token < num_valid_tokens - # get a_ptrs,b_ptrs - a_ptrs = cur_a_ptr + ( - offs_token[:, None] // token_mapping_factor * stride_am - + offs_k[None, :] * stride_ak - ) + if USE_TMA and a_desc is not None: + # Expand path - with TMA enabled, load from A using TMA descriptor + offs_am = ( + slice_id * max_loras * EM + + lora_id * EM + + pid_m * BLOCK_SIZE_M // token_mapping_factor + ) + offs_ak = pid_sk * BLOCK_SIZE_K + else: + # Shrink path - load hidden states based on order defined in + # 'sorted_token_ids_ptr' then store them in c_ptr in this same sorted order + tl.static_assert(a_desc is None, "a_desc must be none") + a_ptrs = cur_a_ptr + ( + offs_token[:, None] // token_mapping_factor * stride_am + + offs_k[None, :] * stride_ak + ) - b_ptrs = ( - cur_b_ptr - + lora_id * stride_bl - + expert_id * stride_be - + offs_k[:, None] * stride_bk - + offs_bn[None, :] * stride_bn - ) + if USE_TMA: + offs_bn = pid_n * BLOCK_SIZE_N + offs_bk = pid_sk * BLOCK_SIZE_K + if b_desc is None: + # Note(@gnovack) - Allocation of TMA descriptors on-device + # can cause conflicts when running in parallel via PDL + if USE_GDC and not IS_PRIMARY: + tl.extra.cuda.gdc_wait() + + b_desc = tl.make_tensor_descriptor( + cur_b_ptr, + shape=[max_loras, num_experts, N, K], + strides=[stride_bl, stride_be, stride_bn, stride_bk], + block_shape=[1, 1, BLOCK_SIZE_N, BLOCK_SIZE_K], + ) + else: + offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int32) + b_ptrs = ( + cur_b_ptr + + lora_id * stride_bl + + expert_id * stride_be + + offs_k[:, None] * stride_bk + + offs_bn[None, :] * stride_bn + ) if USE_GDC and IS_PRIMARY: # GDC launch dependents hints the runtime system to launch dependent kernels. tl.extra.cuda.gdc_launch_dependents() - # accumulator accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) if USE_GDC and not IS_PRIMARY: tl.extra.cuda.gdc_wait() for k in range(0, grid_k): - k_remaining = K - k * (BLOCK_SIZE_K * SPLIT_K) - # GDC wait waits for ALL programs in the prior kernel to complete - # before continuing. + cur_k_offset = k * (BLOCK_SIZE_K * SPLIT_K) + k_remaining = K - cur_k_offset # pre-fetch lora weight - # add (offs_bn < N) mask; optional .ca for B - b_mask = (offs_k[:, None] < k_remaining) & (offs_bn[None, :] < N) - if USE_B_L2_CACHE: - b = tl.load(b_ptrs, mask=b_mask, other=0.0, cache_modifier=".ca") + if b_desc is not None: + b = ( + b_desc.load([lora_id, expert_id, offs_bn, offs_bk + cur_k_offset]) + .reshape(BLOCK_SIZE_N, BLOCK_SIZE_K) + .T + ) else: - b = tl.load(b_ptrs, mask=b_mask, other=0.0) + # add (offs_bn < N) mask; optional .ca for B + b_mask = (offs_k[:, None] < k_remaining) & (offs_bn[None, :] < N) + if USE_B_L2_CACHE: + b = tl.load(b_ptrs, mask=b_mask, other=0.0, cache_modifier=".ca") + else: + b = tl.load(b_ptrs, mask=b_mask, other=0.0) + b_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_bk + + if a_desc is not None: + a = a_desc.load([offs_am, offs_ak + cur_k_offset]) + else: + a = tl.load( + a_ptrs, + mask=token_mask[:, None] & (offs_k[None, :] < k_remaining), + other=0.0, + ) + a_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_ak - if USE_GDC and not IS_PRIMARY: - tl.extra.cuda.gdc_wait() - a = tl.load( - a_ptrs, - mask=token_mask[:, None] & (offs_k[None, :] < k_remaining), - other=0.0, - ) accumulator += tl.dot(a, b) - # Advance the ptrs to the next K block. - a_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_ak - b_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_bk if MUL_ROUTED_WEIGHT: moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0.0) @@ -309,7 +387,19 @@ def _fused_moe_lora_kernel( accumulator = accumulator.to(c_ptr.dtype.element_ty) # Write back the block of the output offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - c_ptrs = cur_c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :] + c_ptrs = _get_c_ptrs( + cur_c_ptr, + lora_id, + pid_m, + offs, + offs_token, + offs_cn, + stride_cm, + stride_cn, + EM, + BLOCK_SIZE_M, + sort_c, + ) c_mask = token_mask[:, None] & (offs_cn[None, :] < N) if SPLIT_K == 1: @@ -357,6 +447,7 @@ def _fused_moe_lora_shrink( num_active_loras: int, mul_routed_weight: bool = False, use_gdc: bool = False, + use_tma: bool = False, ) -> None: w1_lora_a_stacked = lora_a_stacked[0] shrink_config = { @@ -369,6 +460,7 @@ def _fused_moe_lora_shrink( "SPLIT_K": split_k, "USE_GDC": use_gdc, "launch_pdl": use_gdc, # triton kernel metadata + "USE_TMA": use_tma, } b_ptr = _get_ptr(lora_a_stacked, device) @@ -383,9 +475,20 @@ def _fused_moe_lora_shrink( len(lora_a_stacked), grid_lora_dim, ) + + a_desc = None + b_desc = None + if use_tma and num_slices == 1: + b_desc = triton.tools.tensor_descriptor.TensorDescriptor.from_tensor( + lora_a_stacked[0], + [1, 1, shrink_config["BLOCK_SIZE_N"], shrink_config["BLOCK_SIZE_K"]], + ) + _fused_moe_lora_kernel[grid]( qcurr_hidden_states, + a_desc, b_ptr, + b_desc, a_intermediate_cache1, topk_weights, sorted_token_ids, @@ -407,8 +510,8 @@ def _fused_moe_lora_shrink( w1_lora_a_stacked.stride(1), w1_lora_a_stacked.stride(3), w1_lora_a_stacked.stride(2), - a_intermediate_cache1.stride(2), - a_intermediate_cache1.stride(3), + a_intermediate_cache1.stride(-2), + a_intermediate_cache1.stride(-1), stride_tl, stride_el, slice_a_size=qcurr_hidden_states.numel(), @@ -419,7 +522,8 @@ def _fused_moe_lora_shrink( naive_block_assignment=sorted_token_ids is None, MUL_ROUTED_WEIGHT=False, ADD_INPUTS=False, - USE_B_L2_CACHE=True, # new + USE_B_L2_CACHE=True, + sort_c=use_tma and sorted_token_ids is not None, IS_PRIMARY=True, **shrink_config, ) @@ -462,6 +566,7 @@ def _fused_moe_lora_expand( mul_routed_weight: bool = False, offset: int = 0, use_gdc: bool = False, + use_tma: bool = False, ) -> None: b_ptr = _get_ptr(lora_b_stacked, device) K = max_lora_rank @@ -470,7 +575,7 @@ def _fused_moe_lora_expand( w1_lora_b_stacked = lora_b_stacked[0] a_intermediate_cache1 = a_intermediate_cache1.view( - -1, a_intermediate_cache1.shape[3] + -1, a_intermediate_cache1.shape[-1] ) expand_config = { @@ -483,6 +588,7 @@ def _fused_moe_lora_expand( "SPLIT_K": 1, # Set split_k = 1 for expand calls "USE_GDC": use_gdc, "launch_pdl": use_gdc, # triton kernel metadata + "USE_TMA": use_tma, } grid_lora_dim, stride_tl, stride_el = _adjust_kernel_inputs( @@ -498,10 +604,27 @@ def _fused_moe_lora_expand( # Fast path: directly accumulate into the corresponding slice interval of output. out_view = output[:, :, offset : offset + num_slices * N] slice_c_size = N * out_view.stride(2) + a_desc = None + b_desc = None + if use_tma: + if sorted_token_ids is not None: + a_desc = triton.tools.tensor_descriptor.TensorDescriptor.from_tensor( + a_intermediate_cache1, + [expand_config["BLOCK_SIZE_M"], expand_config["BLOCK_SIZE_K"]], + ) + if num_slices == 1: + b_desc = triton.tools.tensor_descriptor.TensorDescriptor.from_tensor( + lora_b_stacked[0], + [1, 1, expand_config["BLOCK_SIZE_N"], expand_config["BLOCK_SIZE_K"]], + ) + else: + b_desc = None _fused_moe_lora_kernel[grid]( a_intermediate_cache1, + a_desc, b_ptr, + b_desc, out_view, topk_weights, sorted_token_ids, @@ -535,7 +658,8 @@ def _fused_moe_lora_expand( naive_block_assignment=sorted_token_ids is None, MUL_ROUTED_WEIGHT=mul_routed_weight, ADD_INPUTS=True, - USE_B_L2_CACHE=True, # new + USE_B_L2_CACHE=True, + sort_c=False, IS_PRIMARY=False, **expand_config, ) @@ -616,8 +740,34 @@ def _fused_moe_lora( else num_tokens * shrink_block_size_m ) + # TMA is not currently compatiple with fully_sharded due to the non-determinism + # of token id sorting across ranks. + use_tma = supports_tma(device) and not fully_sharded + + intermediate_cache_shape = ( + num_slices, + M, + top_k_num, + max_lora_rank, + ) + if use_tma: + if num_slices > 1: + # if num_slices > 1, we construct TMA descriptors for LoRA + # weights within the kernel, which requires us to first set an allocator + set_triton_allocator(device) + + # When storing intermediate data in sorted order for TMA, we + # need an extra 'num_active_loras' dim in the cache to avoid conflicts + if sorted_token_ids is not None: + intermediate_cache_shape = ( + num_slices, + sorted_token_ids.shape[0], + EM, + max_lora_rank, + ) + a_intermediate_cache1 = torch.zeros( - (num_slices, M, top_k_num, max_lora_rank), + intermediate_cache_shape, dtype=output.dtype, device=device, ) @@ -654,6 +804,7 @@ def _fused_moe_lora( num_active_loras, mul_routed_weight, use_gdc=use_gdc, + use_tma=use_tma, ) if fully_sharded: @@ -703,6 +854,7 @@ def _fused_moe_lora( mul_routed_weight, offset, use_gdc=use_gdc, + use_tma=use_tma, ) @@ -772,6 +924,7 @@ def _fused_moe_lora_shrink_fake( num_active_loras: int, mul_routed_weight: bool = False, use_gdc: bool = False, + use_tma: bool = False, ) -> None: return @@ -809,6 +962,7 @@ def _fused_moe_lora_expand_fake( mul_routed_weight: bool = False, offset: int = 0, use_gdc: bool = False, + use_tma: bool = False, ) -> None: return diff --git a/vllm/lora/ops/triton_ops/utils.py b/vllm/lora/ops/triton_ops/utils.py index c7ac5914b..a863b9726 100644 --- a/vllm/lora/ops/triton_ops/utils.py +++ b/vllm/lora/ops/triton_ops/utils.py @@ -316,3 +316,9 @@ def supports_pdl(device: torch.device | None = None) -> bool: and current_platform.has_device_capability(90) and not envs.VLLM_LORA_DISABLE_PDL ) + + +@lru_cache +def supports_tma(device: torch.device | None = None) -> bool: + # TMA requires compute capability SM90 or above + return current_platform.is_cuda() and current_platform.has_device_capability(90) diff --git a/vllm/triton_utils/allocation.py b/vllm/triton_utils/allocation.py new file mode 100644 index 000000000..e805f80b8 --- /dev/null +++ b/vllm/triton_utils/allocation.py @@ -0,0 +1,13 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch + +from vllm.triton_utils import triton + + +def set_triton_allocator(device: torch.device): + def alloc_fn(size: int, alignment: int, stream: int | None): + return torch.empty(size, device=device, dtype=torch.int8) + + triton.set_allocator(alloc_fn)