diff --git a/tests/lora/test_punica_ops_fp8.py b/tests/lora/test_punica_ops_fp8.py new file mode 100644 index 000000000..042313336 --- /dev/null +++ b/tests/lora/test_punica_ops_fp8.py @@ -0,0 +1,999 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""FP8 accuracy tests for LoRA shrink and expand kernels. + +Tests the FP8 kernels by: +1. Quantizing bf16 inputs/weights to FP8 +2. Dequantizing them back to bf16 +3. Running the bf16 reference (sgmv_shrink/sgmv_expand) with dequantized values +4. Comparing FP8 kernel output against this dequantized reference + +This isolates kernel correctness from quantization precision loss, +allowing much tighter tolerances than comparing against the original bf16. +""" + +import math +from threading import Lock + +import pytest +import torch + +import vllm.lora.ops.torch_ops as torch_ops +import vllm.lora.ops.triton_ops as triton_ops +from vllm.lora.ops.triton_ops import LoRAKernelMeta +from vllm.lora.ops.triton_ops.lora_expand_fp8_op import ( + _EXPAND_LORA_SCALE_PTR_DICT, +) +from vllm.lora.ops.triton_ops.lora_shrink_fp8_op import ( + _SHRINK_LORA_SCALE_PTR_DICT, +) +from vllm.lora.ops.triton_ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT +from vllm.utils.torch_utils import set_random_seed + +DEVICES = [f"cuda:{0}"] +SEED = [0] + +_dict_lock = Lock() + + +@pytest.fixture(autouse=True) +def reset_device(reset_default_device): + pass + + +# ============================================================================ +# Reference implementations (bf16 baseline) +# ============================================================================ + + +def sgmv_shrink_for_nslices( + nslices, + inputs_tensor, + lora_weights_lst, + out_tensor, + b_seq_start_loc, + seq_len_tensor, + prompt_lora_mapping, + batches, + max_seq_length, + num_tokens, + scaling, +): + """Wrapper around torch_ops.sgmv_shrink that handles any nslices.""" + for index in range(nslices): + torch_ops.sgmv_shrink( + inputs_tensor, + lora_weights_lst[index], + out_tensor[index], + b_seq_start_loc, + seq_len_tensor, + prompt_lora_mapping, + batches, + max_seq_length, + num_tokens, + scaling, + ) + + +def sgmv_expand_for_nslices( + nslices, + hidden_size, + inputs_tensor, + lora_weights_lst, + out_tensor, + b_seq_start_loc, + seq_len_tensor, + prompt_lora_mapping, + batches, + max_seq_length, + num_tokens, + add_inputs, +): + """Wrapper around torch_ops.sgmv_expand that handles any nslices.""" + if nslices == 1: + torch_ops.sgmv_expand( + inputs_tensor[0], + lora_weights_lst[0], + out_tensor, + b_seq_start_loc, + seq_len_tensor, + prompt_lora_mapping, + batches, + max_seq_length, + num_tokens, + add_inputs=add_inputs, + ) + else: + slice_offset = 0 + for index in range(nslices): + torch_ops.sgmv_expand_slice( + inputs_tensor[index], + lora_weights_lst[index], + out_tensor, + b_seq_start_loc, + seq_len_tensor, + prompt_lora_mapping, + batches, + max_seq_length, + num_tokens, + slice_offset, + hidden_size, + add_inputs=add_inputs, + ) + slice_offset += hidden_size + + +# ============================================================================ +# FP8 Quantization Helpers +# ============================================================================ + +FP8_DTYPE = torch.float8_e4m3fn +FP8_MAX = torch.finfo(FP8_DTYPE).max +FP8_MIN = torch.finfo(FP8_DTYPE).min + + +def quantize_to_fp8_per_tensor( + tensor: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + """Quantize a tensor to FP8 with per-tensor scaling.""" + amax = tensor.abs().float().max().clamp(min=1e-12) + scale = (amax / FP8_MAX).to(torch.float32) + fp8_tensor = (tensor.float() / scale).clamp(FP8_MIN, FP8_MAX).to(FP8_DTYPE) + return fp8_tensor, scale.reshape(1) + + +def quantize_to_fp8_per_channel( + tensor: torch.Tensor, + channel_dim: int = 0, +) -> tuple[torch.Tensor, torch.Tensor]: + """Quantize a tensor to FP8 with per-channel scaling. + + For shrink lora_a weights of shape (num_loras, rank, hidden_size): + channel_dim=1 gives per-rank scaling -> scale shape (num_loras, rank) + For expand lora_b weights of shape (num_loras, hidden_size, rank): + channel_dim=1 gives per-hidden scaling -> scale shape (num_loras, hidden_size) + """ + # Compute amax along all dims except the leading dims up to channel_dim+1 + reduce_dims = list(range(channel_dim + 1, tensor.ndim)) + if reduce_dims: + amax = tensor.abs().float().amax(dim=reduce_dims).clamp(min=1e-12) + else: + amax = tensor.abs().float().clamp(min=1e-12) + scale = (amax / FP8_MAX).to(torch.float32) + + # Expand scale for broadcasting + for _ in reduce_dims: + scale = scale.unsqueeze(-1) + fp8_tensor = (tensor.float() / scale).clamp(FP8_MIN, FP8_MAX).to(FP8_DTYPE) + scale = scale.squeeze() + if scale.ndim == 0: + scale = scale.unsqueeze(0) + return fp8_tensor, scale + + +def quantize_to_fp8_per_token( + tensor: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + """Quantize a 2D tensor to FP8 with per-token (per-row) scaling. + + Input shape: (num_tokens, hidden_size) + Returns: (fp8_tensor, scale) where scale shape is (num_tokens, 1) + """ + assert tensor.ndim == 2 + amax = tensor.abs().float().amax(dim=1, keepdim=True).clamp(min=1e-12) + scale = (amax / FP8_MAX).to(torch.float32) + fp8_tensor = (tensor.float() / scale).clamp(FP8_MIN, FP8_MAX).to(FP8_DTYPE) + return fp8_tensor, scale + + +def quantize_to_fp8_blockwise( + tensor: torch.Tensor, + group_n: int, + group_k: int, +) -> tuple[torch.Tensor, torch.Tensor]: + """Quantize a 2D or 3D tensor to FP8 with block-wise scaling. + + For a 2D tensor (num_tokens, hidden_size): + Blocks of size (1, group_k) -> + scale shape (num_tokens, ceil(hidden_size/group_k)) + + For a 3D tensor (num_loras, N, K): + Blocks of size (group_n, group_k) -> + scale shape (num_loras, ceil(N/group_n), ceil(K/group_k)) + """ + if tensor.ndim == 2: + M, K = tensor.shape + n_blocks_k = math.ceil(K / group_k) + scale = torch.zeros(M, n_blocks_k, dtype=torch.float32, device=tensor.device) + fp8_tensor = torch.zeros_like(tensor, dtype=FP8_DTYPE) + for m in range(M): + for bk in range(n_blocks_k): + k_start = bk * group_k + k_end = min(k_start + group_k, K) + block = tensor[m, k_start:k_end].float() + amax = block.abs().max().clamp(min=1e-12) + s = (amax / FP8_MAX).to(torch.float32) + scale[m, bk] = s + fp8_tensor[m, k_start:k_end] = ( + (block / s).clamp(FP8_MIN, FP8_MAX).to(FP8_DTYPE) + ) + return fp8_tensor, scale + elif tensor.ndim == 3: + L, N, K = tensor.shape + n_blocks_n = math.ceil(N / group_n) + n_blocks_k = math.ceil(K / group_k) + scale = torch.zeros( + L, n_blocks_n, n_blocks_k, dtype=torch.float32, device=tensor.device + ) + fp8_tensor = torch.zeros_like(tensor, dtype=FP8_DTYPE) + for li in range(L): + for bn in range(n_blocks_n): + for bk in range(n_blocks_k): + n_start = bn * group_n + n_end = min(n_start + group_n, N) + k_start = bk * group_k + k_end = min(k_start + group_k, K) + block = tensor[li, n_start:n_end, k_start:k_end].float() + amax = block.abs().max().clamp(min=1e-12) + s = (amax / FP8_MAX).to(torch.float32) + scale[li, bn, bk] = s + fp8_tensor[li, n_start:n_end, k_start:k_end] = ( + (block / s).clamp(FP8_MIN, FP8_MAX).to(FP8_DTYPE) + ) + return fp8_tensor, scale + else: + raise ValueError(f"Unsupported tensor ndim: {tensor.ndim}") + + +# ============================================================================ +# FP8 Dequantization Helpers +# ============================================================================ + + +def dequantize_fp8_per_tensor( + fp8_tensor: torch.Tensor, + scale: torch.Tensor, + output_dtype: torch.dtype = torch.bfloat16, +) -> torch.Tensor: + """Dequantize FP8 tensor with per-tensor scale back to output_dtype.""" + return (fp8_tensor.float() * scale.float()).to(output_dtype) + + +def dequantize_fp8_per_channel( + fp8_tensor: torch.Tensor, + scale: torch.Tensor, + channel_dim: int, + output_dtype: torch.dtype = torch.bfloat16, +) -> torch.Tensor: + """Dequantize FP8 tensor with per-channel scale back to output_dtype. + + For 3D tensor (num_loras, N, K) with channel_dim=1: + scale shape is (num_loras, N), broadcast over K. + """ + expand_scale = scale.float() + # Add trailing dims for broadcasting + for _ in range(channel_dim + 1, fp8_tensor.ndim): + expand_scale = expand_scale.unsqueeze(-1) + return (fp8_tensor.float() * expand_scale).to(output_dtype) + + +def dequantize_fp8_per_token( + fp8_tensor: torch.Tensor, + scale: torch.Tensor, + output_dtype: torch.dtype = torch.bfloat16, +) -> torch.Tensor: + """Dequantize FP8 2D tensor with per-token scale back to output_dtype. + + fp8_tensor: (num_tokens, hidden_size), scale: (num_tokens, 1) + """ + return (fp8_tensor.float() * scale.float()).to(output_dtype) + + +def dequantize_fp8_blockwise( + fp8_tensor: torch.Tensor, + scale: torch.Tensor, + group_n: int, + group_k: int, + output_dtype: torch.dtype = torch.bfloat16, +) -> torch.Tensor: + """Dequantize FP8 tensor with block-wise scale back to output_dtype.""" + if fp8_tensor.ndim == 2: + M, K = fp8_tensor.shape + out = torch.zeros(M, K, dtype=output_dtype, device=fp8_tensor.device) + n_blocks_k = math.ceil(K / group_k) + for m in range(M): + for bk in range(n_blocks_k): + k_start = bk * group_k + k_end = min(k_start + group_k, K) + out[m, k_start:k_end] = ( + fp8_tensor[m, k_start:k_end].float() * scale[m, bk].float() + ).to(output_dtype) + return out + elif fp8_tensor.ndim == 3: + L, N, K = fp8_tensor.shape + out = torch.zeros(L, N, K, dtype=output_dtype, device=fp8_tensor.device) + n_blocks_n = math.ceil(N / group_n) + n_blocks_k = math.ceil(K / group_k) + for l_idx in range(L): + for bn in range(n_blocks_n): + for bk in range(n_blocks_k): + n_start = bn * group_n + n_end = min(n_start + group_n, N) + k_start = bk * group_k + k_end = min(k_start + group_k, K) + out[l_idx, n_start:n_end, k_start:k_end] = ( + fp8_tensor[l_idx, n_start:n_end, k_start:k_end].float() + * scale[l_idx, bn, bk].float() + ).to(output_dtype) + return out + else: + raise ValueError(f"Unsupported tensor ndim: {fp8_tensor.ndim}") + + +# ============================================================================ +# FP8 Data Generation +# ============================================================================ + + +def generate_fp8_shrink_data( + batches: int, + hidden_size: int, + num_loras: int, + rank: int, + seq_length: int, + nslices: int, + dtype: torch.dtype, + device: str, + quant_mode: str, # "per_tensor", "per_channel", "blockwise" + group_k: int = 128, + group_n: int = 128, +): + """Generate test data for FP8 shrink kernel. + + Shrink: output = input @ lora_a^T * scaling + input: (num_tokens, hidden_size) -> quantized to FP8 + lora_a: (num_loras, rank, hidden_size) -> quantized to FP8 + + Returns bf16 reference tensors, FP8 quantized tensors with scales, + and dequantized bf16 tensors for accurate reference computation. + """ + seq_len_tensor = torch.randint(seq_length, seq_length + 1, (batches,)).to(device) + b_seq_start_loc = torch.cumsum( + torch.tensor([0] + seq_len_tensor[:-1].tolist(), dtype=torch.long), + dim=0, + ).to(device) + total_tokens = seq_len_tensor.sum().item() + + # Generate bf16 reference data + inputs_bf16 = torch.randn(total_tokens, hidden_size, dtype=dtype, device=device) + + lora_a_weights_bf16 = [] + for _ in range(nslices): + lora_a_weights_bf16.append( + torch.randn(num_loras, rank, hidden_size, dtype=dtype, device=device) + ) + + # Quantize inputs to FP8 and dequantize back for reference + if quant_mode == "blockwise": + inputs_fp8, a_scale = quantize_to_fp8_blockwise( + inputs_bf16, group_n=1, group_k=group_k + ) + inputs_dequant = dequantize_fp8_blockwise( + inputs_fp8, + a_scale, + group_n=1, + group_k=group_k, + output_dtype=dtype, + ) + elif quant_mode == "per_tensor": + # Per-tensor: kernel loads a single scalar from a_scale_ptr + inputs_fp8, a_scale = quantize_to_fp8_per_tensor(inputs_bf16) + inputs_dequant = dequantize_fp8_per_tensor( + inputs_fp8, + a_scale, + output_dtype=dtype, + ) + else: + # per_channel: kernel loads per-token a_scale via ram indexing + inputs_fp8, a_scale = quantize_to_fp8_per_token(inputs_bf16) + inputs_dequant = dequantize_fp8_per_token( + inputs_fp8, + a_scale, + output_dtype=dtype, + ) + + # Quantize lora_a weights to FP8 and dequantize back for reference + b_scales = [] + lora_a_weights_fp8 = [] + lora_a_weights_dequant = [] + for w in lora_a_weights_bf16: + if quant_mode == "per_tensor": + w_fp8, w_scale = quantize_to_fp8_per_tensor(w) + w_dequant = dequantize_fp8_per_tensor(w_fp8, w_scale, output_dtype=dtype) + # Scale shape: (1,) -> need (num_loras,) for the kernel + w_scale = w_scale.expand(num_loras).contiguous() + lora_a_weights_fp8.append(w_fp8) + b_scales.append(w_scale) + lora_a_weights_dequant.append(w_dequant) + elif quant_mode == "per_channel": + # Per-channel along rank dim: scale shape (num_loras, rank) + w_fp8, w_scale = quantize_to_fp8_per_channel(w, channel_dim=1) + w_dequant = dequantize_fp8_per_channel( + w_fp8, + w_scale, + channel_dim=1, + output_dtype=dtype, + ) + lora_a_weights_fp8.append(w_fp8) + b_scales.append(w_scale) + lora_a_weights_dequant.append(w_dequant) + elif quant_mode == "blockwise": + w_fp8, w_scale = quantize_to_fp8_blockwise( + w, group_n=group_n, group_k=group_k + ) + w_dequant = dequantize_fp8_blockwise( + w_fp8, + w_scale, + group_n=group_n, + group_k=group_k, + output_dtype=dtype, + ) + lora_a_weights_fp8.append(w_fp8) + b_scales.append(w_scale) + lora_a_weights_dequant.append(w_dequant) + + # Output tensor (float32 for shrink) + out_tensor = torch.zeros( + nslices, total_tokens, rank, dtype=torch.float32, device=device + ) + ref_out_tensor = out_tensor.clone() + + # Token-to-lora mapping + lora_indices_tensor = torch.randint(0, max(num_loras - 1, 1), (batches,)).to(device) + token_lora_mapping = torch.zeros(total_tokens, dtype=torch.long, device=device) + current_offset = 0 + for b_id in range(batches): + lora_index = lora_indices_tensor[b_id] + sl = seq_len_tensor[b_id].item() + token_lora_mapping[current_offset : current_offset + sl] = lora_index + current_offset += sl + + return { + "inputs_bf16": inputs_bf16, + "inputs_fp8": inputs_fp8, + "inputs_dequant": inputs_dequant, + "lora_a_bf16": lora_a_weights_bf16, + "lora_a_fp8": lora_a_weights_fp8, + "lora_a_dequant": lora_a_weights_dequant, + "a_scale": a_scale, + "b_scales": b_scales, + "out_tensor": out_tensor, + "ref_out_tensor": ref_out_tensor, + "token_lora_mapping": token_lora_mapping, + "seq_len_tensor": seq_len_tensor, + "b_seq_start_loc": b_seq_start_loc, + "lora_indices_tensor": lora_indices_tensor, + "total_tokens": total_tokens, + } + + +def generate_fp8_expand_data( + batches: int, + hidden_size: int, + num_loras: int, + rank: int, + seq_length: int, + nslices: int, + dtype: torch.dtype, + device: str, + quant_mode: str, # "per_tensor", "per_channel", "blockwise" + group_k: int = 128, + group_n: int = 128, +): + """Generate test data for FP8 expand kernel (w8a8). + + Expand: output += input @ lora_b^T + input: (nslices, num_tokens, rank) -> quantized to FP8 (activations) + lora_b: (num_loras, hidden_size, rank) -> quantized to FP8 (weights) + + In w8a8 mode, both activations and weights are FP8. + Returns bf16 reference tensors, FP8 quantized tensors with scales, + and dequantized bf16 tensors for accurate reference computation. + """ + seq_len_tensor = torch.randint(seq_length, seq_length + 1, (batches,)).to(device) + b_seq_start_loc = torch.cumsum( + torch.tensor([0] + seq_len_tensor[:-1].tolist(), dtype=torch.long), + dim=0, + ).to(device) + total_tokens = seq_len_tensor.sum().item() + + # Generate bf16 input (shrink output) and quantize to FP8 + inputs_bf16 = torch.randn(nslices, total_tokens, rank, dtype=dtype, device=device) + + # Quantize input to FP8 and dequantize back for reference + inputs_2d_all = inputs_bf16.reshape(-1, rank) + if quant_mode == "blockwise": + # For blockwise, the kernel indexes a_scale by token id (0..total_tokens-1) + # shared across slices. Compute shared scale across slices, then quantize. + # First compute per-token-per-block scale across all slices + n_blocks_k = math.ceil(rank / group_k) + a_scale = torch.zeros( + total_tokens, n_blocks_k, dtype=torch.float32, device=device + ) + for m in range(total_tokens): + for bk in range(n_blocks_k): + k_start = bk * group_k + k_end = min(k_start + group_k, rank) + # Max across all slices for this token and block + block_amax = torch.tensor(0.0, device=device) + for s in range(nslices): + block = inputs_bf16[s, m, k_start:k_end].float() + block_amax = torch.max( + block_amax, block.abs().max().clamp(min=1e-12) + ) + a_scale[m, bk] = (block_amax / FP8_MAX).to(torch.float32) + + # Quantize all slices with the shared scale + inputs_fp8_list = [] + inputs_dequant_list = [] + for s in range(nslices): + slice_2d = inputs_bf16[s] # (total_tokens, rank) + fp8_slice = torch.zeros_like(slice_2d, dtype=FP8_DTYPE) + dequant_slice = torch.zeros_like(slice_2d) + for m in range(total_tokens): + for bk in range(n_blocks_k): + k_start = bk * group_k + k_end = min(k_start + group_k, rank) + block = slice_2d[m, k_start:k_end].float() + s_val = a_scale[m, bk] + fp8_slice[m, k_start:k_end] = ( + (block / s_val).clamp(FP8_MIN, FP8_MAX).to(FP8_DTYPE) + ) + dequant_slice[m, k_start:k_end] = ( + fp8_slice[m, k_start:k_end].float() * s_val.float() + ).to(dtype) + inputs_fp8_list.append(fp8_slice) + inputs_dequant_list.append(dequant_slice) + inputs_fp8 = torch.stack(inputs_fp8_list, dim=0) + inputs_dequant = torch.stack(inputs_dequant_list, dim=0) + elif quant_mode == "per_tensor": + # Per-tensor: kernel loads a single scalar from a_scale_ptr + inputs_fp8_2d, a_scale = quantize_to_fp8_per_tensor(inputs_2d_all) + inputs_dequant_2d = dequantize_fp8_per_tensor( + inputs_fp8_2d, + a_scale, + output_dtype=dtype, + ) + inputs_fp8 = inputs_fp8_2d.reshape(nslices, total_tokens, rank) + inputs_dequant = inputs_dequant_2d.reshape(nslices, total_tokens, rank) + else: + # per_channel: kernel loads per-token a_scale via ram indexing. + # The kernel uses the same a_scale for all slices (indexed by token + # id 0..total_tokens-1), so we compute a shared per-token scale + # across all slices, then quantize each slice with that shared scale. + per_slice_views = [inputs_bf16[s] for s in range(nslices)] + # (nslices, total_tokens, rank) -> max across slices per token + stacked = torch.stack(per_slice_views, dim=0) # (nslices, tokens, rank) + amax = stacked.abs().float().amax(dim=(0, 2), keepdim=False).clamp(min=1e-12) + # amax shape: (total_tokens,) + a_scale = (amax / FP8_MAX).to(torch.float32).unsqueeze(1) # (tokens, 1) + # Quantize all slices with the shared scale + inputs_fp8_2d = ( + (inputs_2d_all.float() / a_scale.repeat(nslices, 1)) + .clamp(FP8_MIN, FP8_MAX) + .to(FP8_DTYPE) + ) + inputs_dequant_2d = ( + inputs_fp8_2d.float() * a_scale.repeat(nslices, 1).float() + ).to(dtype) + inputs_fp8 = inputs_fp8_2d.reshape(nslices, total_tokens, rank) + inputs_dequant = inputs_dequant_2d.reshape(nslices, total_tokens, rank) + + # Generate bf16 LoRA B weights + lora_b_weights_bf16 = [] + for _ in range(nslices): + lora_b_weights_bf16.append( + torch.randn(num_loras, hidden_size, rank, dtype=dtype, device=device) + ) + + # Quantize LoRA B weights to FP8 and dequantize back for reference + b_scales = [] + lora_b_weights_fp8 = [] + lora_b_weights_dequant = [] + for w in lora_b_weights_bf16: + if quant_mode == "per_tensor": + w_fp8, w_scale = quantize_to_fp8_per_tensor(w) + w_dequant = dequantize_fp8_per_tensor(w_fp8, w_scale, output_dtype=dtype) + w_scale = w_scale.expand(num_loras).contiguous() + lora_b_weights_fp8.append(w_fp8) + b_scales.append(w_scale) + lora_b_weights_dequant.append(w_dequant) + elif quant_mode == "per_channel": + # Per-channel along hidden_size dim: scale (num_loras, hidden_size) + w_fp8, w_scale = quantize_to_fp8_per_channel(w, channel_dim=1) + w_dequant = dequantize_fp8_per_channel( + w_fp8, + w_scale, + channel_dim=1, + output_dtype=dtype, + ) + lora_b_weights_fp8.append(w_fp8) + b_scales.append(w_scale) + lora_b_weights_dequant.append(w_dequant) + elif quant_mode == "blockwise": + w_fp8, w_scale = quantize_to_fp8_blockwise( + w, group_n=group_n, group_k=group_k + ) + w_dequant = dequantize_fp8_blockwise( + w_fp8, + w_scale, + group_n=group_n, + group_k=group_k, + output_dtype=dtype, + ) + lora_b_weights_fp8.append(w_fp8) + b_scales.append(w_scale) + lora_b_weights_dequant.append(w_dequant) + + # Output tensor (initialized randomly for add_inputs) + out_tensor = torch.randn( + total_tokens, hidden_size * nslices, dtype=dtype, device=device + ) + ref_out_tensor = out_tensor.clone() + + # Token-to-lora mapping + lora_indices_tensor = torch.randint(0, max(num_loras - 1, 1), (batches,)).to(device) + token_lora_mapping = torch.zeros(total_tokens, dtype=torch.long, device=device) + current_offset = 0 + for b_id in range(batches): + lora_index = lora_indices_tensor[b_id] + sl = seq_len_tensor[b_id].item() + token_lora_mapping[current_offset : current_offset + sl] = lora_index + current_offset += sl + + return { + "inputs_bf16": inputs_bf16, + "inputs_fp8": inputs_fp8, + "inputs_dequant": inputs_dequant, + "a_scale": a_scale, + "lora_b_bf16": lora_b_weights_bf16, + "lora_b_fp8": lora_b_weights_fp8, + "lora_b_dequant": lora_b_weights_dequant, + "b_scales": b_scales, + "out_tensor": out_tensor, + "ref_out_tensor": ref_out_tensor, + "token_lora_mapping": token_lora_mapping, + "seq_len_tensor": seq_len_tensor, + "b_seq_start_loc": b_seq_start_loc, + "lora_indices_tensor": lora_indices_tensor, + "total_tokens": total_tokens, + } + + +# ============================================================================ +# FP8 Shrink Kernel Check +# ============================================================================ + + +def check_lora_shrink_fp8_kernel( + batches: int, + num_loras: int, + rank: int, + hidden_size: int, + nslices: int, + dtype: torch.dtype, + device: str, + seq_length: int, + scaling: float, + quant_mode: str, + group_k: int = 128, + group_n: int = 128, +): + """Test FP8 shrink kernel against dequantized bf16 reference. + + Instead of comparing FP8 kernel output against the original bf16 reference + (which conflates quantization error with kernel error), we: + 1. Quantize bf16 inputs/weights to FP8 + 2. Dequantize them back to bf16 + 3. Run the bf16 reference (sgmv_shrink) with the dequantized values + 4. Compare FP8 kernel output against this dequantized reference + + This isolates kernel correctness from quantization precision loss, + allowing much tighter tolerances. + """ + data = generate_fp8_shrink_data( + batches, + hidden_size, + num_loras, + rank, + seq_length, + nslices, + dtype, + device, + quant_mode, + group_k, + group_n, + ) + + total_tokens = data["total_tokens"] + + # Setup LoRA kernel metadata + lora_meta = LoRAKernelMeta.make( + max_loras=num_loras, max_num_tokens=total_tokens, device=device + ) + lora_meta.prepare_tensors(data["token_lora_mapping"]) + + out_tensor = data["out_tensor"] + + # Determine quantization params for the kernel + per_channel = quant_mode == "per_channel" + gk = group_k if quant_mode == "blockwise" else 0 + gn = group_n if quant_mode == "blockwise" else 0 + + with _dict_lock: + _LORA_A_PTR_DICT.clear() + _SHRINK_LORA_SCALE_PTR_DICT.clear() + triton_ops.lora_shrink_fp8( + data["inputs_fp8"], + data["lora_a_fp8"], + out_tensor, + *lora_meta.meta_args(token_nums=total_tokens, specialize_active_lora=False), + scaling, + data["b_scales"], + a_scale=data["a_scale"], + group_k=gk, + group_n=gn, + use_fp8_w8a8=True, + per_channel_quant=per_channel, + ) + + # Compute reference using dequantized (round-tripped) tensors. + # This means the reference sees the same quantization error as the kernel, + # so any difference is purely kernel error. + ref_out_tensor = data["ref_out_tensor"] + max_seq_length = data["seq_len_tensor"].max().item() + sgmv_shrink_for_nslices( + nslices, + data["inputs_dequant"], + data["lora_a_dequant"], + ref_out_tensor, + data["b_seq_start_loc"], + data["seq_len_tensor"], + data["lora_indices_tensor"], + batches, + max_seq_length, + total_tokens, + scaling, + ) + + # With dequantized reference, we can use much tighter tolerances + # since we're only measuring kernel error, not quantization error. + # Blockwise accumulation order differs from the bf16 reference, so + # allow a slightly larger margin for sporadic rounding outliers. + rtol, atol = 0.1, 0.25 + torch.testing.assert_close( + out_tensor.to(dtype), ref_out_tensor.to(dtype), rtol=rtol, atol=atol + ) + + +# ============================================================================ +# FP8 Expand Kernel Check +# ============================================================================ + + +def check_lora_expand_fp8_kernel( + batches: int, + num_loras: int, + rank: int, + hidden_size: int, + nslices: int, + dtype: torch.dtype, + device: str, + seq_length: int, + add_inputs: bool, + quant_mode: str, + group_k: int = 128, + group_n: int = 128, +): + """Test FP8 expand kernel (w8a8) against dequantized bf16 reference. + + Instead of comparing FP8 kernel output against the original bf16 reference + (which conflates quantization error with kernel error), we: + 1. Quantize bf16 inputs/weights to FP8 + 2. Dequantize them back to bf16 + 3. Run the bf16 reference (sgmv_expand) with the dequantized values + 4. Compare FP8 kernel output against this dequantized reference + + This isolates kernel correctness from quantization precision loss, + allowing much tighter tolerances. + """ + data = generate_fp8_expand_data( + batches, + hidden_size, + num_loras, + rank, + seq_length, + nslices, + dtype, + device, + quant_mode, + group_k, + group_n, + ) + + total_tokens = data["total_tokens"] + + # Setup LoRA kernel metadata + lora_meta = LoRAKernelMeta.make( + max_loras=num_loras, max_num_tokens=total_tokens, device=device + ) + lora_meta.prepare_tensors(data["token_lora_mapping"]) + + out_tensor = data["out_tensor"] + + # Determine quantization params for the kernel + per_channel = quant_mode == "per_channel" + gk = group_k if quant_mode == "blockwise" else 0 + gn = group_n if quant_mode == "blockwise" else 0 + + with _dict_lock: + _LORA_B_PTR_DICT.clear() + _EXPAND_LORA_SCALE_PTR_DICT.clear() + triton_ops.lora_expand_fp8( + data["inputs_fp8"], + data["lora_b_fp8"], + out_tensor, + *lora_meta.meta_args(token_nums=total_tokens, specialize_active_lora=False), + data["b_scales"], + a_scale=data["a_scale"], + offset_start=0, + add_inputs=add_inputs, + group_k=gk, + group_n=gn, + use_fp8_w8a8=True, + per_channel_quant=per_channel, + ) + + # Compute reference using dequantized (round-tripped) tensors. + ref_out_tensor = data["ref_out_tensor"] + max_seq_length = data["seq_len_tensor"].max().item() + sgmv_expand_for_nslices( + nslices, + hidden_size, + data["inputs_dequant"], + data["lora_b_dequant"], + ref_out_tensor, + data["b_seq_start_loc"], + data["seq_len_tensor"], + data["lora_indices_tensor"], + batches, + max_seq_length, + total_tokens, + add_inputs=add_inputs, + ) + + # With dequantized reference, we can use much tighter tolerances + # since we're only measuring kernel error, not quantization error. + rtol, atol = 0.1, 0.15 + torch.testing.assert_close(out_tensor, ref_out_tensor, rtol=rtol, atol=atol) + + +# ============================================================================ +# FP8 Test Parameters +# ============================================================================ + +fp8_test_params = { + "hidden_sizes": [512, 1024, 2048], + "batches": [1, 4, 16], + "num_loras": [1, 4, 8], + "max_ranks": [8, 16, 32, 64], +} + + +# ============================================================================ +# FP8 Shrink Tests +# ============================================================================ + + +@pytest.mark.parametrize("batches", fp8_test_params["batches"]) +@pytest.mark.parametrize("num_loras", fp8_test_params["num_loras"]) +@pytest.mark.parametrize("rank", fp8_test_params["max_ranks"]) +@pytest.mark.parametrize("hidden_size", fp8_test_params["hidden_sizes"]) +@pytest.mark.parametrize("nslices", [1, 2, 3]) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("device", DEVICES) +@pytest.mark.parametrize("seed", SEED) +@pytest.mark.parametrize("quant_mode", ["per_tensor", "per_channel", "blockwise"]) +def test_lora_shrink_fp8( + batches: int, + num_loras: int, + rank: int, + hidden_size: int, + nslices: int, + dtype: torch.dtype, + device: str, + seed: int, + quant_mode: str, +): + """Test FP8 shrink kernel with per-tensor, per-channel, and block-wise + quantization, comparing against the bf16 baseline.""" + torch.set_default_device(device) + set_random_seed(seed) + + # For blockwise, group sizes must divide evenly or be handled by the kernel + group_k = 128 + group_n = 128 + + # Adjust group sizes if they're larger than the dimensions + if quant_mode == "blockwise": + group_k = min(group_k, hidden_size) + group_n = min(group_n, rank) + + check_lora_shrink_fp8_kernel( + batches=batches, + num_loras=num_loras, + rank=rank, + hidden_size=hidden_size, + nslices=nslices, + dtype=dtype, + device=device, + seq_length=128, + scaling=0.5, + quant_mode=quant_mode, + group_k=group_k, + group_n=group_n, + ) + + +# ============================================================================ +# FP8 Expand Tests +# ============================================================================ + + +@pytest.mark.parametrize("batches", fp8_test_params["batches"]) +@pytest.mark.parametrize("num_loras", fp8_test_params["num_loras"]) +@pytest.mark.parametrize("rank", fp8_test_params["max_ranks"]) +@pytest.mark.parametrize("hidden_size", fp8_test_params["hidden_sizes"]) +@pytest.mark.parametrize("nslices", [1, 2, 3]) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("device", DEVICES) +@pytest.mark.parametrize("seed", SEED) +@pytest.mark.parametrize("quant_mode", ["per_tensor", "per_channel", "blockwise"]) +def test_lora_expand_fp8( + batches: int, + num_loras: int, + rank: int, + hidden_size: int, + nslices: int, + dtype: torch.dtype, + device: str, + seed: int, + quant_mode: str, +): + """Test FP8 expand kernel with per-tensor, per-channel, and block-wise + quantization, comparing against the bf16 baseline.""" + torch.set_default_device(device) + set_random_seed(seed) + + group_k = 128 + group_n = 128 + + # Adjust group sizes if they're larger than the dimensions + if quant_mode == "blockwise": + group_k = min(group_k, rank) + group_n = min(group_n, hidden_size) + + check_lora_expand_fp8_kernel( + batches=batches, + num_loras=num_loras, + rank=rank, + hidden_size=hidden_size, + nslices=nslices, + dtype=dtype, + device=device, + seq_length=128, + add_inputs=True, + quant_mode=quant_mode, + group_k=group_k, + group_n=group_n, + ) diff --git a/vllm/lora/ops/triton_ops/__init__.py b/vllm/lora/ops/triton_ops/__init__.py index 76587376a..687170b30 100644 --- a/vllm/lora/ops/triton_ops/__init__.py +++ b/vllm/lora/ops/triton_ops/__init__.py @@ -12,13 +12,17 @@ from vllm.lora.ops.triton_ops.fused_moe_lora_op import ( fused_moe_lora_expand, fused_moe_lora_shrink, ) +from vllm.lora.ops.triton_ops.lora_expand_fp8_op import lora_expand_fp8 from vllm.lora.ops.triton_ops.lora_expand_op import lora_expand from vllm.lora.ops.triton_ops.lora_kernel_metadata import LoRAKernelMeta +from vllm.lora.ops.triton_ops.lora_shrink_fp8_op import lora_shrink_fp8 from vllm.lora.ops.triton_ops.lora_shrink_op import lora_shrink __all__ = [ "lora_expand", + "lora_expand_fp8", "lora_shrink", + "lora_shrink_fp8", "LoRAKernelMeta", "fused_moe_lora", "fused_moe_lora_shrink", diff --git a/vllm/lora/ops/triton_ops/fp8_kernel_utils.py b/vllm/lora/ops/triton_ops/fp8_kernel_utils.py new file mode 100644 index 000000000..8429562c7 --- /dev/null +++ b/vllm/lora/ops/triton_ops/fp8_kernel_utils.py @@ -0,0 +1,603 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Utilities for Punica kernel construction. +""" + +from vllm.triton_utils import tl, triton + + +@triton.jit +def _accumulate_mm( + tiled_a, + tiled_b, + accumulator, + a_scale_ptr, + b_scale_ptr, + a_scale_k_stride, + b_scale_k_stride, + iter_k, + group_k: tl.constexpr, + group_n: tl.constexpr, + use_fp8_w8a8: tl.constexpr, +): + """ + Core matrix multiplication and accumulation logic with quantization support. + + Args: + tiled_a (tl.tensor): Loaded tile from A matrix + tiled_b (tl.tensor): Loaded tile from B matrix + accumulator (tl.tensor): Current accumulator value + a_scale_ptr (tl.tensor): Scale pointer for A matrix + b_scale_ptr (tl.tensor): Scale pointer for B matrix + a_scale_k_stride (int): K dimension stride for A's block-wise scales + b_scale_k_stride (int): K dimension stride for B's block-wise scales + iter_k (int): Current iteration's global K offset + group_k: Block size for K dimension in block-wise quantization + group_n: Block size for N dimension in block-wise quantization + use_fp8_w8a8: Whether using FP8 W8A8 quantization + """ + + if use_fp8_w8a8: + if group_k > 0 and group_n > 0: + # Block-wise quantization: scales are loaded per block + offs_ks = iter_k // group_k + # a_scale_ptr is (BLOCK_M,) tensor of base pointers per row + # Load scale for current K-group, result shape: (BLOCK_M,) + a_scale = tl.load(a_scale_ptr + offs_ks * a_scale_k_stride) + # b_scale_ptr is (BLOCK_N,) tensor with N-offset pre-baked + # Load scale for current K-group, result shape: (BLOCK_N,) + b_scale = tl.load(b_scale_ptr + offs_ks * b_scale_k_stride) + accumulator += ( + tl.dot(tiled_a, tiled_b) * a_scale[:, None] * b_scale[None, :] + ) + else: + # Tensor-wise or per-channel: accumulate and scale at end + accumulator = tl.dot(tiled_a, tiled_b, acc=accumulator) + else: + accumulator += tl.dot(tiled_a, tiled_b) + return accumulator + + +@triton.jit +def fp8_mm_k( + a_ptr, + b_ptr, + a_scale_ptr, + b_scale_ptr, + ak_stride, + bk_stride, + a_scale_k_stride, + b_scale_k_stride, + offset_k, + K: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + EVEN_K: tl.constexpr, + SPLIT_K: tl.constexpr, + group_k: tl.constexpr, + group_n: tl.constexpr, + use_fp8_w8a8: tl.constexpr, + per_channel_quant: tl.constexpr, + CAST_TYPE: tl.constexpr, + b_dtype: tl.constexpr, + USE_GDC: tl.constexpr, + base_k, +): + """ + FP8-compatible matrix multiplication kernel with quantization support. + Given a_ptr and b_ptr, that identify the rows of A (m x k) and columns of + B (k x n), iterate through the K dimension to compute the partial/complete + matrix block product with proper dequantization. + + Args: + a_ptr (tl.tensor): Array of pointers, identifying rows of A + (FP8 or other dtype) + b_ptr (tl.tensor): Array of pointers, identifying columns of B + (FP8 dtype) + a_scale_ptr (tl.tensor): Scale pointer for A matrix + (per-token or block-wise) + b_scale_ptr (tl.tensor): Scale pointer for B matrix + (per-channel or block-wise) + ak_stride (int): K dimension stride of the A matrix + bk_stride (int): K dimension stride of the B matrix + a_scale_k_stride (int): K dimension stride for A's block-wise scales + b_scale_k_stride (int): K dimension stride for B's block-wise scales + offset_k (int): Base offset along K dimension + K: Length of the K dimension + BLOCK_M: M dimension of the output block m x n + BLOCK_N: N dimension of the output block m x n + BLOCK_K: K dimension atom + EVEN_K: True if the blocks of A and B can be loaded without masking + SPLIT_K: Parameter signifying parallelism in the K dimension + group_k: Block size for K dimension in block-wise quantization + group_n: Block size for N dimension in block-wise quantization + use_fp8_w8a8: Whether using FP8 W8A8 quantization + per_channel_quant: Whether using per-channel quantization + CAST_TYPE: if True, cast the values from the A matrix to the B + matrix dtype. + b_dtype: datatype of the B matrix + USE_GDC: Whether to use PDL. True indicates use. + base_k (int): Base offset along K dimension for current SPLIT_K group + """ + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + + # Step size along K for each iteration + STEP_K = BLOCK_K * SPLIT_K + + # Total number of iterations (compile-time constant) + num_iters = tl.cdiv(K, STEP_K) + + for k in range(num_iters): + # Current iteration's global K offset + iter_k = k * STEP_K + base_k + block_end = iter_k + BLOCK_K + + # Skip iterations that are entirely past the K boundary + if not EVEN_K and iter_k >= K: + pass + elif EVEN_K or block_end <= K: + # No masking needed: either K is evenly divisible (EVEN_K) + # or this block fits entirely within K + tiled_b = tl.load(b_ptr) + if USE_GDC: + tl.extra.cuda.gdc_wait() + tiled_a = tl.load(a_ptr) + if CAST_TYPE: + tiled_a = tiled_a.to(b_dtype) + + accumulator = _accumulate_mm( + tiled_a, + tiled_b, + accumulator, + a_scale_ptr, + b_scale_ptr, + a_scale_k_stride, + b_scale_k_stride, + iter_k, + group_k, + group_n, + use_fp8_w8a8, + ) + else: + # Partial block at the tail: mask out-of-bounds elements + k_offsets = tl.arange(0, BLOCK_K) + mask = iter_k + k_offsets < K + tiled_b = tl.load(b_ptr, mask=mask[:, None], other=0.0) + if USE_GDC: + tl.extra.cuda.gdc_wait() + tiled_a = tl.load(a_ptr, mask=mask[None, :], other=0.0) + if CAST_TYPE: + tiled_a = tiled_a.to(b_dtype) + + accumulator = _accumulate_mm( + tiled_a, + tiled_b, + accumulator, + a_scale_ptr, + b_scale_ptr, + a_scale_k_stride, + b_scale_k_stride, + iter_k, + group_k, + group_n, + use_fp8_w8a8, + ) + + a_ptr += STEP_K * ak_stride + b_ptr += STEP_K * bk_stride + + return accumulator + + +@triton.jit +def do_shrink_kernel_fp8( + pid_n, + pid_sk, + slice_id, + lora_index, + input_ptr, + lora_ptr, + out_ptr, + a_scale_ptr, + b_scale_ptr, + N, + K, + M_LEN, + ram, + # input strides + input_d0_stride, + input_d1_stride, + # lora strides + lora_d0_stride, + lora_d1_stride, + lora_d2_stride, + # scale strides + a_scale_m_stride, + a_scale_k_stride, + b_scale_l_stride, + b_scale_n_stride, + b_scale_k_stride, + # output strides + output_d0_stride, + output_d1_stride, + output_d2_stride, + scaling, + # block size for block-wise quantization + group_n: tl.constexpr, + group_k: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + EVEN_K: tl.constexpr, + SPLIT_K: tl.constexpr, + SLICE_NUM: tl.constexpr, + USE_GDC: tl.constexpr, + use_fp8_w8a8: tl.constexpr, + per_channel_quant: tl.constexpr, + launch_pdl: tl.constexpr, +): + """ + Given an array of integers that identifies the rows of A, ram, + a lora index that identifies which LoRA to use from lora_ptr, lora_index, + a slice_id that identifies the input/output slice, compute the + matrix product and store in the appropriate output location. + """ + + # Identify the lora_ptr from slice_id. + if SLICE_NUM == 1: + cur_lora_ptr = lora_ptr + cur_b_scale_ptr = b_scale_ptr + else: + cur_lora_ptr = ( + tl.load(lora_ptr + slice_id).to(tl.pointer_type(tl.float8e4nv)) + if b_scale_ptr is not None + else tl.load(lora_ptr + slice_id).to( + tl.pointer_type(input_ptr.dtype.element_ty) + ) + ) + cur_b_scale_ptr = ( + tl.load(b_scale_ptr + slice_id).to(tl.pointer_type(tl.float32)) + if b_scale_ptr is not None + else b_scale_ptr + ) + + # Identify the column indices of B to process. + offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + rbn = tl.max_contiguous(tl.multiple_of(offset_n % N, BLOCK_N), BLOCK_N) + + # Identify A and B block pointers + offset_k = pid_sk * BLOCK_K + tl.arange(0, BLOCK_K) + a_ptr = ( + input_ptr + ram[:, None] * input_d0_stride + offset_k[None, :] * input_d1_stride + ) + b_ptr = ( + cur_lora_ptr + + lora_d0_stride * lora_index + + rbn[None, :] * lora_d1_stride + + offset_k[:, None] * lora_d2_stride + ) + + # Load scales for tensor-wise or per-channel quantization (outside the loop) + # Block-wise scales are loaded inside fp8_mm_k + if use_fp8_w8a8: + if group_k > 0 and group_n > 0: + # Block-wise: compute scale pointers for fp8_mm_k + # a_scale: per-row base pointers, shape (BLOCK_M,) + # Each pointer points to the start of that row's scale data + mm_a_scale_ptr = a_scale_ptr + ram * a_scale_m_stride + + # b_scale: pre-compute N-dimension offset + # We need to bake in the N-group offset since fp8_mm_k doesn't know pid_n + n_offset = pid_n * BLOCK_N + offs_ns = (n_offset + tl.arange(0, BLOCK_N)) // group_n + # Base pointer with lora offset + N-group offset baked in, shape (BLOCK_N,) + mm_b_scale_ptr = ( + cur_b_scale_ptr + + lora_index * b_scale_l_stride + + offs_ns * b_scale_n_stride + ) + elif per_channel_quant: + # Per-channel for weights, per-token for activations + b_scale_ptrs = ( + cur_b_scale_ptr + lora_index * b_scale_l_stride + rbn * b_scale_n_stride + ) + b_scale = tl.load(b_scale_ptrs) + # Per-token activation scale + a_scale = tl.load(a_scale_ptr + ram * a_scale_m_stride)[:, None] + # For non-block-wise, pass original pointers (not used in mm loop) + mm_a_scale_ptr = a_scale_ptr + mm_b_scale_ptr = cur_b_scale_ptr + else: + # Tensor-wise quantization + a_scale = tl.load(a_scale_ptr) if a_scale_ptr is not None else 1.0 + b_scale = tl.load(cur_b_scale_ptr + lora_index * b_scale_l_stride) + # For non-block-wise, pass original pointers (not used in mm loop) + mm_a_scale_ptr = a_scale_ptr + mm_b_scale_ptr = cur_b_scale_ptr + else: + # Non-quantized path + mm_a_scale_ptr = a_scale_ptr + mm_b_scale_ptr = cur_b_scale_ptr + + # Compute partial/complete block matrix product. + accumulator = fp8_mm_k( + a_ptr, + b_ptr, + mm_a_scale_ptr, + mm_b_scale_ptr, + input_d1_stride, + lora_d2_stride, + a_scale_k_stride, + b_scale_k_stride, + offset_k, + K, + BLOCK_M, + BLOCK_N, + BLOCK_K, + EVEN_K, + SPLIT_K, + group_k, + group_n, + use_fp8_w8a8, + per_channel_quant, + False, + cur_lora_ptr.dtype.element_ty, + USE_GDC, + base_k=pid_sk * BLOCK_K, + ) + # GDC launch dependents hints the runtime system to launch dependent kernels. + if USE_GDC: + tl.extra.cuda.gdc_launch_dependents() + + # Apply dequantization scales for tensor-wise/per-channel quantization + if use_fp8_w8a8: + if group_k > 0 and group_n > 0: + # Block-wise: already applied in fp8_mm_k + pass + else: + # Tensor-wise or per-channel: apply scales after accumulation + accumulator = accumulator * a_scale * b_scale + + # Apply LoRA scaling factor + accumulator *= scaling + + # Identify the C output pointers to store the results of the accumulator. + offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + offset_cm = tl.arange(0, BLOCK_M) + cur_out_ptr = out_ptr if SLICE_NUM == 1 else out_ptr + slice_id * output_d0_stride + c_ptr = ( + cur_out_ptr + + ram[:, None] * output_d1_stride + + offset_cn[None, :] * output_d2_stride + ) + c_mask = (offset_cm[:, None] < M_LEN) & (offset_cn[None, :] < N) + + # Cast accumulator to output dtype + accumulator = accumulator.to(out_ptr.dtype.element_ty) + + # handles write-back with reduction-splitting + if SPLIT_K == 1: + tl.store(c_ptr, accumulator, mask=c_mask) + else: + tl.atomic_add(c_ptr, accumulator, mask=c_mask, sem="relaxed") + + +@triton.jit +def do_expand_kernel_fp8( + pid_n, + lora_index, + slice_id, + input_ptr, + lora_ptr, + out_ptr, + a_scale_ptr, + b_scale_ptr, + N, + K, + M_LEN, + ram, # array identifying the rows of Input ptr to operate on + slice_start_loc, + # input ptr strides + input_d0_stride, + input_d1_stride, + input_d2_stride, + # lora ptr strides + ls_d0_ptr, + ls_d1_ptr, + ls_d2_ptr, + # scale strides + a_scale_m_stride, + a_scale_k_stride, + b_scale_l_stride, + b_scale_n_stride, + b_scale_k_stride, + # out ptr strides + output_d0_stride, + output_d1_stride, + # block size for block-wise quantization + group_n: tl.constexpr, + group_k: tl.constexpr, + # constants + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + SAME_STRIDE: tl.constexpr, + SLICE_NUM: tl.constexpr, + EVEN_K: tl.constexpr, + CAST_TYPE: tl.constexpr, + ADD_INPUTS: tl.constexpr, + USE_GDC: tl.constexpr, + use_fp8_w8a8: tl.constexpr, + per_channel_quant: tl.constexpr, +): + """ + FP8-compatible expand kernel for LoRA. + Given an array of integers that identifies the rows of A, ram, + a lora index that identifies which LoRA to use from lora_ptr, lora_index, + a slice_id that identifies the input/output slice, + compute the matrix product with FP8 quantization support and store in + the appropriate output location. + + For expand kernel, the input (shrink output) may be in FP32/FP16/BF16, + while the LoRA B weights can be in FP8. + + Supports: + - FP8 W8A8 quantization for LoRA B weights + - Block-wise quantization with configurable group_k and group_n + - Per-channel quantization + - Tensor-wise quantization + """ + + # ls_d*_ptr can be either an integer or a pointer + if SAME_STRIDE: + cur_lora_d0_stride = ls_d0_ptr + cur_lora_d1_stride = ls_d1_ptr + cur_lora_d2_stride = ls_d2_ptr + else: + cur_lora_d0_stride = tl.load(ls_d0_ptr + slice_id) + cur_lora_d1_stride = tl.load(ls_d1_ptr + slice_id) + cur_lora_d2_stride = tl.load(ls_d2_ptr + slice_id) + + # Identify the input_ptr and lora_ptr from slice_id. + if SLICE_NUM == 1: + cur_input_ptr = input_ptr + if use_fp8_w8a8: + cur_lora_ptr = lora_ptr + cur_b_scale_ptr = b_scale_ptr + else: + cur_lora_ptr = lora_ptr + cur_b_scale_ptr = b_scale_ptr # May be None for non-quantized + else: + cur_input_ptr = input_ptr + slice_id * input_d0_stride + if use_fp8_w8a8: + cur_lora_ptr = tl.load(lora_ptr + slice_id).to( + tl.pointer_type(tl.float8e4nv) + ) + cur_b_scale_ptr = tl.load(b_scale_ptr + slice_id).to( + tl.pointer_type(tl.float32) + ) + else: + cur_lora_ptr = tl.load(lora_ptr + slice_id).to( + tl.pointer_type(out_ptr.dtype.element_ty) + ) + cur_b_scale_ptr = ( + tl.load(b_scale_ptr + slice_id).to(tl.pointer_type(tl.float32)) + if b_scale_ptr is not None + else None + ) + + # Identify the column indices of B to process. + offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + rbn = tl.max_contiguous(tl.multiple_of(offset_n % N, BLOCK_N), BLOCK_N) + + # Identify A and B block pointers + offset_k = tl.arange(0, BLOCK_K) + a_ptr = ( + cur_input_ptr + + ram[:, None] * input_d1_stride + + offset_k[None, :] * input_d2_stride + ) + b_ptr = ( + cur_lora_ptr + + cur_lora_d0_stride * lora_index + + offset_k[:, None] * cur_lora_d2_stride + + rbn[None, :] * cur_lora_d1_stride + ) + + # Setup scale pointers for FP8/INT8 quantization + if use_fp8_w8a8: + if group_k > 0 and group_n > 0: + # Block-wise quantization - compute scale pointers for fp8_mm_k + # a_scale: per-row base pointers, shape (BLOCK_M,) + mm_a_scale_ptr = a_scale_ptr + ram * a_scale_m_stride + + # b_scale: pre-compute N-dimension offset since fp8_mm_k doesn't know pid_n + n_offset = pid_n * BLOCK_N + offs_ns = (n_offset + tl.arange(0, BLOCK_N)) // group_n + # Base pointer with lora offset + N-group offset baked in, shape (BLOCK_N,) + mm_b_scale_ptr = ( + cur_b_scale_ptr + + lora_index * b_scale_l_stride + + offs_ns * b_scale_n_stride + ) + elif per_channel_quant: + # Per-channel for weights, shape (BLOCK_N,) + b_scale_ptrs = ( + cur_b_scale_ptr + lora_index * b_scale_l_stride + rbn * b_scale_n_stride + ) + b_scale = tl.load(b_scale_ptrs) + # Per-token activation scale, only if a_scale_ptr provided + a_scale = tl.load(a_scale_ptr + ram * a_scale_m_stride)[:, None] + # For non-block-wise, pass original pointers (not used in mm loop) + mm_a_scale_ptr = a_scale_ptr + mm_b_scale_ptr = cur_b_scale_ptr + else: + # Tensor-wise quantization + a_scale = tl.load(a_scale_ptr) if a_scale_ptr is not None else 1.0 + b_scale = tl.load(cur_b_scale_ptr + lora_index * b_scale_l_stride) + # For non-block-wise, pass original pointers (not used in mm loop) + mm_a_scale_ptr = a_scale_ptr + mm_b_scale_ptr = cur_b_scale_ptr + else: + # Non-quantized path + mm_a_scale_ptr = a_scale_ptr + mm_b_scale_ptr = cur_b_scale_ptr + + # Compute the block matrix product using fp8_mm_k + # Note: For expand kernel, SPLIT_K=1, so we pass 1 for SPLIT_K + accumulator = fp8_mm_k( + a_ptr, + b_ptr, + mm_a_scale_ptr, + mm_b_scale_ptr, + input_d2_stride, # ak_stride + cur_lora_d2_stride, # bk_stride + a_scale_k_stride, + b_scale_k_stride, + offset_k, + K, + BLOCK_M, + BLOCK_N, + BLOCK_K, + EVEN_K, + 1, # SPLIT_K = 1 for expand kernel + group_k, + group_n, + use_fp8_w8a8, + per_channel_quant, + CAST_TYPE, # CAST_TYPE - cast FP8 B to A's dtype + cur_lora_ptr.dtype.element_ty, + USE_GDC, + base_k=0, + ) + + # Apply dequantization scales for non-block-wise quantization + if use_fp8_w8a8: + if group_k > 0 and group_n > 0: + pass # Already applied per block in fp8_mm_k + else: + # Tensor-wise or per-channel: apply scales after accumulation + accumulator = accumulator * a_scale * b_scale + + tiled_c = accumulator.to(out_ptr.dtype.element_ty) + if SLICE_NUM == 1: + cur_slice_start = slice_start_loc + else: + cur_slice_start = tl.load(slice_start_loc + slice_id) + + # Identify the C output pointers to store the results of the accumulator. + offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + cur_slice_start + offset_cm = tl.arange(0, BLOCK_M) + c_ptr = ( + out_ptr + + ram[:, None] * output_d0_stride + + offset_cn[None, :] * output_d1_stride + ) + c_mask = (offset_cm[:, None] < M_LEN) & (offset_cn[None, :] < (cur_slice_start + N)) + + if ADD_INPUTS: + tiled_out = tl.load(c_ptr, mask=c_mask) + tiled_c += tiled_out + tl.store(c_ptr, tiled_c, mask=c_mask) diff --git a/vllm/lora/ops/triton_ops/lora_expand_fp8_op.py b/vllm/lora/ops/triton_ops/lora_expand_fp8_op.py new file mode 100644 index 000000000..d5850f118 --- /dev/null +++ b/vllm/lora/ops/triton_ops/lora_expand_fp8_op.py @@ -0,0 +1,403 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Based on: +Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023). +Punica: Multi-Tenant LoRA Serving. +https://arxiv.org/abs/2310.18547 +""" + +import torch + +from vllm.lora.ops.triton_ops.fp8_kernel_utils import do_expand_kernel_fp8 +from vllm.lora.ops.triton_ops.utils import ( + _get_lora_b_ptr, + get_lora_op_configs, +) +from vllm.triton_utils import tl, triton +from vllm.utils.torch_utils import direct_register_custom_op + +_EXPAND_LORA_SCALE_PTR_DICT: dict[tuple[int, ...], torch.tensor] = {} + + +def _get_expand_lora_scale_ptr(lora_weights: list[torch.Tensor], device: torch.device): + """ + `_EXPAND_LORA_SCALE_PTR_DICT` collects the required information during + `profile_run`, + After this, it remains constant and subsequent usage is through LUT. + Refer to: + https://github.com/triton-lang/triton/blob/release/3.1.x/python/tutorials/08-grouped-gemm.py + """ + key = tuple(lora_weight.data_ptr() for lora_weight in lora_weights) + + if (ptr_tensor := _EXPAND_LORA_SCALE_PTR_DICT.get(key)) is not None: + return ptr_tensor + + if len(lora_weights) > 1: + tensor_ptrs = [] + for lora_weight in lora_weights: + tensor_ptrs.append(lora_weight.data_ptr()) + ptr_tensor = torch.tensor(tensor_ptrs, device=device, dtype=torch.uint64) + else: + # Single slice: return the actual tensor so the kernel can use it + # directly without pointer indirection (matches SLICE_NUM == 1 path). + ptr_tensor = lora_weights[0] + + _EXPAND_LORA_SCALE_PTR_DICT[key] = ptr_tensor + return _EXPAND_LORA_SCALE_PTR_DICT.get(key) + + +@triton.jit +def _lora_expand_kernel_fp8( + input_ptr, + lora_ptr, + out_ptr, + a_scale_ptr, + b_scale_ptr, + M, + N, + K, + token_indices_sorted_by_lora_ids, + num_tokens_per_lora, + lora_token_start_loc, + lora_ids, + slice_start_loc, + input_d0_stride, + input_d1_stride, + input_d2_stride, + ls_d0_ptr, + ls_d1_ptr, + ls_d2_ptr, + a_scale_m_stride, + a_scale_k_stride, + b_scale_l_stride, + b_scale_n_stride, + b_scale_k_stride, + output_d0_stride, + output_d1_stride, + output_hs_ptr, + group_n: tl.constexpr, + group_k: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + EVEN_K: tl.constexpr, + ADD_INPUTS: tl.constexpr, + CAST_TYPE: tl.constexpr, + SLICE_NUM: tl.constexpr, + SAME_STRIDE: tl.constexpr, + USE_GDC: tl.constexpr, + use_fp8_w8a8: tl.constexpr, + per_channel_quant: tl.constexpr, + launch_pdl: tl.constexpr, +): + """ + FP8-compatible expand kernel wrapper. + """ + cta_n_num = tl.cdiv(N, BLOCK_N) + cta_m_num = tl.cdiv(M, BLOCK_M) + + pid_mn = tl.program_id(axis=0) + pid_m = pid_mn % cta_m_num + pid_n = (pid_mn // cta_m_num) % cta_n_num + + slice_id = tl.program_id(axis=1) + lora_idx = tl.program_id(axis=2) + + lora_id = tl.load(lora_ids + lora_idx) + if lora_id == -1: + return + + lora_m_size = tl.load(num_tokens_per_lora + lora_idx) + + cta_m_offset = pid_m * BLOCK_M + if cta_m_offset >= lora_m_size: + return + + curr_N = N if SAME_STRIDE else tl.load(output_hs_ptr + slice_id) + if pid_n * BLOCK_N >= curr_N: + return + + cta_m_len = min(BLOCK_M, lora_m_size - cta_m_offset) + + lora_m_indices_start = tl.load(lora_token_start_loc + lora_idx) + cta_lora_seq_indices = ( + token_indices_sorted_by_lora_ids + lora_m_indices_start + cta_m_offset + ) + + offset_m = tl.arange(0, BLOCK_M) % cta_m_len + ram = tl.load(cta_lora_seq_indices + offset_m) + + do_expand_kernel_fp8( + pid_n, + lora_id, + slice_id, + input_ptr, + lora_ptr, + out_ptr, + a_scale_ptr, + b_scale_ptr, + curr_N, + K, + cta_m_len, + ram, + slice_start_loc, + input_d0_stride, + input_d1_stride, + input_d2_stride, + ls_d0_ptr, + ls_d1_ptr, + ls_d2_ptr, + a_scale_m_stride, + a_scale_k_stride, + b_scale_l_stride, + b_scale_n_stride, + b_scale_k_stride, + output_d0_stride, + output_d1_stride, + group_n, + group_k, + BLOCK_M, + BLOCK_N, + BLOCK_K, + SAME_STRIDE, + SLICE_NUM, + EVEN_K, + CAST_TYPE, + ADD_INPUTS, + USE_GDC, + use_fp8_w8a8, + per_channel_quant, + ) + + +@torch.inference_mode() +def _lora_expand_fp8( + inputs: torch.Tensor, # shape [num_slices, num_tokens, lora_rank] + lora_b_weights: list[torch.Tensor], # FP8 [num_lora, hidden_size, lora_rank] + output_tensor: torch.Tensor, # shape [num_tokens, hidden_size * num_slices] + token_lora_mapping: torch.Tensor, + token_indices_sorted_by_lora_ids: torch.Tensor, + num_tokens_per_lora: torch.Tensor, + lora_token_start_loc: torch.Tensor, + lora_ids: torch.Tensor, + no_lora_flag_cpu: torch.Tensor, # shape [1] + num_active_loras: int, # number of active LoRAs (unused here, for API compat) + b_scale: list[torch.Tensor], # LoRA B weight scale per slice + a_scale: torch.Tensor | None = None, # Scale for shrink output (optional) + offset_start: int = 0, + add_inputs: bool = False, + group_k: int = 0, + group_n: int = 0, + use_fp8_w8a8: bool = False, + per_channel_quant: bool = False, +) -> None: + """ + FP8-compatible LoRA expand operation. + + Args: + inputs: Input tensor from shrink operation [num_slices, num_tokens, lora_rank] + lora_b_weights: List of FP8 LoRA B weights per slice + output_tensor: Output tensor + a_scale: Optional scale for input (if input is quantized) + b_scale: Weight quantization scales per slice + token_lora_mapping: Token to LoRA ID mapping + token_indices_sorted_by_lora_ids: Sorted token indices + num_tokens_per_lora: Number of tokens per LoRA + lora_token_start_loc: Start location for each LoRA's tokens + lora_ids: LoRA IDs to process + no_lora_flag_cpu (torch.Tensor): A CPU tensor of size 1, that indicates + if there are any requests that require LoRA. + offset_start (int, optional): Offset start for output_tensor. + Defaults to 0. + add_inputs (bool, optional): Whether to add the input tensor to the + output tensor. Defaults to False. + group_k (int, optional): Block size for K in block-wise quantization. + group_n (int, optional): Block size for N in block-wise quantization. + use_fp8_w8a8 (bool, optional): Whether to use FP8 W8A8 quantization. + per_channel_quant (bool, optional): Whether to use per-channel quantization. + """ + assert no_lora_flag_cpu.numel() == 1 + if no_lora_flag_cpu.item(): + # None of the inputs require LoRA. + return + + if use_fp8_w8a8: + assert inputs.dtype in [ + torch.float8_e4m3fn, + torch.float8_e5m2, + ] + for weight in lora_b_weights: + assert weight.dtype in [ + torch.float8_e5m2, + torch.float8_e4m3fn, + ] + else: + assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32] + for weight in lora_b_weights: + assert weight.dtype in [torch.float16, torch.bfloat16] + assert inputs.size(0) == len(lora_b_weights) + assert output_tensor.is_contiguous() + + # metadata sanity check. + M = inputs.size(1) + assert token_lora_mapping.size(0) == M + assert token_lora_mapping.size(0) == token_indices_sorted_by_lora_ids.size(0) + assert lora_ids.size(0) == num_tokens_per_lora.size(0) + assert lora_token_start_loc.size(0) == lora_ids.size(0) + 1 + + ( + slice_start_tensor, + lora_ptr_tensor, + lora_strides_d0_tensor, + lora_strides_d1_tensor, + lora_strides_d2_tensor, + hidden_sizes_tensor, + same_stride, + MAX_N, + ) = _get_lora_b_ptr(lora_b_weights, offset_start, inputs.device) + + # Get scale pointers + if b_scale is not None: + b_scale_ptr_tensor = _get_expand_lora_scale_ptr(b_scale, inputs.device) + else: + b_scale_ptr_tensor = None + K = lora_b_weights[0].shape[-1] + ADD_INPUTS = add_inputs + MAX_LORAS = lora_ids.size(0) + + CAST_TYPE = False + NUM_SLICES = len(lora_b_weights) + + # Triton kernel configs. + kernel_config = get_lora_op_configs( + op_type="expand", + max_loras=MAX_LORAS, + batch=M, + hidden_size=MAX_N, + rank=K, + num_slices=NUM_SLICES, + add_inputs=add_inputs, + ) + BLOCK_M = kernel_config["block_m"] + BLOCK_N = kernel_config["block_n"] + BLOCK_K = kernel_config["block_k"] + NUM_WARPS = kernel_config["num_warps"] + NUM_CTAS = kernel_config.get("num_ctas", 1) + NUM_STAGES = kernel_config["num_stages"] + + EVEN_K = K % BLOCK_K == 0 + + grid = ( + triton.cdiv(M, BLOCK_M) * triton.cdiv(MAX_N, BLOCK_N), + NUM_SLICES, + num_active_loras, + ) + # We disable PDL temporarily because LoRA kernels are not launching back-to-back, + # making PDL invalid and affecting the kernel performance. + use_gdc = False # supports_pdl(inputs.device) + # Get scale strides + if a_scale is not None: + a_scale_m_stride = a_scale.stride(0) if a_scale.dim() > 1 else 0 + a_scale_k_stride = a_scale.stride(-1) if a_scale.dim() > 1 else 0 + else: + a_scale_m_stride = 0 + a_scale_k_stride = 0 + + if b_scale is not None and b_scale[0].dim() > 0: + b_scale_l_stride = b_scale[0].stride(0) if b_scale[0].dim() > 0 else 0 + b_scale_n_stride = ( + b_scale[0].stride(-2) + if b_scale[0].dim() > 2 + else (b_scale[0].stride(-1) if b_scale[0].dim() > 1 else 1) + ) + b_scale_k_stride = b_scale[0].stride(-1) if b_scale[0].dim() > 2 else 0 + else: + b_scale_l_stride = 1 + b_scale_n_stride = 0 + b_scale_k_stride = 0 + + _lora_expand_kernel_fp8[grid]( + inputs, + lora_ptr_tensor, + output_tensor, + a_scale, + b_scale_ptr_tensor, + M, + MAX_N, + K, + token_indices_sorted_by_lora_ids, + num_tokens_per_lora, + lora_token_start_loc, + lora_ids, + slice_start_tensor, + inputs.stride(0), + inputs.stride(1), + inputs.stride(2), + lora_strides_d0_tensor, + lora_strides_d1_tensor, + lora_strides_d2_tensor, + a_scale_m_stride, + a_scale_k_stride, + b_scale_l_stride, + b_scale_n_stride, + b_scale_k_stride, + output_tensor.stride(0), + output_tensor.stride(1), + hidden_sizes_tensor, + group_n, + group_k, + BLOCK_M, + BLOCK_N, + BLOCK_K, + EVEN_K, + ADD_INPUTS, + CAST_TYPE, + NUM_SLICES, + same_stride, + use_gdc, + use_fp8_w8a8=use_fp8_w8a8, + per_channel_quant=per_channel_quant, + num_warps=NUM_WARPS, + num_ctas=NUM_CTAS, + num_stages=NUM_STAGES, + launch_pdl=use_gdc, + ) + + return + + +def _lora_expand_fp8_fake( + inputs: torch.Tensor, + lora_b_weights: list[torch.Tensor], + output_tensor: torch.Tensor, + token_lora_mapping: torch.Tensor, + token_indices_sorted_by_lora_ids: torch.Tensor, + num_tokens_per_lora: torch.Tensor, + lora_token_start_loc: torch.Tensor, + lora_ids: torch.Tensor, + no_lora_flag_cpu: torch.Tensor, + num_active_loras: int, + b_scale: list[torch.Tensor], + a_scale: torch.Tensor | None = None, + offset_start: int = 0, + add_inputs: bool = False, + group_k: int = 0, + group_n: int = 0, + use_fp8_w8a8: bool = False, + per_channel_quant: bool = False, +) -> None: + return + + +try: + direct_register_custom_op( + op_name="lora_expand_fp8", + op_func=_lora_expand_fp8, + mutates_args=["output_tensor"], + fake_impl=_lora_expand_fp8_fake, + ) + lora_expand_fp8 = torch.ops.vllm.lora_expand_fp8 + +except AttributeError: + lora_expand_fp8 = _lora_expand_fp8 diff --git a/vllm/lora/ops/triton_ops/lora_shrink_fp8_op.py b/vllm/lora/ops/triton_ops/lora_shrink_fp8_op.py new file mode 100644 index 000000000..d58368753 --- /dev/null +++ b/vllm/lora/ops/triton_ops/lora_shrink_fp8_op.py @@ -0,0 +1,429 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Based on: +Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023). +Punica: Multi-Tenant LoRA Serving. +https://arxiv.org/abs/2310.18547 +""" + +import torch + +from vllm.lora.ops.triton_ops.fp8_kernel_utils import do_shrink_kernel_fp8 +from vllm.lora.ops.triton_ops.utils import _get_lora_a_ptr, get_lora_op_configs +from vllm.triton_utils import tl, triton +from vllm.utils.torch_utils import direct_register_custom_op + +_SHRINK_LORA_SCALE_PTR_DICT: dict[tuple[int, ...], tuple] = {} + + +def _get_shrink_lora_scale_ptr( + lora_scale_weights: list[torch.Tensor], device: torch.device +): + """ + `_SHRINK_LORA_SCALE_PTR_DICT` collects the required information during + `profile_run`. After this, it remains constant and subsequent usage is + through LUT. + + Returns a tuple of (scale_ptr_tensor, l_stride, n_stride, k_stride). + + Supports scale tensors of varying dimensionality: + - 1D: (lora_num,) — tensor-wise quantization + - 2D: (lora_num, N) — per-channel quantization + - 3D: (lora_num, N, K) — block-wise quantization + - 4D: (lora_num, 1, N, K) — block-wise with extra dim (squeezed to 3D) + + Refer to: + https://github.com/triton-lang/triton/blob/release/3.1.x/python/tutorials/08-grouped-gemm.py + """ + key = tuple(lora_weight.data_ptr() for lora_weight in lora_scale_weights) + + if values := _SHRINK_LORA_SCALE_PTR_DICT.get(key): + return values + + tensor_ptrs = [] + scale_l_strides = [] + scale_n_strides = [] + scale_k_strides = [] + for lora_scale_weight in lora_scale_weights: + if lora_scale_weight.ndim == 4: # shape:(lora_num,1,size,rank) + assert lora_scale_weight.size(1) == 1 + lora_scale_weight = lora_scale_weight.squeeze(dim=1) + assert 1 <= lora_scale_weight.ndim <= 3 + assert lora_scale_weight.is_contiguous() + tensor_ptrs.append(lora_scale_weight.data_ptr()) + scale_l_strides.append( + lora_scale_weight.stride(0) if lora_scale_weight.ndim > 0 else 0 + ) + scale_n_strides.append( + lora_scale_weight.stride(-2) + if lora_scale_weight.ndim > 2 + else (lora_scale_weight.stride(-1) if lora_scale_weight.ndim > 1 else 1) + ) + scale_k_strides.append( + lora_scale_weight.stride(-1) if lora_scale_weight.ndim > 2 else 0 + ) + if len(lora_scale_weights) > 1: + scale_ptr_tensor = torch.tensor(tensor_ptrs, device=device, dtype=torch.uint64) + else: + scale_ptr_tensor = lora_scale_weights[0] + + if ( + len(set(scale_l_strides)) > 1 + or len(set(scale_n_strides)) > 1 + or len(set(scale_k_strides)) > 1 + ): + raise ValueError("All LoRA scale weights must have the same stride.") + + _SHRINK_LORA_SCALE_PTR_DICT[key] = ( + scale_ptr_tensor, + scale_l_strides[0], + scale_n_strides[0], + scale_k_strides[0], + ) + return _SHRINK_LORA_SCALE_PTR_DICT.get(key) + + +@triton.jit +def _lora_shrink_kernel_fp8( + input_ptr, + lora_ptr, + out_ptr, + a_scale_ptr, + b_scale_ptr, + M, + N, + K, + token_indices_sorted_by_lora_ids, + num_tokens_per_lora, + lora_token_start_loc, + lora_ids, + scaling, + input_d0_stride, + input_d1_stride, + lora_d0_stride, + lora_d1_stride, + lora_d2_stride, + a_scale_m_stride, + a_scale_k_stride, + b_scale_l_stride, + b_scale_n_stride, + b_scale_k_stride, + output_d0_stride, + output_d1_stride, + output_d2_stride, + group_n: tl.constexpr, + group_k: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + EVEN_K: tl.constexpr, + SPLIT_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + SLICE_NUM: tl.constexpr, + USE_GDC: tl.constexpr, ## should always be false in shrink kernel + use_fp8_w8a8: tl.constexpr, + per_channel_quant: tl.constexpr, + launch_pdl: tl.constexpr, +): + cta_n_num = tl.cdiv(N, BLOCK_N) + cta_m_num = tl.cdiv(M, BLOCK_M) + + pid_sk_m_n = tl.program_id(axis=0) + pid_sk = pid_sk_m_n % SPLIT_K + + pid_m_n = pid_sk_m_n // SPLIT_K + num_pid_in_group = GROUP_SIZE_M * cta_n_num + group_id = pid_m_n // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + + group_size_m = min(cta_m_num - first_pid_m, GROUP_SIZE_M) + + # Column-major ordering within groups for better cache reuse + pid_m = first_pid_m + ((pid_m_n % num_pid_in_group) % group_size_m) + pid_n = (pid_m_n % num_pid_in_group) // group_size_m + + slice_id = tl.program_id(axis=1) + lora_idx = tl.program_id(axis=2) + + lora_id = tl.load(lora_ids + lora_idx) + if lora_id == -1: + # Early exit for the no-lora case. + return + + lora_m_size = tl.load(num_tokens_per_lora + lora_idx) + + cta_m_offset = pid_m * BLOCK_M + if cta_m_offset >= lora_m_size: + # Early exit CTA. + return + + # num rows this CTA should process. + cta_m_len = min(BLOCK_M, lora_m_size - cta_m_offset) + + # Identify all rows that this CTA should process. + lora_m_indices_start = tl.load(lora_token_start_loc + lora_idx) + cta_lora_seq_indices = ( + token_indices_sorted_by_lora_ids + lora_m_indices_start + cta_m_offset + ) + + # Load all relevant row indices. + offset_m = tl.arange(0, BLOCK_M) % cta_m_len + ram = tl.load(cta_lora_seq_indices + offset_m) + + do_shrink_kernel_fp8( + pid_n, + pid_sk, + slice_id, + lora_id, + input_ptr, + lora_ptr, + out_ptr, + a_scale_ptr, + b_scale_ptr, + N, + K, + cta_m_len, + ram, # array identifying the rows of Input ptr to operate on + # input strides + input_d0_stride, + input_d1_stride, + # lora strides + lora_d0_stride, + lora_d1_stride, + lora_d2_stride, + # scale strides + a_scale_m_stride, + a_scale_k_stride, + b_scale_l_stride, + b_scale_n_stride, + b_scale_k_stride, + # output strides + output_d0_stride, + output_d1_stride, + output_d2_stride, + scaling, + # block size for block-wise quantization + group_n, + group_k, + BLOCK_M, + BLOCK_N, + BLOCK_K, + EVEN_K, + SPLIT_K, + SLICE_NUM, + USE_GDC, + use_fp8_w8a8, + per_channel_quant, + launch_pdl, + ) + + +@torch.inference_mode() +def _lora_shrink_fp8( + inputs: torch.Tensor, # shape [num_tokens, hidden_size] - FP8 or FP16/BF16 + lora_a_weights: list[ + torch.Tensor + ], # shape [num_loras, lora_rank, hidden_size] - FP8 or FP16/BF16 + output_tensor: torch.Tensor, # shape [num_slices, num_tokens, lora_rank] + token_lora_mapping: torch.Tensor, # shape [num_tokens] + token_indices_sorted_by_lora_ids: torch.Tensor, # shape [num_tokens] + num_tokens_per_lora: torch.Tensor, # shape [max-loras + 1] + lora_token_start_loc: torch.Tensor, # shape [max-loras + 2] + lora_ids: torch.Tensor, # shape [max-loras + 1] + no_lora_flag_cpu: torch.Tensor, # shape [1] + num_active_loras: int, # number of active LoRAs (unused here, for API compat) + scaling: float, + b_scale: list[torch.Tensor], # LoRA weight scale per slice + a_scale: torch.Tensor | None = None, # Activation scale - per-token or block-wise + group_k: int = 0, # Block size for K in block-wise quantization (0 = tensor-wise) + group_n: int = 0, # Block size for N in block-wise quantization + use_fp8_w8a8: bool = False, + per_channel_quant: bool = False, +) -> None: + """ + Args: + inputs: FP8 or FP16/BF16 input tensor [num_tokens, hidden_size] + lora_a_weights: List of FP8 or FP16/BF16 LoRA A weights per slice + output_tensor: Output tensor (FP16/BF16/FP32) + token_lora_mapping: Token to LoRA ID mapping + token_indices_sorted_by_lora_ids: Sorted token indices + num_tokens_per_lora: Number of tokens per LoRA + lora_token_start_loc: Start location for each LoRA's tokens + lora_ids: LoRA IDs to process + scaling: LoRA scaling factor + a_scale: Activation quantization scales + b_scale: Weight quantization scales per slice + group_k: Block size for K dimension quantization + group_n: Block size for N dimension quantization + use_fp8_w8a8: Whether to use FP8 weights and activations + per_channel_quant: Whether to use per-channel quantization + """ + assert no_lora_flag_cpu.numel() == 1 + if no_lora_flag_cpu.item(): + # None of the inputs require LoRA. + return + + assert inputs.size(1) == lora_a_weights[0].size(-1) + assert inputs.is_contiguous() + assert output_tensor.is_contiguous() + + # metadata sanity check + M = inputs.size(0) + assert token_lora_mapping.size(0) == M + assert token_lora_mapping.size(0) == token_indices_sorted_by_lora_ids.size(0) + assert lora_ids.size(0) == num_tokens_per_lora.size(0) + assert lora_token_start_loc.size(0) == lora_ids.size(0) + 1 + + output_tensor.zero_() + + # Get LoRA weight pointers + (lora_ptr_tensor, lora_strides_d0, lora_strides_d1, lora_strides_d2) = ( + _get_lora_a_ptr(lora_a_weights, inputs.device) + ) + + # Get scale pointers if using FP8 + if use_fp8_w8a8: + assert a_scale is not None, "a_scale required for FP8 w8a8" + assert b_scale is not None, "b_scale required for FP8" + + b_scale_ptr_tensor, b_scale_l_stride, b_scale_n_stride, b_scale_k_stride = ( + _get_shrink_lora_scale_ptr(b_scale, inputs.device) + ) + a_scale_ptr = ( + a_scale if a_scale is not None else torch.tensor(1.0, device=inputs.device) + ) + else: + b_scale_ptr_tensor = torch.tensor(0, device=inputs.device) + b_scale_l_stride = 0 + b_scale_n_stride = 0 + b_scale_k_stride = 0 + a_scale_ptr = torch.tensor(0, device=inputs.device) + + N, K = lora_a_weights[0].shape[-2:] # K=hidden_size, N=rank + NUM_SLICES = len(lora_a_weights) + MAX_LORAS = lora_ids.size(0) + + # Triton kernel configs + kernel_config = get_lora_op_configs( + "shrink", + max_loras=MAX_LORAS, + batch=M, + hidden_size=K, + rank=N, + num_slices=NUM_SLICES, + ) + BLOCK_M = kernel_config["block_m"] + BLOCK_N = kernel_config["block_n"] + BLOCK_K = kernel_config["block_k"] + SPLIT_K = kernel_config["split_k"] + NUM_WARPS = kernel_config["num_warps"] + NUM_STAGES = kernel_config["num_stages"] + NUM_CTAS = kernel_config["num_ctas"] + GROUP_SIZE_M = kernel_config.get("group_size_m", 8) + assert BLOCK_K is not None and SPLIT_K is not None + EVEN_K = K % (BLOCK_K * SPLIT_K) == 0 + + # Grid configuration with column-major ordering support + grid = ( + SPLIT_K * triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N), + NUM_SLICES, + num_active_loras, + ) + + # Determine scale strides + if use_fp8_w8a8: + if a_scale is not None and a_scale.ndim == 2: + a_scale_m_stride = a_scale.stride(0) + a_scale_k_stride = a_scale.stride(1) + else: + a_scale_m_stride = 0 + a_scale_k_stride = 0 + else: + a_scale_m_stride = 0 + a_scale_k_stride = 0 + + # We disable PDL temporarily because LoRA kernels are not launching back-to-back, + # making PDL invalid and affecting the kernel performance. + use_gdc = False # supports_pdl(inputs.device) + _lora_shrink_kernel_fp8[grid]( + inputs, + lora_ptr_tensor, + output_tensor, + a_scale_ptr, + b_scale_ptr_tensor, + M, + N, + K, + token_indices_sorted_by_lora_ids, + num_tokens_per_lora, + lora_token_start_loc, + lora_ids, + scaling, + inputs.stride(0), + inputs.stride(1), + lora_strides_d0, + lora_strides_d1, + lora_strides_d2, + a_scale_m_stride, + a_scale_k_stride, + b_scale_l_stride, + b_scale_n_stride, + b_scale_k_stride, + output_tensor.stride(0), + output_tensor.stride(1), + output_tensor.stride(2), + group_n, + group_k, + BLOCK_M, + BLOCK_N, + BLOCK_K, + EVEN_K, + SPLIT_K, + GROUP_SIZE_M, + NUM_SLICES, + use_gdc, + use_fp8_w8a8, + per_channel_quant, + use_gdc, + num_warps=NUM_WARPS, + num_ctas=NUM_CTAS, + num_stages=NUM_STAGES, + ) + + return + + +def _lora_shrink_fp8_fake( + inputs: torch.Tensor, + lora_a_weights: list[torch.Tensor], + output_tensor: torch.Tensor, + token_lora_mapping: torch.Tensor, + token_indices_sorted_by_lora_ids: torch.Tensor, + num_tokens_per_lora: torch.Tensor, + lora_token_start_loc: torch.Tensor, + lora_ids: torch.Tensor, + no_lora_flag_cpu: torch.Tensor, + num_active_loras: int, + scaling: float, + b_scale: list[torch.Tensor], # LoRA weight scale per slice + a_scale: torch.Tensor | None = None, # Activation scale - per-token or block-wise + group_k: int = 0, # Block size for K in block-wise quantization (0 = tensor-wise) + group_n: int = 0, # Block size for N in block-wise quantization + use_fp8_w8a8: bool = False, + per_channel_quant: bool = False, +) -> None: + return + + +try: + direct_register_custom_op( + op_name="lora_shrink_fp8", + op_func=_lora_shrink_fp8, + mutates_args=["output_tensor"], + fake_impl=_lora_shrink_fp8_fake, + ) + lora_shrink_fp8 = torch.ops.vllm.lora_shrink_fp8 + +except AttributeError: + lora_shrink_fp8 = _lora_shrink_fp8 diff --git a/vllm/lora/ops/triton_ops/utils.py b/vllm/lora/ops/triton_ops/utils.py index a863b9726..ac32dd471 100644 --- a/vllm/lora/ops/triton_ops/utils.py +++ b/vllm/lora/ops/triton_ops/utils.py @@ -252,7 +252,7 @@ def get_lora_op_configs( default = { "block_m": 64, "block_n": 64 if num_slices > 1 else 128, - "block_k": 16, + "block_k": 32, "num_warps": 4, "num_ctas": 1, "num_stages": 2,