# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """FP8 accuracy tests for LoRA shrink and expand kernels. Tests the FP8 kernels by: 1. Quantizing bf16 inputs/weights to FP8 2. Dequantizing them back to bf16 3. Running the bf16 reference (sgmv_shrink/sgmv_expand) with dequantized values 4. Comparing FP8 kernel output against this dequantized reference This isolates kernel correctness from quantization precision loss, allowing much tighter tolerances than comparing against the original bf16. """ import math from threading import Lock import pytest import torch import vllm.lora.ops.torch_ops as torch_ops import vllm.lora.ops.triton_ops as triton_ops from vllm.lora.ops.triton_ops import LoRAKernelMeta from vllm.lora.ops.triton_ops.lora_expand_fp8_op import ( _EXPAND_LORA_SCALE_PTR_DICT, ) from vllm.lora.ops.triton_ops.lora_shrink_fp8_op import ( _SHRINK_LORA_SCALE_PTR_DICT, ) from vllm.lora.ops.triton_ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT from vllm.platforms import current_platform from vllm.utils.torch_utils import set_random_seed DEVICE_TYPE = current_platform.device_type DEVICES = [f"{DEVICE_TYPE}:{0}"] SEED = [0] _dict_lock = Lock() @pytest.fixture(autouse=True) def reset_device(reset_default_device): pass # ============================================================================ # Reference implementations (bf16 baseline) # ============================================================================ def sgmv_shrink_for_nslices( nslices, inputs_tensor, lora_weights_lst, out_tensor, b_seq_start_loc, seq_len_tensor, prompt_lora_mapping, batches, max_seq_length, num_tokens, scaling, ): """Wrapper around torch_ops.sgmv_shrink that handles any nslices.""" for index in range(nslices): torch_ops.sgmv_shrink( inputs_tensor, lora_weights_lst[index], out_tensor[index], b_seq_start_loc, seq_len_tensor, prompt_lora_mapping, batches, max_seq_length, num_tokens, scaling, ) def sgmv_expand_for_nslices( nslices, hidden_size, inputs_tensor, lora_weights_lst, out_tensor, b_seq_start_loc, seq_len_tensor, prompt_lora_mapping, batches, max_seq_length, num_tokens, add_inputs, ): """Wrapper around torch_ops.sgmv_expand that handles any nslices.""" if nslices == 1: torch_ops.sgmv_expand( inputs_tensor[0], lora_weights_lst[0], out_tensor, b_seq_start_loc, seq_len_tensor, prompt_lora_mapping, batches, max_seq_length, num_tokens, add_inputs=add_inputs, ) else: slice_offset = 0 for index in range(nslices): torch_ops.sgmv_expand_slice( inputs_tensor[index], lora_weights_lst[index], out_tensor, b_seq_start_loc, seq_len_tensor, prompt_lora_mapping, batches, max_seq_length, num_tokens, slice_offset, hidden_size, add_inputs=add_inputs, ) slice_offset += hidden_size # ============================================================================ # FP8 Quantization Helpers # ============================================================================ FP8_DTYPE = torch.float8_e4m3fn FP8_MAX = torch.finfo(FP8_DTYPE).max FP8_MIN = torch.finfo(FP8_DTYPE).min def quantize_to_fp8_per_tensor( tensor: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: """Quantize a tensor to FP8 with per-tensor scaling.""" amax = tensor.abs().float().max().clamp(min=1e-12) scale = (amax / FP8_MAX).to(torch.float32) fp8_tensor = (tensor.float() / scale).clamp(FP8_MIN, FP8_MAX).to(FP8_DTYPE) return fp8_tensor, scale.reshape(1) def quantize_to_fp8_per_channel( tensor: torch.Tensor, channel_dim: int = 0, ) -> tuple[torch.Tensor, torch.Tensor]: """Quantize a tensor to FP8 with per-channel scaling. For shrink lora_a weights of shape (num_loras, rank, hidden_size): channel_dim=1 gives per-rank scaling -> scale shape (num_loras, rank) For expand lora_b weights of shape (num_loras, hidden_size, rank): channel_dim=1 gives per-hidden scaling -> scale shape (num_loras, hidden_size) """ # Compute amax along all dims except the leading dims up to channel_dim+1 reduce_dims = list(range(channel_dim + 1, tensor.ndim)) if reduce_dims: amax = tensor.abs().float().amax(dim=reduce_dims).clamp(min=1e-12) else: amax = tensor.abs().float().clamp(min=1e-12) scale = (amax / FP8_MAX).to(torch.float32) # Expand scale for broadcasting for _ in reduce_dims: scale = scale.unsqueeze(-1) fp8_tensor = (tensor.float() / scale).clamp(FP8_MIN, FP8_MAX).to(FP8_DTYPE) scale = scale.squeeze() if scale.ndim == 0: scale = scale.unsqueeze(0) return fp8_tensor, scale def quantize_to_fp8_per_token( tensor: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: """Quantize a 2D tensor to FP8 with per-token (per-row) scaling. Input shape: (num_tokens, hidden_size) Returns: (fp8_tensor, scale) where scale shape is (num_tokens, 1) """ assert tensor.ndim == 2 amax = tensor.abs().float().amax(dim=1, keepdim=True).clamp(min=1e-12) scale = (amax / FP8_MAX).to(torch.float32) fp8_tensor = (tensor.float() / scale).clamp(FP8_MIN, FP8_MAX).to(FP8_DTYPE) return fp8_tensor, scale def quantize_to_fp8_blockwise( tensor: torch.Tensor, group_n: int, group_k: int, ) -> tuple[torch.Tensor, torch.Tensor]: """Quantize a 2D or 3D tensor to FP8 with block-wise scaling. For a 2D tensor (num_tokens, hidden_size): Blocks of size (1, group_k) -> scale shape (num_tokens, ceil(hidden_size/group_k)) For a 3D tensor (num_loras, N, K): Blocks of size (group_n, group_k) -> scale shape (num_loras, ceil(N/group_n), ceil(K/group_k)) """ if tensor.ndim == 2: M, K = tensor.shape n_blocks_k = math.ceil(K / group_k) scale = torch.zeros(M, n_blocks_k, dtype=torch.float32, device=tensor.device) fp8_tensor = torch.zeros_like(tensor, dtype=FP8_DTYPE) for m in range(M): for bk in range(n_blocks_k): k_start = bk * group_k k_end = min(k_start + group_k, K) block = tensor[m, k_start:k_end].float() amax = block.abs().max().clamp(min=1e-12) s = (amax / FP8_MAX).to(torch.float32) scale[m, bk] = s fp8_tensor[m, k_start:k_end] = ( (block / s).clamp(FP8_MIN, FP8_MAX).to(FP8_DTYPE) ) return fp8_tensor, scale elif tensor.ndim == 3: L, N, K = tensor.shape n_blocks_n = math.ceil(N / group_n) n_blocks_k = math.ceil(K / group_k) scale = torch.zeros( L, n_blocks_n, n_blocks_k, dtype=torch.float32, device=tensor.device ) fp8_tensor = torch.zeros_like(tensor, dtype=FP8_DTYPE) for li in range(L): for bn in range(n_blocks_n): for bk in range(n_blocks_k): n_start = bn * group_n n_end = min(n_start + group_n, N) k_start = bk * group_k k_end = min(k_start + group_k, K) block = tensor[li, n_start:n_end, k_start:k_end].float() amax = block.abs().max().clamp(min=1e-12) s = (amax / FP8_MAX).to(torch.float32) scale[li, bn, bk] = s fp8_tensor[li, n_start:n_end, k_start:k_end] = ( (block / s).clamp(FP8_MIN, FP8_MAX).to(FP8_DTYPE) ) return fp8_tensor, scale else: raise ValueError(f"Unsupported tensor ndim: {tensor.ndim}") # ============================================================================ # FP8 Dequantization Helpers # ============================================================================ def dequantize_fp8_per_tensor( fp8_tensor: torch.Tensor, scale: torch.Tensor, output_dtype: torch.dtype = torch.bfloat16, ) -> torch.Tensor: """Dequantize FP8 tensor with per-tensor scale back to output_dtype.""" return (fp8_tensor.float() * scale.float()).to(output_dtype) def dequantize_fp8_per_channel( fp8_tensor: torch.Tensor, scale: torch.Tensor, channel_dim: int, output_dtype: torch.dtype = torch.bfloat16, ) -> torch.Tensor: """Dequantize FP8 tensor with per-channel scale back to output_dtype. For 3D tensor (num_loras, N, K) with channel_dim=1: scale shape is (num_loras, N), broadcast over K. """ expand_scale = scale.float() # Add trailing dims for broadcasting for _ in range(channel_dim + 1, fp8_tensor.ndim): expand_scale = expand_scale.unsqueeze(-1) return (fp8_tensor.float() * expand_scale).to(output_dtype) def dequantize_fp8_per_token( fp8_tensor: torch.Tensor, scale: torch.Tensor, output_dtype: torch.dtype = torch.bfloat16, ) -> torch.Tensor: """Dequantize FP8 2D tensor with per-token scale back to output_dtype. fp8_tensor: (num_tokens, hidden_size), scale: (num_tokens, 1) """ return (fp8_tensor.float() * scale.float()).to(output_dtype) def dequantize_fp8_blockwise( fp8_tensor: torch.Tensor, scale: torch.Tensor, group_n: int, group_k: int, output_dtype: torch.dtype = torch.bfloat16, ) -> torch.Tensor: """Dequantize FP8 tensor with block-wise scale back to output_dtype.""" if fp8_tensor.ndim == 2: M, K = fp8_tensor.shape out = torch.zeros(M, K, dtype=output_dtype, device=fp8_tensor.device) n_blocks_k = math.ceil(K / group_k) for m in range(M): for bk in range(n_blocks_k): k_start = bk * group_k k_end = min(k_start + group_k, K) out[m, k_start:k_end] = ( fp8_tensor[m, k_start:k_end].float() * scale[m, bk].float() ).to(output_dtype) return out elif fp8_tensor.ndim == 3: L, N, K = fp8_tensor.shape out = torch.zeros(L, N, K, dtype=output_dtype, device=fp8_tensor.device) n_blocks_n = math.ceil(N / group_n) n_blocks_k = math.ceil(K / group_k) for l_idx in range(L): for bn in range(n_blocks_n): for bk in range(n_blocks_k): n_start = bn * group_n n_end = min(n_start + group_n, N) k_start = bk * group_k k_end = min(k_start + group_k, K) out[l_idx, n_start:n_end, k_start:k_end] = ( fp8_tensor[l_idx, n_start:n_end, k_start:k_end].float() * scale[l_idx, bn, bk].float() ).to(output_dtype) return out else: raise ValueError(f"Unsupported tensor ndim: {fp8_tensor.ndim}") # ============================================================================ # FP8 Data Generation # ============================================================================ def generate_fp8_shrink_data( batches: int, hidden_size: int, num_loras: int, rank: int, seq_length: int, nslices: int, dtype: torch.dtype, device: str, quant_mode: str, # "per_tensor", "per_channel", "blockwise" group_k: int = 128, group_n: int = 128, ): """Generate test data for FP8 shrink kernel. Shrink: output = input @ lora_a^T * scaling input: (num_tokens, hidden_size) -> quantized to FP8 lora_a: (num_loras, rank, hidden_size) -> quantized to FP8 Returns bf16 reference tensors, FP8 quantized tensors with scales, and dequantized bf16 tensors for accurate reference computation. """ seq_len_tensor = torch.randint(seq_length, seq_length + 1, (batches,)).to(device) b_seq_start_loc = torch.cumsum( torch.tensor([0] + seq_len_tensor[:-1].tolist(), dtype=torch.long), dim=0, ).to(device) total_tokens = seq_len_tensor.sum().item() # Generate bf16 reference data inputs_bf16 = torch.randn(total_tokens, hidden_size, dtype=dtype, device=device) lora_a_weights_bf16 = [] for _ in range(nslices): lora_a_weights_bf16.append( torch.randn(num_loras, rank, hidden_size, dtype=dtype, device=device) ) # Quantize inputs to FP8 and dequantize back for reference if quant_mode == "blockwise": inputs_fp8, a_scale = quantize_to_fp8_blockwise( inputs_bf16, group_n=1, group_k=group_k ) inputs_dequant = dequantize_fp8_blockwise( inputs_fp8, a_scale, group_n=1, group_k=group_k, output_dtype=dtype, ) elif quant_mode == "per_tensor": # Per-tensor: kernel loads a single scalar from a_scale_ptr inputs_fp8, a_scale = quantize_to_fp8_per_tensor(inputs_bf16) inputs_dequant = dequantize_fp8_per_tensor( inputs_fp8, a_scale, output_dtype=dtype, ) else: # per_channel: kernel loads per-token a_scale via ram indexing inputs_fp8, a_scale = quantize_to_fp8_per_token(inputs_bf16) inputs_dequant = dequantize_fp8_per_token( inputs_fp8, a_scale, output_dtype=dtype, ) # Quantize lora_a weights to FP8 and dequantize back for reference b_scales = [] lora_a_weights_fp8 = [] lora_a_weights_dequant = [] for w in lora_a_weights_bf16: if quant_mode == "per_tensor": w_fp8, w_scale = quantize_to_fp8_per_tensor(w) w_dequant = dequantize_fp8_per_tensor(w_fp8, w_scale, output_dtype=dtype) # Scale shape: (1,) -> need (num_loras,) for the kernel w_scale = w_scale.expand(num_loras).contiguous() lora_a_weights_fp8.append(w_fp8) b_scales.append(w_scale) lora_a_weights_dequant.append(w_dequant) elif quant_mode == "per_channel": # Per-channel along rank dim: scale shape (num_loras, rank) w_fp8, w_scale = quantize_to_fp8_per_channel(w, channel_dim=1) w_dequant = dequantize_fp8_per_channel( w_fp8, w_scale, channel_dim=1, output_dtype=dtype, ) lora_a_weights_fp8.append(w_fp8) b_scales.append(w_scale) lora_a_weights_dequant.append(w_dequant) elif quant_mode == "blockwise": w_fp8, w_scale = quantize_to_fp8_blockwise( w, group_n=group_n, group_k=group_k ) w_dequant = dequantize_fp8_blockwise( w_fp8, w_scale, group_n=group_n, group_k=group_k, output_dtype=dtype, ) lora_a_weights_fp8.append(w_fp8) b_scales.append(w_scale) lora_a_weights_dequant.append(w_dequant) # Output tensor (float32 for shrink) out_tensor = torch.zeros( nslices, total_tokens, rank, dtype=torch.float32, device=device ) ref_out_tensor = out_tensor.clone() # Token-to-lora mapping lora_indices_tensor = torch.randint(0, max(num_loras - 1, 1), (batches,)).to(device) token_lora_mapping = torch.zeros(total_tokens, dtype=torch.long, device=device) current_offset = 0 for b_id in range(batches): lora_index = lora_indices_tensor[b_id] sl = seq_len_tensor[b_id].item() token_lora_mapping[current_offset : current_offset + sl] = lora_index current_offset += sl return { "inputs_bf16": inputs_bf16, "inputs_fp8": inputs_fp8, "inputs_dequant": inputs_dequant, "lora_a_bf16": lora_a_weights_bf16, "lora_a_fp8": lora_a_weights_fp8, "lora_a_dequant": lora_a_weights_dequant, "a_scale": a_scale, "b_scales": b_scales, "out_tensor": out_tensor, "ref_out_tensor": ref_out_tensor, "token_lora_mapping": token_lora_mapping, "seq_len_tensor": seq_len_tensor, "b_seq_start_loc": b_seq_start_loc, "lora_indices_tensor": lora_indices_tensor, "total_tokens": total_tokens, } def generate_fp8_expand_data( batches: int, hidden_size: int, num_loras: int, rank: int, seq_length: int, nslices: int, dtype: torch.dtype, device: str, quant_mode: str, # "per_tensor", "per_channel", "blockwise" group_k: int = 128, group_n: int = 128, ): """Generate test data for FP8 expand kernel (w8a8). Expand: output += input @ lora_b^T input: (nslices, num_tokens, rank) -> quantized to FP8 (activations) lora_b: (num_loras, hidden_size, rank) -> quantized to FP8 (weights) In w8a8 mode, both activations and weights are FP8. Returns bf16 reference tensors, FP8 quantized tensors with scales, and dequantized bf16 tensors for accurate reference computation. """ seq_len_tensor = torch.randint(seq_length, seq_length + 1, (batches,)).to(device) b_seq_start_loc = torch.cumsum( torch.tensor([0] + seq_len_tensor[:-1].tolist(), dtype=torch.long), dim=0, ).to(device) total_tokens = seq_len_tensor.sum().item() # Generate bf16 input (shrink output) and quantize to FP8 inputs_bf16 = torch.randn(nslices, total_tokens, rank, dtype=dtype, device=device) # Quantize input to FP8 and dequantize back for reference inputs_2d_all = inputs_bf16.reshape(-1, rank) if quant_mode == "blockwise": # For blockwise, the kernel indexes a_scale by token id (0..total_tokens-1) # shared across slices. Compute shared scale across slices, then quantize. # First compute per-token-per-block scale across all slices n_blocks_k = math.ceil(rank / group_k) a_scale = torch.zeros( total_tokens, n_blocks_k, dtype=torch.float32, device=device ) for m in range(total_tokens): for bk in range(n_blocks_k): k_start = bk * group_k k_end = min(k_start + group_k, rank) # Max across all slices for this token and block block_amax = torch.tensor(0.0, device=device) for s in range(nslices): block = inputs_bf16[s, m, k_start:k_end].float() block_amax = torch.max( block_amax, block.abs().max().clamp(min=1e-12) ) a_scale[m, bk] = (block_amax / FP8_MAX).to(torch.float32) # Quantize all slices with the shared scale inputs_fp8_list = [] inputs_dequant_list = [] for s in range(nslices): slice_2d = inputs_bf16[s] # (total_tokens, rank) fp8_slice = torch.zeros_like(slice_2d, dtype=FP8_DTYPE) dequant_slice = torch.zeros_like(slice_2d) for m in range(total_tokens): for bk in range(n_blocks_k): k_start = bk * group_k k_end = min(k_start + group_k, rank) block = slice_2d[m, k_start:k_end].float() s_val = a_scale[m, bk] fp8_slice[m, k_start:k_end] = ( (block / s_val).clamp(FP8_MIN, FP8_MAX).to(FP8_DTYPE) ) dequant_slice[m, k_start:k_end] = ( fp8_slice[m, k_start:k_end].float() * s_val.float() ).to(dtype) inputs_fp8_list.append(fp8_slice) inputs_dequant_list.append(dequant_slice) inputs_fp8 = torch.stack(inputs_fp8_list, dim=0) inputs_dequant = torch.stack(inputs_dequant_list, dim=0) elif quant_mode == "per_tensor": # Per-tensor: kernel loads a single scalar from a_scale_ptr inputs_fp8_2d, a_scale = quantize_to_fp8_per_tensor(inputs_2d_all) inputs_dequant_2d = dequantize_fp8_per_tensor( inputs_fp8_2d, a_scale, output_dtype=dtype, ) inputs_fp8 = inputs_fp8_2d.reshape(nslices, total_tokens, rank) inputs_dequant = inputs_dequant_2d.reshape(nslices, total_tokens, rank) else: # per_channel: kernel loads per-token a_scale via ram indexing. # The kernel uses the same a_scale for all slices (indexed by token # id 0..total_tokens-1), so we compute a shared per-token scale # across all slices, then quantize each slice with that shared scale. per_slice_views = [inputs_bf16[s] for s in range(nslices)] # (nslices, total_tokens, rank) -> max across slices per token stacked = torch.stack(per_slice_views, dim=0) # (nslices, tokens, rank) amax = stacked.abs().float().amax(dim=(0, 2), keepdim=False).clamp(min=1e-12) # amax shape: (total_tokens,) a_scale = (amax / FP8_MAX).to(torch.float32).unsqueeze(1) # (tokens, 1) # Quantize all slices with the shared scale inputs_fp8_2d = ( (inputs_2d_all.float() / a_scale.repeat(nslices, 1)) .clamp(FP8_MIN, FP8_MAX) .to(FP8_DTYPE) ) inputs_dequant_2d = ( inputs_fp8_2d.float() * a_scale.repeat(nslices, 1).float() ).to(dtype) inputs_fp8 = inputs_fp8_2d.reshape(nslices, total_tokens, rank) inputs_dequant = inputs_dequant_2d.reshape(nslices, total_tokens, rank) # Generate bf16 LoRA B weights lora_b_weights_bf16 = [] for _ in range(nslices): lora_b_weights_bf16.append( torch.randn(num_loras, hidden_size, rank, dtype=dtype, device=device) ) # Quantize LoRA B weights to FP8 and dequantize back for reference b_scales = [] lora_b_weights_fp8 = [] lora_b_weights_dequant = [] for w in lora_b_weights_bf16: if quant_mode == "per_tensor": w_fp8, w_scale = quantize_to_fp8_per_tensor(w) w_dequant = dequantize_fp8_per_tensor(w_fp8, w_scale, output_dtype=dtype) w_scale = w_scale.expand(num_loras).contiguous() lora_b_weights_fp8.append(w_fp8) b_scales.append(w_scale) lora_b_weights_dequant.append(w_dequant) elif quant_mode == "per_channel": # Per-channel along hidden_size dim: scale (num_loras, hidden_size) w_fp8, w_scale = quantize_to_fp8_per_channel(w, channel_dim=1) w_dequant = dequantize_fp8_per_channel( w_fp8, w_scale, channel_dim=1, output_dtype=dtype, ) lora_b_weights_fp8.append(w_fp8) b_scales.append(w_scale) lora_b_weights_dequant.append(w_dequant) elif quant_mode == "blockwise": w_fp8, w_scale = quantize_to_fp8_blockwise( w, group_n=group_n, group_k=group_k ) w_dequant = dequantize_fp8_blockwise( w_fp8, w_scale, group_n=group_n, group_k=group_k, output_dtype=dtype, ) lora_b_weights_fp8.append(w_fp8) b_scales.append(w_scale) lora_b_weights_dequant.append(w_dequant) # Output tensor (initialized randomly for add_inputs) out_tensor = torch.randn( total_tokens, hidden_size * nslices, dtype=dtype, device=device ) ref_out_tensor = out_tensor.clone() # Token-to-lora mapping lora_indices_tensor = torch.randint(0, max(num_loras - 1, 1), (batches,)).to(device) token_lora_mapping = torch.zeros(total_tokens, dtype=torch.long, device=device) current_offset = 0 for b_id in range(batches): lora_index = lora_indices_tensor[b_id] sl = seq_len_tensor[b_id].item() token_lora_mapping[current_offset : current_offset + sl] = lora_index current_offset += sl return { "inputs_bf16": inputs_bf16, "inputs_fp8": inputs_fp8, "inputs_dequant": inputs_dequant, "a_scale": a_scale, "lora_b_bf16": lora_b_weights_bf16, "lora_b_fp8": lora_b_weights_fp8, "lora_b_dequant": lora_b_weights_dequant, "b_scales": b_scales, "out_tensor": out_tensor, "ref_out_tensor": ref_out_tensor, "token_lora_mapping": token_lora_mapping, "seq_len_tensor": seq_len_tensor, "b_seq_start_loc": b_seq_start_loc, "lora_indices_tensor": lora_indices_tensor, "total_tokens": total_tokens, } # ============================================================================ # FP8 Shrink Kernel Check # ============================================================================ def check_lora_shrink_fp8_kernel( batches: int, num_loras: int, rank: int, hidden_size: int, nslices: int, dtype: torch.dtype, device: str, seq_length: int, scaling: float, quant_mode: str, group_k: int = 128, group_n: int = 128, ): """Test FP8 shrink kernel against dequantized bf16 reference. Instead of comparing FP8 kernel output against the original bf16 reference (which conflates quantization error with kernel error), we: 1. Quantize bf16 inputs/weights to FP8 2. Dequantize them back to bf16 3. Run the bf16 reference (sgmv_shrink) with the dequantized values 4. Compare FP8 kernel output against this dequantized reference This isolates kernel correctness from quantization precision loss, allowing much tighter tolerances. """ data = generate_fp8_shrink_data( batches, hidden_size, num_loras, rank, seq_length, nslices, dtype, device, quant_mode, group_k, group_n, ) total_tokens = data["total_tokens"] # Setup LoRA kernel metadata lora_meta = LoRAKernelMeta.make( max_loras=num_loras, max_num_tokens=total_tokens, device=device ) lora_meta.prepare_tensors(data["token_lora_mapping"]) out_tensor = data["out_tensor"] # Determine quantization params for the kernel per_channel = quant_mode == "per_channel" gk = group_k if quant_mode == "blockwise" else 0 gn = group_n if quant_mode == "blockwise" else 0 with _dict_lock: _LORA_A_PTR_DICT.clear() _SHRINK_LORA_SCALE_PTR_DICT.clear() triton_ops.lora_shrink_fp8( data["inputs_fp8"], data["lora_a_fp8"], out_tensor, *lora_meta.meta_args(token_nums=total_tokens, specialize_active_lora=False), scaling, data["b_scales"], a_scale=data["a_scale"], group_k=gk, group_n=gn, use_fp8_w8a8=True, per_channel_quant=per_channel, ) # Compute reference using dequantized (round-tripped) tensors. # This means the reference sees the same quantization error as the kernel, # so any difference is purely kernel error. ref_out_tensor = data["ref_out_tensor"] max_seq_length = data["seq_len_tensor"].max().item() sgmv_shrink_for_nslices( nslices, data["inputs_dequant"], data["lora_a_dequant"], ref_out_tensor, data["b_seq_start_loc"], data["seq_len_tensor"], data["lora_indices_tensor"], batches, max_seq_length, total_tokens, scaling, ) # With dequantized reference, we can use much tighter tolerances # since we're only measuring kernel error, not quantization error. # Blockwise accumulation order differs from the bf16 reference, so # allow a slightly larger margin for sporadic rounding outliers. rtol, atol = 0.1, 0.25 torch.testing.assert_close( out_tensor.to(dtype), ref_out_tensor.to(dtype), rtol=rtol, atol=atol ) # ============================================================================ # FP8 Expand Kernel Check # ============================================================================ def check_lora_expand_fp8_kernel( batches: int, num_loras: int, rank: int, hidden_size: int, nslices: int, dtype: torch.dtype, device: str, seq_length: int, add_inputs: bool, quant_mode: str, group_k: int = 128, group_n: int = 128, ): """Test FP8 expand kernel (w8a8) against dequantized bf16 reference. Instead of comparing FP8 kernel output against the original bf16 reference (which conflates quantization error with kernel error), we: 1. Quantize bf16 inputs/weights to FP8 2. Dequantize them back to bf16 3. Run the bf16 reference (sgmv_expand) with the dequantized values 4. Compare FP8 kernel output against this dequantized reference This isolates kernel correctness from quantization precision loss, allowing much tighter tolerances. """ data = generate_fp8_expand_data( batches, hidden_size, num_loras, rank, seq_length, nslices, dtype, device, quant_mode, group_k, group_n, ) total_tokens = data["total_tokens"] # Setup LoRA kernel metadata lora_meta = LoRAKernelMeta.make( max_loras=num_loras, max_num_tokens=total_tokens, device=device ) lora_meta.prepare_tensors(data["token_lora_mapping"]) out_tensor = data["out_tensor"] # Determine quantization params for the kernel per_channel = quant_mode == "per_channel" gk = group_k if quant_mode == "blockwise" else 0 gn = group_n if quant_mode == "blockwise" else 0 with _dict_lock: _LORA_B_PTR_DICT.clear() _EXPAND_LORA_SCALE_PTR_DICT.clear() triton_ops.lora_expand_fp8( data["inputs_fp8"], data["lora_b_fp8"], out_tensor, *lora_meta.meta_args(token_nums=total_tokens, specialize_active_lora=False), data["b_scales"], a_scale=data["a_scale"], offset_start=0, add_inputs=add_inputs, group_k=gk, group_n=gn, use_fp8_w8a8=True, per_channel_quant=per_channel, ) # Compute reference using dequantized (round-tripped) tensors. ref_out_tensor = data["ref_out_tensor"] max_seq_length = data["seq_len_tensor"].max().item() sgmv_expand_for_nslices( nslices, hidden_size, data["inputs_dequant"], data["lora_b_dequant"], ref_out_tensor, data["b_seq_start_loc"], data["seq_len_tensor"], data["lora_indices_tensor"], batches, max_seq_length, total_tokens, add_inputs=add_inputs, ) # With dequantized reference, we can use much tighter tolerances # since we're only measuring kernel error, not quantization error. rtol, atol = 0.1, 0.15 torch.testing.assert_close(out_tensor, ref_out_tensor, rtol=rtol, atol=atol) # ============================================================================ # FP8 Test Parameters # ============================================================================ fp8_test_params = { "hidden_sizes": [512, 1024, 2048], "batches": [1, 4, 16], "num_loras": [1, 4, 8], "max_ranks": [8, 16, 32, 64], } # ============================================================================ # FP8 Shrink Tests # ============================================================================ @pytest.mark.parametrize("batches", fp8_test_params["batches"]) @pytest.mark.parametrize("num_loras", fp8_test_params["num_loras"]) @pytest.mark.parametrize("rank", fp8_test_params["max_ranks"]) @pytest.mark.parametrize("hidden_size", fp8_test_params["hidden_sizes"]) @pytest.mark.parametrize("nslices", [1, 2, 3]) @pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("device", DEVICES) @pytest.mark.parametrize("seed", SEED) @pytest.mark.parametrize("quant_mode", ["per_tensor", "per_channel", "blockwise"]) def test_lora_shrink_fp8( batches: int, num_loras: int, rank: int, hidden_size: int, nslices: int, dtype: torch.dtype, device: str, seed: int, quant_mode: str, ): """Test FP8 shrink kernel with per-tensor, per-channel, and block-wise quantization, comparing against the bf16 baseline.""" torch.set_default_device(device) set_random_seed(seed) # For blockwise, group sizes must divide evenly or be handled by the kernel group_k = 128 group_n = 128 # Adjust group sizes if they're larger than the dimensions if quant_mode == "blockwise": group_k = min(group_k, hidden_size) group_n = min(group_n, rank) check_lora_shrink_fp8_kernel( batches=batches, num_loras=num_loras, rank=rank, hidden_size=hidden_size, nslices=nslices, dtype=dtype, device=device, seq_length=128, scaling=0.5, quant_mode=quant_mode, group_k=group_k, group_n=group_n, ) # ============================================================================ # FP8 Expand Tests # ============================================================================ @pytest.mark.parametrize("batches", fp8_test_params["batches"]) @pytest.mark.parametrize("num_loras", fp8_test_params["num_loras"]) @pytest.mark.parametrize("rank", fp8_test_params["max_ranks"]) @pytest.mark.parametrize("hidden_size", fp8_test_params["hidden_sizes"]) @pytest.mark.parametrize("nslices", [1, 2, 3]) @pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("device", DEVICES) @pytest.mark.parametrize("seed", SEED) @pytest.mark.parametrize("quant_mode", ["per_tensor", "per_channel", "blockwise"]) def test_lora_expand_fp8( batches: int, num_loras: int, rank: int, hidden_size: int, nslices: int, dtype: torch.dtype, device: str, seed: int, quant_mode: str, ): """Test FP8 expand kernel with per-tensor, per-channel, and block-wise quantization, comparing against the bf16 baseline.""" torch.set_default_device(device) set_random_seed(seed) group_k = 128 group_n = 128 # Adjust group sizes if they're larger than the dimensions if quant_mode == "blockwise": group_k = min(group_k, rank) group_n = min(group_n, hidden_size) check_lora_expand_fp8_kernel( batches=batches, num_loras=num_loras, rank=rank, hidden_size=hidden_size, nslices=nslices, dtype=dtype, device=device, seq_length=128, add_inputs=True, quant_mode=quant_mode, group_k=group_k, group_n=group_n, )