Add TMA support to fused_moe_lora kernel (#32195)
Signed-off-by: gnovack <gnovack@amazon.com> Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
@@ -231,17 +231,22 @@ def use_torch(
|
||||
lora_a_stacked,
|
||||
lora_b_stacked,
|
||||
top_k_num,
|
||||
num_slices=1,
|
||||
):
|
||||
outputs = []
|
||||
for i in range(hidden_states.shape[0]):
|
||||
lora_idx = token_lora_mapping[i]
|
||||
expert_ids = topk_ids[i]
|
||||
lora_a = lora_a_stacked[0][lora_idx][expert_ids]
|
||||
lora_b = lora_b_stacked[0][lora_idx][expert_ids]
|
||||
tensors = [
|
||||
hidden_states[i] @ lora_a[x].T @ lora_b[x].T for x in range(top_k_num)
|
||||
]
|
||||
outputs.append(torch.stack(tensors, dim=0))
|
||||
slice_tensors = []
|
||||
for slice_id in range(num_slices):
|
||||
lora_idx = token_lora_mapping[i]
|
||||
expert_ids = topk_ids[i]
|
||||
lora_a = lora_a_stacked[slice_id][lora_idx][expert_ids]
|
||||
lora_b = lora_b_stacked[slice_id][lora_idx][expert_ids]
|
||||
tensors = [
|
||||
hidden_states[i] @ lora_a[x].T @ lora_b[x].T for x in range(top_k_num)
|
||||
]
|
||||
slice_tensors.append(torch.stack(tensors, dim=0))
|
||||
|
||||
outputs.append(torch.concat(slice_tensors, dim=-1))
|
||||
return torch.stack(outputs, dim=0)
|
||||
|
||||
|
||||
@@ -259,6 +264,7 @@ SEED = [42]
|
||||
@pytest.mark.parametrize("K", [2048])
|
||||
@pytest.mark.parametrize("max_lora_rank", [16, 32, 64])
|
||||
@pytest.mark.parametrize("block_size", [16])
|
||||
@pytest.mark.parametrize("num_slices", [1, 2])
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
@pytest.mark.parametrize("seed", SEED)
|
||||
@@ -271,6 +277,7 @@ def test_fused_moe_lora_kernel(
|
||||
K,
|
||||
max_lora_rank,
|
||||
block_size,
|
||||
num_slices,
|
||||
dtype,
|
||||
device,
|
||||
seed,
|
||||
@@ -295,17 +302,19 @@ def test_fused_moe_lora_kernel(
|
||||
),
|
||||
dtype=dtype,
|
||||
)
|
||||
for _ in range(num_slices)
|
||||
]
|
||||
lora_b_stacked = [
|
||||
torch.rand(
|
||||
(
|
||||
max_loras,
|
||||
num_experts,
|
||||
N,
|
||||
N // num_slices,
|
||||
max_lora_rank,
|
||||
),
|
||||
dtype=dtype,
|
||||
)
|
||||
for _ in range(num_slices)
|
||||
]
|
||||
hidden_states = torch.rand(
|
||||
(
|
||||
@@ -340,6 +349,7 @@ def test_fused_moe_lora_kernel(
|
||||
lora_a_stacked,
|
||||
lora_b_stacked,
|
||||
top_k_num,
|
||||
num_slices,
|
||||
)
|
||||
|
||||
torch.testing.assert_close(output, output2, atol=1e-2, rtol=1e-2)
|
||||
@@ -434,6 +444,7 @@ def use_fused_moe_lora_kernel_naive(
|
||||
@pytest.mark.parametrize("K", [2048])
|
||||
@pytest.mark.parametrize("max_lora_rank", [16, 32])
|
||||
@pytest.mark.parametrize("block_size", [16])
|
||||
@pytest.mark.parametrize("num_slices", [1, 2])
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
@pytest.mark.parametrize("seed", SEED)
|
||||
@@ -446,6 +457,7 @@ def test_fused_moe_lora_kernel_naive_block_assignment(
|
||||
K,
|
||||
max_lora_rank,
|
||||
block_size,
|
||||
num_slices,
|
||||
dtype,
|
||||
device,
|
||||
seed,
|
||||
@@ -484,17 +496,19 @@ def test_fused_moe_lora_kernel_naive_block_assignment(
|
||||
),
|
||||
dtype=dtype,
|
||||
)
|
||||
for _ in range(num_slices)
|
||||
]
|
||||
lora_b_stacked = [
|
||||
torch.rand(
|
||||
(
|
||||
max_loras,
|
||||
num_experts,
|
||||
N,
|
||||
N // num_slices,
|
||||
max_lora_rank,
|
||||
),
|
||||
dtype=dtype,
|
||||
)
|
||||
for _ in range(num_slices)
|
||||
]
|
||||
hidden_states = torch.rand(
|
||||
(
|
||||
@@ -529,6 +543,7 @@ def test_fused_moe_lora_kernel_naive_block_assignment(
|
||||
lora_a_stacked,
|
||||
lora_b_stacked,
|
||||
top_k_num,
|
||||
num_slices,
|
||||
)
|
||||
|
||||
torch.testing.assert_close(output, output_ref, atol=1e-2, rtol=1e-2)
|
||||
|
||||
@@ -2,7 +2,11 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
import shutil
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from safetensors.torch import load_file, save_file
|
||||
|
||||
import vllm
|
||||
from vllm.lora.request import LoRARequest
|
||||
@@ -122,6 +126,41 @@ def test_olmoe_lora_mixed(olmoe_lora_files):
|
||||
generate_and_test(llm, olmoe_lora_files, lora_id=[1, None, 3, None])
|
||||
|
||||
|
||||
def test_olmoe_lora_mixed_random(olmoe_lora_files, tmp_path):
|
||||
# Create a dummy LoRA with random weights based on the real one
|
||||
random_lora_path = tmp_path / "random_lora"
|
||||
shutil.copytree(olmoe_lora_files, random_lora_path)
|
||||
|
||||
weights_path = random_lora_path / "adapter_model.safetensors"
|
||||
weights = load_file(str(weights_path))
|
||||
random_weights = {k: torch.randn_like(v) for k, v in weights.items()}
|
||||
save_file(random_weights, str(weights_path))
|
||||
|
||||
llm = vllm.LLM(
|
||||
MODEL_PATH,
|
||||
max_model_len=1024,
|
||||
enable_lora=True,
|
||||
max_loras=4,
|
||||
enforce_eager=True,
|
||||
trust_remote_code=True,
|
||||
enable_chunked_prefill=True,
|
||||
)
|
||||
|
||||
prompts = [
|
||||
PROMPT_TEMPLATE.format(context="How many candidates are there?"),
|
||||
PROMPT_TEMPLATE.format(context="Count the number of candidates."),
|
||||
]
|
||||
|
||||
lora_requests = [
|
||||
LoRARequest("real", 1, olmoe_lora_files),
|
||||
LoRARequest("random", 2, str(random_lora_path)),
|
||||
]
|
||||
|
||||
sampling_params = vllm.SamplingParams(temperature=0, max_tokens=64)
|
||||
outputs = llm.generate(prompts, sampling_params, lora_request=lora_requests)
|
||||
assert outputs[0].outputs[0].text.strip().startswith(EXPECTED_LORA_OUTPUT[0])
|
||||
|
||||
|
||||
@pytest.mark.parametrize("fully_sharded_loras", [False, True])
|
||||
@multi_gpu_test(num_gpus=2)
|
||||
def test_olmoe_lora_tp2(olmoe_lora_files, fully_sharded_loras):
|
||||
|
||||
@@ -8,9 +8,10 @@ from vllm.distributed import (
|
||||
tensor_model_parallel_all_reduce,
|
||||
)
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.triton_utils.allocation import set_triton_allocator
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
from .utils import supports_pdl
|
||||
from .utils import supports_pdl, supports_tma
|
||||
|
||||
|
||||
@triton.jit
|
||||
@@ -70,6 +71,37 @@ def _get_token_offs(
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _get_c_ptrs(
|
||||
cur_c_ptr,
|
||||
lora_id,
|
||||
pid_m,
|
||||
offs,
|
||||
offs_token,
|
||||
offs_cn,
|
||||
stride_cm,
|
||||
stride_cn,
|
||||
EM: tl.constexpr,
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
sort_c: tl.constexpr,
|
||||
):
|
||||
# When sort_c is true, store the output in c_ptr using token order defined
|
||||
# in sorted_token_ids_ptr; otherwise, use the original token order from the prompt
|
||||
if sort_c:
|
||||
offs_token_id = pid_m * BLOCK_SIZE_M + offs
|
||||
c_ptrs = (
|
||||
cur_c_ptr
|
||||
+ lora_id * EM * stride_cm
|
||||
+ stride_cm * offs_token_id[:, None]
|
||||
+ stride_cn * offs_cn[None, :]
|
||||
)
|
||||
else:
|
||||
c_ptrs = (
|
||||
cur_c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
|
||||
)
|
||||
return c_ptrs
|
||||
|
||||
|
||||
_LORA_PTR_DICT: dict[tuple[int, ...], torch.tensor] = {}
|
||||
|
||||
|
||||
@@ -125,7 +157,9 @@ def _adjust_kernel_inputs(
|
||||
)
|
||||
def _fused_moe_lora_kernel(
|
||||
a_ptr,
|
||||
a_desc,
|
||||
b_ptr,
|
||||
b_desc,
|
||||
c_ptr,
|
||||
topk_weights_ptr,
|
||||
sorted_token_ids_ptr,
|
||||
@@ -177,6 +211,18 @@ def _fused_moe_lora_kernel(
|
||||
USE_GDC: tl.constexpr,
|
||||
launch_pdl: tl.constexpr,
|
||||
IS_PRIMARY: tl.constexpr,
|
||||
USE_TMA: tl.constexpr,
|
||||
# sort_c determines whether tokens are stored in C in the order determined
|
||||
# by sorted_token_ids to enable later TMA loads from this tensor.
|
||||
#
|
||||
# When USE_TMA is enabled, the parameter combinations are:
|
||||
# a_desc | b_desc | sort_c | Use Case
|
||||
# --------|---------|--------|-----------------------------
|
||||
# yes | yes | False | expand kernel (num_slices=1)
|
||||
# no | yes | True | shrink kernel (num_slices=1)
|
||||
# yes | no | False | expand kernel (num_slices>1)
|
||||
# no | no | True | shrink kernel (num_slices>1)
|
||||
sort_c: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(axis=0)
|
||||
slice_id = tl.program_id(axis=1)
|
||||
@@ -250,58 +296,90 @@ def _fused_moe_lora_kernel(
|
||||
cur_b_ptr = tl.load(b_ptr + slice_id).to(tl.pointer_type(c_ptr.dtype.element_ty))
|
||||
cur_c_ptr = c_ptr + (slice_id % num_slice_c) * slice_c_size
|
||||
|
||||
# remove modulo wrap-around
|
||||
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int32)
|
||||
offs_k = pid_sk * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
|
||||
token_mask = offs_token < num_valid_tokens
|
||||
|
||||
# get a_ptrs,b_ptrs
|
||||
a_ptrs = cur_a_ptr + (
|
||||
offs_token[:, None] // token_mapping_factor * stride_am
|
||||
+ offs_k[None, :] * stride_ak
|
||||
)
|
||||
if USE_TMA and a_desc is not None:
|
||||
# Expand path - with TMA enabled, load from A using TMA descriptor
|
||||
offs_am = (
|
||||
slice_id * max_loras * EM
|
||||
+ lora_id * EM
|
||||
+ pid_m * BLOCK_SIZE_M // token_mapping_factor
|
||||
)
|
||||
offs_ak = pid_sk * BLOCK_SIZE_K
|
||||
else:
|
||||
# Shrink path - load hidden states based on order defined in
|
||||
# 'sorted_token_ids_ptr' then store them in c_ptr in this same sorted order
|
||||
tl.static_assert(a_desc is None, "a_desc must be none")
|
||||
a_ptrs = cur_a_ptr + (
|
||||
offs_token[:, None] // token_mapping_factor * stride_am
|
||||
+ offs_k[None, :] * stride_ak
|
||||
)
|
||||
|
||||
b_ptrs = (
|
||||
cur_b_ptr
|
||||
+ lora_id * stride_bl
|
||||
+ expert_id * stride_be
|
||||
+ offs_k[:, None] * stride_bk
|
||||
+ offs_bn[None, :] * stride_bn
|
||||
)
|
||||
if USE_TMA:
|
||||
offs_bn = pid_n * BLOCK_SIZE_N
|
||||
offs_bk = pid_sk * BLOCK_SIZE_K
|
||||
if b_desc is None:
|
||||
# Note(@gnovack) - Allocation of TMA descriptors on-device
|
||||
# can cause conflicts when running in parallel via PDL
|
||||
if USE_GDC and not IS_PRIMARY:
|
||||
tl.extra.cuda.gdc_wait()
|
||||
|
||||
b_desc = tl.make_tensor_descriptor(
|
||||
cur_b_ptr,
|
||||
shape=[max_loras, num_experts, N, K],
|
||||
strides=[stride_bl, stride_be, stride_bn, stride_bk],
|
||||
block_shape=[1, 1, BLOCK_SIZE_N, BLOCK_SIZE_K],
|
||||
)
|
||||
else:
|
||||
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int32)
|
||||
b_ptrs = (
|
||||
cur_b_ptr
|
||||
+ lora_id * stride_bl
|
||||
+ expert_id * stride_be
|
||||
+ offs_k[:, None] * stride_bk
|
||||
+ offs_bn[None, :] * stride_bn
|
||||
)
|
||||
|
||||
if USE_GDC and IS_PRIMARY:
|
||||
# GDC launch dependents hints the runtime system to launch dependent kernels.
|
||||
tl.extra.cuda.gdc_launch_dependents()
|
||||
|
||||
# accumulator
|
||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
|
||||
if USE_GDC and not IS_PRIMARY:
|
||||
tl.extra.cuda.gdc_wait()
|
||||
|
||||
for k in range(0, grid_k):
|
||||
k_remaining = K - k * (BLOCK_SIZE_K * SPLIT_K)
|
||||
# GDC wait waits for ALL programs in the prior kernel to complete
|
||||
# before continuing.
|
||||
cur_k_offset = k * (BLOCK_SIZE_K * SPLIT_K)
|
||||
k_remaining = K - cur_k_offset
|
||||
# pre-fetch lora weight
|
||||
# add (offs_bn < N) mask; optional .ca for B
|
||||
b_mask = (offs_k[:, None] < k_remaining) & (offs_bn[None, :] < N)
|
||||
if USE_B_L2_CACHE:
|
||||
b = tl.load(b_ptrs, mask=b_mask, other=0.0, cache_modifier=".ca")
|
||||
if b_desc is not None:
|
||||
b = (
|
||||
b_desc.load([lora_id, expert_id, offs_bn, offs_bk + cur_k_offset])
|
||||
.reshape(BLOCK_SIZE_N, BLOCK_SIZE_K)
|
||||
.T
|
||||
)
|
||||
else:
|
||||
b = tl.load(b_ptrs, mask=b_mask, other=0.0)
|
||||
# add (offs_bn < N) mask; optional .ca for B
|
||||
b_mask = (offs_k[:, None] < k_remaining) & (offs_bn[None, :] < N)
|
||||
if USE_B_L2_CACHE:
|
||||
b = tl.load(b_ptrs, mask=b_mask, other=0.0, cache_modifier=".ca")
|
||||
else:
|
||||
b = tl.load(b_ptrs, mask=b_mask, other=0.0)
|
||||
b_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_bk
|
||||
|
||||
if a_desc is not None:
|
||||
a = a_desc.load([offs_am, offs_ak + cur_k_offset])
|
||||
else:
|
||||
a = tl.load(
|
||||
a_ptrs,
|
||||
mask=token_mask[:, None] & (offs_k[None, :] < k_remaining),
|
||||
other=0.0,
|
||||
)
|
||||
a_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_ak
|
||||
|
||||
if USE_GDC and not IS_PRIMARY:
|
||||
tl.extra.cuda.gdc_wait()
|
||||
a = tl.load(
|
||||
a_ptrs,
|
||||
mask=token_mask[:, None] & (offs_k[None, :] < k_remaining),
|
||||
other=0.0,
|
||||
)
|
||||
accumulator += tl.dot(a, b)
|
||||
# Advance the ptrs to the next K block.
|
||||
a_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_ak
|
||||
b_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_bk
|
||||
|
||||
if MUL_ROUTED_WEIGHT:
|
||||
moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0.0)
|
||||
@@ -309,7 +387,19 @@ def _fused_moe_lora_kernel(
|
||||
accumulator = accumulator.to(c_ptr.dtype.element_ty)
|
||||
# Write back the block of the output
|
||||
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
c_ptrs = cur_c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
|
||||
c_ptrs = _get_c_ptrs(
|
||||
cur_c_ptr,
|
||||
lora_id,
|
||||
pid_m,
|
||||
offs,
|
||||
offs_token,
|
||||
offs_cn,
|
||||
stride_cm,
|
||||
stride_cn,
|
||||
EM,
|
||||
BLOCK_SIZE_M,
|
||||
sort_c,
|
||||
)
|
||||
c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
|
||||
|
||||
if SPLIT_K == 1:
|
||||
@@ -357,6 +447,7 @@ def _fused_moe_lora_shrink(
|
||||
num_active_loras: int,
|
||||
mul_routed_weight: bool = False,
|
||||
use_gdc: bool = False,
|
||||
use_tma: bool = False,
|
||||
) -> None:
|
||||
w1_lora_a_stacked = lora_a_stacked[0]
|
||||
shrink_config = {
|
||||
@@ -369,6 +460,7 @@ def _fused_moe_lora_shrink(
|
||||
"SPLIT_K": split_k,
|
||||
"USE_GDC": use_gdc,
|
||||
"launch_pdl": use_gdc, # triton kernel metadata
|
||||
"USE_TMA": use_tma,
|
||||
}
|
||||
|
||||
b_ptr = _get_ptr(lora_a_stacked, device)
|
||||
@@ -383,9 +475,20 @@ def _fused_moe_lora_shrink(
|
||||
len(lora_a_stacked),
|
||||
grid_lora_dim,
|
||||
)
|
||||
|
||||
a_desc = None
|
||||
b_desc = None
|
||||
if use_tma and num_slices == 1:
|
||||
b_desc = triton.tools.tensor_descriptor.TensorDescriptor.from_tensor(
|
||||
lora_a_stacked[0],
|
||||
[1, 1, shrink_config["BLOCK_SIZE_N"], shrink_config["BLOCK_SIZE_K"]],
|
||||
)
|
||||
|
||||
_fused_moe_lora_kernel[grid](
|
||||
qcurr_hidden_states,
|
||||
a_desc,
|
||||
b_ptr,
|
||||
b_desc,
|
||||
a_intermediate_cache1,
|
||||
topk_weights,
|
||||
sorted_token_ids,
|
||||
@@ -407,8 +510,8 @@ def _fused_moe_lora_shrink(
|
||||
w1_lora_a_stacked.stride(1),
|
||||
w1_lora_a_stacked.stride(3),
|
||||
w1_lora_a_stacked.stride(2),
|
||||
a_intermediate_cache1.stride(2),
|
||||
a_intermediate_cache1.stride(3),
|
||||
a_intermediate_cache1.stride(-2),
|
||||
a_intermediate_cache1.stride(-1),
|
||||
stride_tl,
|
||||
stride_el,
|
||||
slice_a_size=qcurr_hidden_states.numel(),
|
||||
@@ -419,7 +522,8 @@ def _fused_moe_lora_shrink(
|
||||
naive_block_assignment=sorted_token_ids is None,
|
||||
MUL_ROUTED_WEIGHT=False,
|
||||
ADD_INPUTS=False,
|
||||
USE_B_L2_CACHE=True, # new
|
||||
USE_B_L2_CACHE=True,
|
||||
sort_c=use_tma and sorted_token_ids is not None,
|
||||
IS_PRIMARY=True,
|
||||
**shrink_config,
|
||||
)
|
||||
@@ -462,6 +566,7 @@ def _fused_moe_lora_expand(
|
||||
mul_routed_weight: bool = False,
|
||||
offset: int = 0,
|
||||
use_gdc: bool = False,
|
||||
use_tma: bool = False,
|
||||
) -> None:
|
||||
b_ptr = _get_ptr(lora_b_stacked, device)
|
||||
K = max_lora_rank
|
||||
@@ -470,7 +575,7 @@ def _fused_moe_lora_expand(
|
||||
w1_lora_b_stacked = lora_b_stacked[0]
|
||||
|
||||
a_intermediate_cache1 = a_intermediate_cache1.view(
|
||||
-1, a_intermediate_cache1.shape[3]
|
||||
-1, a_intermediate_cache1.shape[-1]
|
||||
)
|
||||
|
||||
expand_config = {
|
||||
@@ -483,6 +588,7 @@ def _fused_moe_lora_expand(
|
||||
"SPLIT_K": 1, # Set split_k = 1 for expand calls
|
||||
"USE_GDC": use_gdc,
|
||||
"launch_pdl": use_gdc, # triton kernel metadata
|
||||
"USE_TMA": use_tma,
|
||||
}
|
||||
|
||||
grid_lora_dim, stride_tl, stride_el = _adjust_kernel_inputs(
|
||||
@@ -498,10 +604,27 @@ def _fused_moe_lora_expand(
|
||||
# Fast path: directly accumulate into the corresponding slice interval of output.
|
||||
out_view = output[:, :, offset : offset + num_slices * N]
|
||||
slice_c_size = N * out_view.stride(2)
|
||||
a_desc = None
|
||||
b_desc = None
|
||||
if use_tma:
|
||||
if sorted_token_ids is not None:
|
||||
a_desc = triton.tools.tensor_descriptor.TensorDescriptor.from_tensor(
|
||||
a_intermediate_cache1,
|
||||
[expand_config["BLOCK_SIZE_M"], expand_config["BLOCK_SIZE_K"]],
|
||||
)
|
||||
if num_slices == 1:
|
||||
b_desc = triton.tools.tensor_descriptor.TensorDescriptor.from_tensor(
|
||||
lora_b_stacked[0],
|
||||
[1, 1, expand_config["BLOCK_SIZE_N"], expand_config["BLOCK_SIZE_K"]],
|
||||
)
|
||||
else:
|
||||
b_desc = None
|
||||
|
||||
_fused_moe_lora_kernel[grid](
|
||||
a_intermediate_cache1,
|
||||
a_desc,
|
||||
b_ptr,
|
||||
b_desc,
|
||||
out_view,
|
||||
topk_weights,
|
||||
sorted_token_ids,
|
||||
@@ -535,7 +658,8 @@ def _fused_moe_lora_expand(
|
||||
naive_block_assignment=sorted_token_ids is None,
|
||||
MUL_ROUTED_WEIGHT=mul_routed_weight,
|
||||
ADD_INPUTS=True,
|
||||
USE_B_L2_CACHE=True, # new
|
||||
USE_B_L2_CACHE=True,
|
||||
sort_c=False,
|
||||
IS_PRIMARY=False,
|
||||
**expand_config,
|
||||
)
|
||||
@@ -616,8 +740,34 @@ def _fused_moe_lora(
|
||||
else num_tokens * shrink_block_size_m
|
||||
)
|
||||
|
||||
# TMA is not currently compatiple with fully_sharded due to the non-determinism
|
||||
# of token id sorting across ranks.
|
||||
use_tma = supports_tma(device) and not fully_sharded
|
||||
|
||||
intermediate_cache_shape = (
|
||||
num_slices,
|
||||
M,
|
||||
top_k_num,
|
||||
max_lora_rank,
|
||||
)
|
||||
if use_tma:
|
||||
if num_slices > 1:
|
||||
# if num_slices > 1, we construct TMA descriptors for LoRA
|
||||
# weights within the kernel, which requires us to first set an allocator
|
||||
set_triton_allocator(device)
|
||||
|
||||
# When storing intermediate data in sorted order for TMA, we
|
||||
# need an extra 'num_active_loras' dim in the cache to avoid conflicts
|
||||
if sorted_token_ids is not None:
|
||||
intermediate_cache_shape = (
|
||||
num_slices,
|
||||
sorted_token_ids.shape[0],
|
||||
EM,
|
||||
max_lora_rank,
|
||||
)
|
||||
|
||||
a_intermediate_cache1 = torch.zeros(
|
||||
(num_slices, M, top_k_num, max_lora_rank),
|
||||
intermediate_cache_shape,
|
||||
dtype=output.dtype,
|
||||
device=device,
|
||||
)
|
||||
@@ -654,6 +804,7 @@ def _fused_moe_lora(
|
||||
num_active_loras,
|
||||
mul_routed_weight,
|
||||
use_gdc=use_gdc,
|
||||
use_tma=use_tma,
|
||||
)
|
||||
|
||||
if fully_sharded:
|
||||
@@ -703,6 +854,7 @@ def _fused_moe_lora(
|
||||
mul_routed_weight,
|
||||
offset,
|
||||
use_gdc=use_gdc,
|
||||
use_tma=use_tma,
|
||||
)
|
||||
|
||||
|
||||
@@ -772,6 +924,7 @@ def _fused_moe_lora_shrink_fake(
|
||||
num_active_loras: int,
|
||||
mul_routed_weight: bool = False,
|
||||
use_gdc: bool = False,
|
||||
use_tma: bool = False,
|
||||
) -> None:
|
||||
return
|
||||
|
||||
@@ -809,6 +962,7 @@ def _fused_moe_lora_expand_fake(
|
||||
mul_routed_weight: bool = False,
|
||||
offset: int = 0,
|
||||
use_gdc: bool = False,
|
||||
use_tma: bool = False,
|
||||
) -> None:
|
||||
return
|
||||
|
||||
|
||||
@@ -316,3 +316,9 @@ def supports_pdl(device: torch.device | None = None) -> bool:
|
||||
and current_platform.has_device_capability(90)
|
||||
and not envs.VLLM_LORA_DISABLE_PDL
|
||||
)
|
||||
|
||||
|
||||
@lru_cache
|
||||
def supports_tma(device: torch.device | None = None) -> bool:
|
||||
# TMA requires compute capability SM90 or above
|
||||
return current_platform.is_cuda() and current_platform.has_device_capability(90)
|
||||
|
||||
13
vllm/triton_utils/allocation.py
Normal file
13
vllm/triton_utils/allocation.py
Normal file
@@ -0,0 +1,13 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import triton
|
||||
|
||||
|
||||
def set_triton_allocator(device: torch.device):
|
||||
def alloc_fn(size: int, alignment: int, stream: int | None):
|
||||
return torch.empty(size, device=device, dtype=torch.int8)
|
||||
|
||||
triton.set_allocator(alloc_fn)
|
||||
Reference in New Issue
Block a user