2024-06-09 16:23:30 -04:00
|
|
|
import contextlib
|
2024-06-12 14:46:35 -07:00
|
|
|
import functools
|
2024-07-31 10:49:48 +08:00
|
|
|
from typing import List, Optional, Tuple, Union
|
2024-04-11 03:26:07 +00:00
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
|
2024-08-28 14:37:47 -05:00
|
|
|
import vllm.envs as envs
|
2024-08-02 16:51:58 -04:00
|
|
|
from vllm._core_ext import ScalarType
|
2024-06-12 14:46:35 -07:00
|
|
|
from vllm.logger import init_logger
|
2024-08-13 00:30:30 -07:00
|
|
|
from vllm.platforms import current_platform
|
2024-06-12 14:46:35 -07:00
|
|
|
|
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
|
2024-08-13 00:30:30 -07:00
|
|
|
if not current_platform.is_tpu():
|
|
|
|
|
try:
|
|
|
|
|
import vllm._C
|
|
|
|
|
except ImportError as e:
|
|
|
|
|
logger.warning("Failed to import from vllm._C with %r", e)
|
2024-04-11 03:26:07 +00:00
|
|
|
|
2024-06-09 16:23:30 -04:00
|
|
|
with contextlib.suppress(ImportError):
|
2024-08-16 17:00:11 -04:00
|
|
|
import vllm._moe_C # noqa: F401
|
2024-06-09 16:23:30 -04:00
|
|
|
|
2024-04-11 03:26:07 +00:00
|
|
|
|
2024-06-12 14:46:35 -07:00
|
|
|
def hint_on_error(fn):
|
|
|
|
|
|
|
|
|
|
@functools.wraps(fn)
|
|
|
|
|
def wrapper(*args, **kwargs):
|
|
|
|
|
try:
|
|
|
|
|
return fn(*args, **kwargs)
|
|
|
|
|
except AttributeError as e:
|
|
|
|
|
msg = (
|
|
|
|
|
"Error in calling custom op %s: %s\n"
|
|
|
|
|
"Possibly you have built or installed an obsolete version of vllm.\n"
|
|
|
|
|
"Please try a clean build and install of vllm,"
|
|
|
|
|
"or remove old built files such as vllm/*cpython*.so and build/ ."
|
|
|
|
|
)
|
|
|
|
|
logger.error(msg, fn.__name__, e)
|
|
|
|
|
raise e
|
|
|
|
|
|
|
|
|
|
return wrapper
|
|
|
|
|
|
|
|
|
|
|
2024-04-11 03:26:07 +00:00
|
|
|
# activation ops
|
|
|
|
|
def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
|
2024-06-09 16:23:30 -04:00
|
|
|
torch.ops._C.silu_and_mul(out, x)
|
2024-04-11 03:26:07 +00:00
|
|
|
|
|
|
|
|
|
|
|
|
|
def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
|
2024-06-09 16:23:30 -04:00
|
|
|
torch.ops._C.gelu_and_mul(out, x)
|
2024-04-11 03:26:07 +00:00
|
|
|
|
|
|
|
|
|
|
|
|
|
def gelu_tanh_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
|
2024-06-09 16:23:30 -04:00
|
|
|
torch.ops._C.gelu_tanh_and_mul(out, x)
|
2024-04-11 03:26:07 +00:00
|
|
|
|
|
|
|
|
|
|
|
|
|
def gelu_fast(out: torch.Tensor, x: torch.Tensor) -> None:
|
2024-06-09 16:23:30 -04:00
|
|
|
torch.ops._C.gelu_fast(out, x)
|
2024-04-11 03:26:07 +00:00
|
|
|
|
|
|
|
|
|
|
|
|
|
def gelu_new(out: torch.Tensor, x: torch.Tensor) -> None:
|
2024-06-09 16:23:30 -04:00
|
|
|
torch.ops._C.gelu_new(out, x)
|
2024-04-11 03:26:07 +00:00
|
|
|
|
|
|
|
|
|
2024-06-20 04:52:09 -07:00
|
|
|
def gelu_quick(out: torch.Tensor, x: torch.Tensor) -> None:
|
|
|
|
|
torch.ops._C.gelu_quick(out, x)
|
|
|
|
|
|
|
|
|
|
|
2024-04-11 03:26:07 +00:00
|
|
|
# page attention ops
|
|
|
|
|
def paged_attention_v1(
|
|
|
|
|
out: torch.Tensor,
|
|
|
|
|
query: torch.Tensor,
|
|
|
|
|
key_cache: torch.Tensor,
|
|
|
|
|
value_cache: torch.Tensor,
|
|
|
|
|
num_kv_heads: int,
|
|
|
|
|
scale: float,
|
|
|
|
|
block_tables: torch.Tensor,
|
2024-05-04 02:20:12 +09:00
|
|
|
seq_lens: torch.Tensor,
|
2024-04-11 03:26:07 +00:00
|
|
|
block_size: int,
|
2024-05-04 02:20:12 +09:00
|
|
|
max_seq_len: int,
|
2024-04-11 03:26:07 +00:00
|
|
|
alibi_slopes: Optional[torch.Tensor],
|
|
|
|
|
kv_cache_dtype: str,
|
2024-07-16 18:31:32 -04:00
|
|
|
k_scale: float,
|
|
|
|
|
v_scale: float,
|
2024-05-25 01:00:52 -04:00
|
|
|
tp_rank: int = 0,
|
|
|
|
|
blocksparse_local_blocks: int = 0,
|
|
|
|
|
blocksparse_vert_stride: int = 0,
|
|
|
|
|
blocksparse_block_size: int = 64,
|
|
|
|
|
blocksparse_head_sliding_step: int = 0,
|
2024-04-11 03:26:07 +00:00
|
|
|
) -> None:
|
2024-06-09 16:23:30 -04:00
|
|
|
torch.ops._C.paged_attention_v1(
|
2024-05-25 01:00:52 -04:00
|
|
|
out, query, key_cache, value_cache, num_kv_heads, scale, block_tables,
|
|
|
|
|
seq_lens, block_size, max_seq_len, alibi_slopes, kv_cache_dtype,
|
2024-07-16 18:31:32 -04:00
|
|
|
k_scale, v_scale, tp_rank, blocksparse_local_blocks,
|
|
|
|
|
blocksparse_vert_stride, blocksparse_block_size,
|
|
|
|
|
blocksparse_head_sliding_step)
|
2024-04-11 03:26:07 +00:00
|
|
|
|
|
|
|
|
|
|
|
|
|
def paged_attention_v2(
|
|
|
|
|
out: torch.Tensor,
|
|
|
|
|
exp_sum: torch.Tensor,
|
|
|
|
|
max_logits: torch.Tensor,
|
|
|
|
|
tmp_out: torch.Tensor,
|
|
|
|
|
query: torch.Tensor,
|
|
|
|
|
key_cache: torch.Tensor,
|
|
|
|
|
value_cache: torch.Tensor,
|
|
|
|
|
num_kv_heads: int,
|
|
|
|
|
scale: float,
|
|
|
|
|
block_tables: torch.Tensor,
|
2024-05-04 02:20:12 +09:00
|
|
|
seq_lens: torch.Tensor,
|
2024-04-11 03:26:07 +00:00
|
|
|
block_size: int,
|
2024-05-04 02:20:12 +09:00
|
|
|
max_seq_len: int,
|
2024-04-11 03:26:07 +00:00
|
|
|
alibi_slopes: Optional[torch.Tensor],
|
|
|
|
|
kv_cache_dtype: str,
|
2024-07-16 18:31:32 -04:00
|
|
|
k_scale: float,
|
|
|
|
|
v_scale: float,
|
2024-05-25 01:00:52 -04:00
|
|
|
tp_rank: int = 0,
|
|
|
|
|
blocksparse_local_blocks: int = 0,
|
|
|
|
|
blocksparse_vert_stride: int = 0,
|
|
|
|
|
blocksparse_block_size: int = 64,
|
|
|
|
|
blocksparse_head_sliding_step: int = 0,
|
2024-04-11 03:26:07 +00:00
|
|
|
) -> None:
|
2024-06-09 16:23:30 -04:00
|
|
|
torch.ops._C.paged_attention_v2(
|
2024-05-25 01:00:52 -04:00
|
|
|
out, exp_sum, max_logits, tmp_out, query, key_cache, value_cache,
|
|
|
|
|
num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len,
|
2024-07-16 18:31:32 -04:00
|
|
|
alibi_slopes, kv_cache_dtype, k_scale, v_scale, tp_rank,
|
2024-05-25 01:00:52 -04:00
|
|
|
blocksparse_local_blocks, blocksparse_vert_stride,
|
|
|
|
|
blocksparse_block_size, blocksparse_head_sliding_step)
|
2024-04-11 03:26:07 +00:00
|
|
|
|
|
|
|
|
|
|
|
|
|
# pos encoding ops
|
|
|
|
|
def rotary_embedding(
|
|
|
|
|
positions: torch.Tensor,
|
|
|
|
|
query: torch.Tensor,
|
|
|
|
|
key: torch.Tensor,
|
|
|
|
|
head_size: int,
|
|
|
|
|
cos_sin_cache: torch.Tensor,
|
|
|
|
|
is_neox: bool,
|
|
|
|
|
) -> None:
|
2024-06-09 16:23:30 -04:00
|
|
|
torch.ops._C.rotary_embedding(positions, query, key, head_size,
|
|
|
|
|
cos_sin_cache, is_neox)
|
2024-04-11 03:26:07 +00:00
|
|
|
|
|
|
|
|
|
|
|
|
|
def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor,
|
|
|
|
|
key: torch.Tensor, head_size: int,
|
|
|
|
|
cos_sin_cache: torch.Tensor, is_neox: bool,
|
|
|
|
|
rot_dim: int,
|
|
|
|
|
cos_sin_cache_offsets: torch.Tensor) -> None:
|
2024-06-09 16:23:30 -04:00
|
|
|
torch.ops._C.batched_rotary_embedding(positions, query, key, head_size,
|
|
|
|
|
cos_sin_cache, is_neox, rot_dim,
|
|
|
|
|
cos_sin_cache_offsets)
|
2024-04-11 03:26:07 +00:00
|
|
|
|
|
|
|
|
|
|
|
|
|
# layer norm ops
|
|
|
|
|
def rms_norm(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor,
|
|
|
|
|
epsilon: float) -> None:
|
2024-06-09 16:23:30 -04:00
|
|
|
torch.ops._C.rms_norm(out, input, weight, epsilon)
|
2024-04-11 03:26:07 +00:00
|
|
|
|
|
|
|
|
|
|
|
|
|
def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor,
|
|
|
|
|
weight: torch.Tensor, epsilon: float) -> None:
|
2024-06-09 16:23:30 -04:00
|
|
|
torch.ops._C.fused_add_rms_norm(input, residual, weight, epsilon)
|
2024-04-11 03:26:07 +00:00
|
|
|
|
|
|
|
|
|
2024-07-17 17:30:28 -04:00
|
|
|
def advance_step(num_seqs: int, num_queries: int, block_size: int,
|
|
|
|
|
input_tokens: torch.Tensor, sampled_token_ids: torch.Tensor,
|
|
|
|
|
input_positions: torch.Tensor, seq_lens: torch.Tensor,
|
|
|
|
|
slot_mapping: torch.Tensor,
|
|
|
|
|
block_tables: torch.Tensor) -> None:
|
|
|
|
|
"""Advance a step on GPU for existing inputs for a multi-step runner"""
|
|
|
|
|
return torch.ops._C.advance_step(num_seqs, num_queries, block_size,
|
|
|
|
|
input_tokens, sampled_token_ids,
|
|
|
|
|
input_positions, seq_lens, slot_mapping,
|
|
|
|
|
block_tables)
|
|
|
|
|
|
|
|
|
|
|
2024-04-11 03:26:07 +00:00
|
|
|
# quantization ops
|
|
|
|
|
# awq
|
|
|
|
|
def awq_dequantize(qweight: torch.Tensor, scales: torch.Tensor,
|
|
|
|
|
zeros: torch.Tensor, split_k_iters: int, thx: int,
|
|
|
|
|
thy: int) -> torch.Tensor:
|
2024-08-28 14:37:47 -05:00
|
|
|
if envs.VLLM_USE_TRITON_AWQ:
|
|
|
|
|
from vllm.model_executor.layers.quantization.awq_triton import (
|
|
|
|
|
awq_dequantize_triton)
|
|
|
|
|
return awq_dequantize_triton(qweight, scales, zeros)
|
2024-06-09 16:23:30 -04:00
|
|
|
return torch.ops._C.awq_dequantize(qweight, scales, zeros, split_k_iters,
|
|
|
|
|
thx, thy)
|
2024-04-11 03:26:07 +00:00
|
|
|
|
|
|
|
|
|
|
|
|
|
def awq_gemm(input: torch.Tensor, qweight: torch.Tensor, qzeros: torch.Tensor,
|
|
|
|
|
scales: torch.Tensor, split_k_iters: int) -> torch.Tensor:
|
2024-08-28 14:37:47 -05:00
|
|
|
if envs.VLLM_USE_TRITON_AWQ:
|
|
|
|
|
from vllm.model_executor.layers.quantization.awq_triton import (
|
|
|
|
|
awq_gemm_triton)
|
|
|
|
|
return awq_gemm_triton(input, qweight, qzeros, scales, split_k_iters)
|
2024-06-09 16:23:30 -04:00
|
|
|
return torch.ops._C.awq_gemm(input, qweight, qzeros, scales, split_k_iters)
|
2024-04-11 03:26:07 +00:00
|
|
|
|
|
|
|
|
|
|
|
|
|
# gptq
|
|
|
|
|
def gptq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
|
|
|
|
|
b_gptq_qzeros: torch.Tensor, b_gptq_scales: torch.Tensor,
|
|
|
|
|
b_g_idx: torch.Tensor, use_exllama: bool,
|
|
|
|
|
bit: int) -> torch.Tensor:
|
2024-06-09 16:23:30 -04:00
|
|
|
return torch.ops._C.gptq_gemm(a, b_q_weight, b_gptq_qzeros, b_gptq_scales,
|
|
|
|
|
b_g_idx, use_exllama, bit)
|
2024-04-11 03:26:07 +00:00
|
|
|
|
|
|
|
|
|
|
|
|
|
def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor,
|
|
|
|
|
bit: int) -> None:
|
2024-06-09 16:23:30 -04:00
|
|
|
torch.ops._C.gptq_shuffle(q_weight, q_perm, bit)
|
2024-04-11 03:26:07 +00:00
|
|
|
|
|
|
|
|
|
|
|
|
|
# squeezellm
|
|
|
|
|
def squeezellm_gemm(vec: torch.Tensor, mat: torch.Tensor, mul: torch.Tensor,
|
|
|
|
|
lookup_table: torch.Tensor) -> None:
|
2024-06-09 16:23:30 -04:00
|
|
|
torch.ops._C.squeezellm_gemm(vec, mat, mul, lookup_table)
|
2024-04-11 03:26:07 +00:00
|
|
|
|
|
|
|
|
|
|
|
|
|
# marlin
|
|
|
|
|
def marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
|
|
|
|
|
b_scales: torch.Tensor, workspace: torch.Tensor, size_m: int,
|
|
|
|
|
size_n: int, size_k: int) -> torch.Tensor:
|
2024-06-09 16:23:30 -04:00
|
|
|
return torch.ops._C.marlin_gemm(a, b_q_weight, b_scales, workspace, size_m,
|
|
|
|
|
size_n, size_k)
|
2024-04-11 03:26:07 +00:00
|
|
|
|
|
|
|
|
|
2024-05-16 12:56:15 -04:00
|
|
|
# marlin_24
|
|
|
|
|
def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
|
|
|
|
|
b_meta: torch.Tensor, b_scales: torch.Tensor,
|
2024-08-02 16:51:58 -04:00
|
|
|
workspace: torch.Tensor, b_q_type: ScalarType,
|
|
|
|
|
size_m: int, size_n: int, size_k: int) -> torch.Tensor:
|
2024-06-09 16:23:30 -04:00
|
|
|
return torch.ops._C.gptq_marlin_24_gemm(a, b_q_weight, b_meta, b_scales,
|
2024-08-02 16:51:58 -04:00
|
|
|
workspace, b_q_type, size_m,
|
2024-06-09 16:23:30 -04:00
|
|
|
size_n, size_k)
|
2024-05-16 12:56:15 -04:00
|
|
|
|
|
|
|
|
|
2024-05-16 18:32:50 -04:00
|
|
|
# cutlass
|
2024-06-20 14:36:10 -04:00
|
|
|
def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool:
|
|
|
|
|
return torch.ops._C.cutlass_scaled_mm_supports_fp8(cuda_device_capability)
|
|
|
|
|
|
|
|
|
|
|
2024-06-26 11:16:00 -04:00
|
|
|
def cutlass_scaled_mm(a: torch.Tensor,
|
|
|
|
|
b: torch.Tensor,
|
|
|
|
|
scale_a: torch.Tensor,
|
2024-06-13 14:22:19 -04:00
|
|
|
scale_b: torch.Tensor,
|
2024-07-31 10:38:03 +08:00
|
|
|
out_dtype: torch.dtype,
|
2024-06-26 11:16:00 -04:00
|
|
|
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
2024-05-16 18:32:50 -04:00
|
|
|
assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
|
|
|
|
|
assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16)
|
2024-08-06 14:17:08 -04:00
|
|
|
assert bias is None or bias.shape[0] == b.shape[
|
|
|
|
|
1] and bias.dtype == out_dtype
|
2024-05-16 18:32:50 -04:00
|
|
|
|
|
|
|
|
m = a.shape[0]
|
|
|
|
|
n = b.shape[1]
|
|
|
|
|
out = torch.empty((m, n), dtype=out_dtype, device=a.device)
|
|
|
|
|
|
2024-06-26 11:16:00 -04:00
|
|
|
torch.ops._C.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias)
|
|
|
|
|
|
2024-05-16 18:32:50 -04:00
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
|
2024-08-06 14:17:08 -04:00
|
|
|
def cutlass_scaled_mm_azp(a: torch.Tensor,
|
|
|
|
|
b: torch.Tensor,
|
|
|
|
|
scale_a: torch.Tensor,
|
|
|
|
|
scale_b: torch.Tensor,
|
|
|
|
|
out_dtype: torch.dtype,
|
|
|
|
|
azp_adj: torch.Tensor,
|
|
|
|
|
azp: Optional[torch.Tensor] = None,
|
|
|
|
|
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
|
|
|
assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
|
|
|
|
|
assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16)
|
|
|
|
|
assert bias is None or bias.numel(
|
|
|
|
|
) == b.shape[1] and bias.dtype == out_dtype
|
|
|
|
|
|
|
|
|
|
m = a.shape[0]
|
|
|
|
|
n = b.shape[1]
|
|
|
|
|
out = torch.empty((m, n), dtype=out_dtype, device=a.device)
|
|
|
|
|
|
|
|
|
|
torch.ops._C.cutlass_scaled_mm_azp(out, a, b, scale_a, scale_b, azp_adj,
|
|
|
|
|
azp, bias)
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
|
2024-04-25 19:03:56 +00:00
|
|
|
# aqlm
|
|
|
|
|
def aqlm_gemm(input: torch.Tensor, codes: torch.Tensor,
|
|
|
|
|
codebooks: torch.Tensor, scales: torch.Tensor,
|
2024-08-16 17:00:11 -04:00
|
|
|
codebook_partition_sizes: List[int],
|
2024-04-25 19:03:56 +00:00
|
|
|
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
2024-06-09 16:23:30 -04:00
|
|
|
return torch.ops._C.aqlm_gemm(input, codes, codebooks, scales,
|
|
|
|
|
codebook_partition_sizes, bias)
|
2024-04-25 19:03:56 +00:00
|
|
|
|
|
|
|
|
|
|
|
|
|
def aqlm_dequant(codes: torch.Tensor, codebooks: torch.Tensor,
|
2024-08-16 17:00:11 -04:00
|
|
|
codebook_partition_sizes: List[int]) -> torch.Tensor:
|
2024-06-09 16:23:30 -04:00
|
|
|
return torch.ops._C.aqlm_dequant(codes, codebooks,
|
|
|
|
|
codebook_partition_sizes)
|
2024-04-25 19:03:56 +00:00
|
|
|
|
|
|
|
|
|
2024-04-30 12:14:47 +00:00
|
|
|
# gptq_marlin
|
|
|
|
|
def gptq_marlin_repack(b_q_weight: torch.Tensor, perm: torch.Tensor,
|
2024-05-02 12:56:22 -04:00
|
|
|
size_k: int, size_n: int,
|
|
|
|
|
num_bits: int) -> torch.Tensor:
|
2024-06-09 16:23:30 -04:00
|
|
|
return torch.ops._C.gptq_marlin_repack(b_q_weight, perm, size_k, size_n,
|
|
|
|
|
num_bits)
|
2024-04-30 12:14:47 +00:00
|
|
|
|
|
|
|
|
|
2024-07-21 19:41:42 -04:00
|
|
|
# gptq_marlin
|
|
|
|
|
def awq_marlin_repack(b_q_weight: torch.Tensor, size_k: int, size_n: int,
|
|
|
|
|
num_bits: int) -> torch.Tensor:
|
|
|
|
|
return torch.ops._C.awq_marlin_repack(b_q_weight, size_k, size_n, num_bits)
|
|
|
|
|
|
|
|
|
|
|
2024-08-27 18:07:09 -04:00
|
|
|
def gptq_marlin_moe_repack(b_q_weight: torch.Tensor, perm: torch.Tensor,
|
|
|
|
|
size_k: int, size_n: int,
|
|
|
|
|
num_bits: int) -> torch.Tensor:
|
|
|
|
|
num_experts = b_q_weight.shape[0]
|
|
|
|
|
assert size_k % 16 == 0
|
|
|
|
|
output = torch.empty((num_experts, size_k // 16, size_n * 2),
|
|
|
|
|
device=b_q_weight.device,
|
|
|
|
|
dtype=b_q_weight.dtype)
|
|
|
|
|
for e in range(num_experts):
|
|
|
|
|
output[e] = torch.ops._C.gptq_marlin_repack(b_q_weight[e], perm[e],
|
|
|
|
|
size_k, size_n, num_bits)
|
|
|
|
|
return output
|
|
|
|
|
|
|
|
|
|
|
2024-08-02 16:51:58 -04:00
|
|
|
def gptq_marlin_gemm(a: torch.Tensor,
|
|
|
|
|
b_q_weight: torch.Tensor,
|
|
|
|
|
b_scales: torch.Tensor,
|
|
|
|
|
b_zeros: torch.Tensor,
|
|
|
|
|
g_idx: torch.Tensor,
|
|
|
|
|
perm: torch.Tensor,
|
|
|
|
|
workspace: torch.Tensor,
|
|
|
|
|
b_q_type: ScalarType,
|
|
|
|
|
size_m: int,
|
|
|
|
|
size_n: int,
|
|
|
|
|
size_k: int,
|
|
|
|
|
is_k_full: bool,
|
|
|
|
|
has_zp: bool = False,
|
|
|
|
|
use_fp32_reduce: bool = False) -> torch.Tensor:
|
2024-07-21 19:41:42 -04:00
|
|
|
return torch.ops._C.gptq_marlin_gemm(a, b_q_weight, b_scales, b_zeros,
|
2024-08-02 16:51:58 -04:00
|
|
|
g_idx, perm, workspace, b_q_type,
|
2024-07-21 19:41:42 -04:00
|
|
|
size_m, size_n, size_k, is_k_full,
|
2024-07-27 17:52:33 -04:00
|
|
|
has_zp, use_fp32_reduce)
|
2024-04-30 12:14:47 +00:00
|
|
|
|
|
|
|
|
|
2024-07-03 13:38:00 -04:00
|
|
|
# fp8 marlin
|
|
|
|
|
def fp8_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
|
|
|
|
|
b_scales: torch.Tensor, workspace: torch.Tensor,
|
|
|
|
|
num_bits: int, size_m: int, size_n: int,
|
|
|
|
|
size_k: int) -> torch.Tensor:
|
|
|
|
|
return torch.ops._C.fp8_marlin_gemm(a, b_q_weight, b_scales, workspace,
|
|
|
|
|
num_bits, size_m, size_n, size_k)
|
|
|
|
|
|
|
|
|
|
|
2024-08-20 09:09:33 -04:00
|
|
|
# machete
|
|
|
|
|
def machete_supported_schedules(b_type: ScalarType) -> List[str]:
|
|
|
|
|
return torch.ops._C.machete_supported_schedules(b_type)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def machete_gemm(
|
|
|
|
|
a: torch.Tensor,
|
|
|
|
|
b_q: torch.Tensor, # Should be the tensor returned by machete_prepack_B
|
|
|
|
|
b_type: ScalarType,
|
|
|
|
|
b_scales: Optional[torch.Tensor] = None,
|
|
|
|
|
b_zeros: Optional[torch.Tensor] = None,
|
|
|
|
|
b_group_size: Optional[int] = None,
|
|
|
|
|
c: Optional[torch.Tensor] = None,
|
|
|
|
|
alpha: Optional[float] = None,
|
|
|
|
|
beta: Optional[float] = None,
|
|
|
|
|
schedule: Optional[str] = None,
|
|
|
|
|
) -> torch.Tensor:
|
|
|
|
|
return torch.ops._C.machete_gemm(a, b_q, b_type, b_scales, b_zeros,
|
|
|
|
|
b_group_size, c, alpha, beta, schedule)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def machete_prepack_B(b_q_weight: torch.Tensor,
|
|
|
|
|
b_type: ScalarType) -> torch.Tensor:
|
|
|
|
|
return torch.ops._C.machete_prepack_B(b_q_weight, b_type)
|
|
|
|
|
|
|
|
|
|
|
[Kernel] FP8 support for MoE kernel / Mixtral (#4244)
This PR is the first step towards fixing https://github.com/vllm-project/vllm/pull/3208
It implements dynamic per-tensor scaling (see https://github.com/vllm-project/vllm/pull/4118), so users do not need to compute activation scales on a calibration dataset and they also don't need to convert their model checkpoints. It is enough to specify the `quantization="fp8"` argument. You can try out the PR like this:
```python
from vllm import LLM, SamplingParams
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
llm = LLM(model="mistralai/Mixtral-8x7B-Instruct-v0.1", tensor_parallel_size=2, quantization="fp8")
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
```
**Performance**: For this PR, the focus is on making the code clean (while still trying to get reasonable performance), there is a bunch of optimizations that we will submit as a follow up PR that significantly improve the performance (similar to the numbers in https://github.com/vllm-project/vllm/pull/3954). With this PR, the results are as follows:
<img width="725" alt="Screenshot 2024-04-21 at 1 31 50 PM" src="https://github.com/vllm-project/vllm/assets/113316/d8fe1118-07a0-4d4e-8530-37a77d465a03">
**Accuracy**: The accuracy with this PR on MMLU on `mistralai/Mixtral-8x7B-v0.1` is as follows:
```
| Groups |Version|Filter|n-shot|Metric|Value | |Stderr|
|------------------|-------|------|-----:|------|-----:|---|-----:|
|mmlu |N/A |none | 0|acc |0.7018|± |0.0036|
| - humanities |N/A |none | 5|acc |0.6472|± |0.0065|
| - other |N/A |none | 5|acc |0.7673|± |0.0072|
| - social_sciences|N/A |none | 5|acc |0.8099|± |0.0070|
| - stem |N/A |none | 5|acc |0.6131|± |0.0083|
```
this compares favorably with the fp16 results which are
```
| Groups |Version|Filter|n-shot|Metric|Value | |Stderr|
|------------------|-------|------|-----:|------|-----:|---|-----:|
|mmlu |N/A |none | 0|acc |0.7020|± |0.1313|
| - humanities |N/A |none | 5|acc |0.6425|± |0.1349|
| - other |N/A |none | 5|acc |0.7744|± |0.1038|
| - social_sciences|N/A |none | 5|acc |0.8131|± |0.0695|
| - stem |N/A |none | 5|acc |0.6108|± |0.1383|
```
Happy hacking!
2024-04-23 18:18:23 -07:00
|
|
|
# fp8
|
2024-04-26 21:49:59 -07:00
|
|
|
def scaled_fp8_quant(
|
|
|
|
|
input: torch.Tensor,
|
|
|
|
|
scale: Optional[torch.Tensor] = None,
|
2024-07-30 16:37:01 -04:00
|
|
|
num_token_padding: Optional[int] = None,
|
2024-07-19 21:15:26 -04:00
|
|
|
scale_ub: Optional[torch.Tensor] = None,
|
2024-07-19 19:08:15 -04:00
|
|
|
use_per_token_if_dynamic: bool = False,
|
2024-04-26 21:49:59 -07:00
|
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
[Kernel] [FP8] Improve FP8 linear layer performance (#4691)
This PR improves the FP8 performance of linear layers, which had been lacking before (#4118 (comment) and #4118 (comment)).
We noticed that CUBLASLt can find a better algorithm if the first dimension of the matrix is greater than 16. So this PR enlarges matrices appropriately during quantization. This improves FP8 performance and removes the performance regression vs. FP16, in many cases exceeding FP16 performance.
Here are benchmarks on llama3 70b (ITL numbers for 1000 input and 50 output tokens at fixed qps and at TP 4), all FP8 measurements are for dynamic quantization:
qps = 1: 24 ms (FP8, this PR), 32 ms (FP8, previous main), 26 ms (FP16)
qps = 2: 26 ms (FP8, this PR), 34ms (FP8, previous main), 28 ms (FP16)
qps = 4: 33 ms (FP8, this PR), 44 ms (FP8, previous main), 36 ms (FP16)
qps = 6: 46 ms (FP8, this PR), 56 ms (FP8, previous main), 54 ms (FP16)
qps = 8: 85 ms (FP8, this PR), 85 ms (FP8, previous main), 138 ms (FP16)
2024-05-09 16:38:07 -07:00
|
|
|
"""
|
|
|
|
|
Quantize input tensor to FP8 and return quantized tensor and scale.
|
|
|
|
|
|
|
|
|
|
This function supports both static and dynamic quantization: If you
|
|
|
|
|
provide the scale, it will use static scaling and if you omit it,
|
|
|
|
|
the scale will be determined dynamically. The function also allows
|
2024-07-30 16:37:01 -04:00
|
|
|
optional padding of the output tensors for downstream kernels that
|
[Kernel] [FP8] Improve FP8 linear layer performance (#4691)
This PR improves the FP8 performance of linear layers, which had been lacking before (#4118 (comment) and #4118 (comment)).
We noticed that CUBLASLt can find a better algorithm if the first dimension of the matrix is greater than 16. So this PR enlarges matrices appropriately during quantization. This improves FP8 performance and removes the performance regression vs. FP16, in many cases exceeding FP16 performance.
Here are benchmarks on llama3 70b (ITL numbers for 1000 input and 50 output tokens at fixed qps and at TP 4), all FP8 measurements are for dynamic quantization:
qps = 1: 24 ms (FP8, this PR), 32 ms (FP8, previous main), 26 ms (FP16)
qps = 2: 26 ms (FP8, this PR), 34ms (FP8, previous main), 28 ms (FP16)
qps = 4: 33 ms (FP8, this PR), 44 ms (FP8, previous main), 36 ms (FP16)
qps = 6: 46 ms (FP8, this PR), 56 ms (FP8, previous main), 54 ms (FP16)
qps = 8: 85 ms (FP8, this PR), 85 ms (FP8, previous main), 138 ms (FP16)
2024-05-09 16:38:07 -07:00
|
|
|
will benefit from padding.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
input: The input tensor to be quantized to FP8
|
|
|
|
|
scale: Optional scaling factor for the FP8 quantization
|
2024-07-20 12:36:57 -04:00
|
|
|
scale_ub: Optional upper bound for scaling factor in dynamic
|
|
|
|
|
per token case
|
2024-07-30 16:37:01 -04:00
|
|
|
num_token_padding: If specified, pad the first dimension
|
[Kernel] [FP8] Improve FP8 linear layer performance (#4691)
This PR improves the FP8 performance of linear layers, which had been lacking before (#4118 (comment) and #4118 (comment)).
We noticed that CUBLASLt can find a better algorithm if the first dimension of the matrix is greater than 16. So this PR enlarges matrices appropriately during quantization. This improves FP8 performance and removes the performance regression vs. FP16, in many cases exceeding FP16 performance.
Here are benchmarks on llama3 70b (ITL numbers for 1000 input and 50 output tokens at fixed qps and at TP 4), all FP8 measurements are for dynamic quantization:
qps = 1: 24 ms (FP8, this PR), 32 ms (FP8, previous main), 26 ms (FP16)
qps = 2: 26 ms (FP8, this PR), 34ms (FP8, previous main), 28 ms (FP16)
qps = 4: 33 ms (FP8, this PR), 44 ms (FP8, previous main), 36 ms (FP16)
qps = 6: 46 ms (FP8, this PR), 56 ms (FP8, previous main), 54 ms (FP16)
qps = 8: 85 ms (FP8, this PR), 85 ms (FP8, previous main), 138 ms (FP16)
2024-05-09 16:38:07 -07:00
|
|
|
of the output to at least this value.
|
2024-07-19 19:08:15 -04:00
|
|
|
use_per_token_if_dynamic: Whether to do per_tensor or per_token
|
|
|
|
|
in the dynamic quantization case.
|
[Kernel] [FP8] Improve FP8 linear layer performance (#4691)
This PR improves the FP8 performance of linear layers, which had been lacking before (#4118 (comment) and #4118 (comment)).
We noticed that CUBLASLt can find a better algorithm if the first dimension of the matrix is greater than 16. So this PR enlarges matrices appropriately during quantization. This improves FP8 performance and removes the performance regression vs. FP16, in many cases exceeding FP16 performance.
Here are benchmarks on llama3 70b (ITL numbers for 1000 input and 50 output tokens at fixed qps and at TP 4), all FP8 measurements are for dynamic quantization:
qps = 1: 24 ms (FP8, this PR), 32 ms (FP8, previous main), 26 ms (FP16)
qps = 2: 26 ms (FP8, this PR), 34ms (FP8, previous main), 28 ms (FP16)
qps = 4: 33 ms (FP8, this PR), 44 ms (FP8, previous main), 36 ms (FP16)
qps = 6: 46 ms (FP8, this PR), 56 ms (FP8, previous main), 54 ms (FP16)
qps = 8: 85 ms (FP8, this PR), 85 ms (FP8, previous main), 138 ms (FP16)
2024-05-09 16:38:07 -07:00
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
|
|
|
|
|
scaling factor.
|
|
|
|
|
"""
|
2024-07-30 16:37:01 -04:00
|
|
|
# This code assumes batch_dim and num_tokens are flattened
|
|
|
|
|
assert (input.ndim == 2)
|
2024-07-31 10:49:48 +08:00
|
|
|
shape: Union[Tuple[int, int], torch.Size] = input.shape
|
2024-08-16 12:06:30 -05:00
|
|
|
# For rocm, the output fp8 dtype is torch.float_e3m3fnuz
|
|
|
|
|
out_dtype: torch.dtype = torch.float8_e4m3fnuz if vllm.utils.is_hip() \
|
|
|
|
|
else torch.float8_e4m3fn
|
2024-07-30 16:37:01 -04:00
|
|
|
if num_token_padding:
|
|
|
|
|
shape = (max(num_token_padding, input.shape[0]), shape[1])
|
2024-08-16 12:06:30 -05:00
|
|
|
output = torch.empty(shape, device=input.device, dtype=out_dtype)
|
2024-07-30 16:37:01 -04:00
|
|
|
|
2024-04-26 21:49:59 -07:00
|
|
|
if scale is None:
|
2024-07-19 19:08:15 -04:00
|
|
|
if use_per_token_if_dynamic:
|
2024-07-30 16:37:01 -04:00
|
|
|
scale = torch.empty((shape[0], 1),
|
2024-07-19 19:08:15 -04:00
|
|
|
device=input.device,
|
|
|
|
|
dtype=torch.float32)
|
|
|
|
|
torch.ops._C.dynamic_per_token_scaled_fp8_quant(
|
2024-07-19 21:15:26 -04:00
|
|
|
output, input, scale, scale_ub)
|
2024-07-19 19:08:15 -04:00
|
|
|
else:
|
|
|
|
|
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
|
|
|
|
|
torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale)
|
2024-04-26 21:49:59 -07:00
|
|
|
else:
|
2024-07-30 16:37:01 -04:00
|
|
|
# num_token_padding not implemented for this case
|
|
|
|
|
assert (scale.numel() == 1 or num_token_padding is None)
|
2024-06-09 16:23:30 -04:00
|
|
|
torch.ops._C.static_scaled_fp8_quant(output, input, scale)
|
[Kernel] FP8 support for MoE kernel / Mixtral (#4244)
This PR is the first step towards fixing https://github.com/vllm-project/vllm/pull/3208
It implements dynamic per-tensor scaling (see https://github.com/vllm-project/vllm/pull/4118), so users do not need to compute activation scales on a calibration dataset and they also don't need to convert their model checkpoints. It is enough to specify the `quantization="fp8"` argument. You can try out the PR like this:
```python
from vllm import LLM, SamplingParams
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
llm = LLM(model="mistralai/Mixtral-8x7B-Instruct-v0.1", tensor_parallel_size=2, quantization="fp8")
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
```
**Performance**: For this PR, the focus is on making the code clean (while still trying to get reasonable performance), there is a bunch of optimizations that we will submit as a follow up PR that significantly improve the performance (similar to the numbers in https://github.com/vllm-project/vllm/pull/3954). With this PR, the results are as follows:
<img width="725" alt="Screenshot 2024-04-21 at 1 31 50 PM" src="https://github.com/vllm-project/vllm/assets/113316/d8fe1118-07a0-4d4e-8530-37a77d465a03">
**Accuracy**: The accuracy with this PR on MMLU on `mistralai/Mixtral-8x7B-v0.1` is as follows:
```
| Groups |Version|Filter|n-shot|Metric|Value | |Stderr|
|------------------|-------|------|-----:|------|-----:|---|-----:|
|mmlu |N/A |none | 0|acc |0.7018|± |0.0036|
| - humanities |N/A |none | 5|acc |0.6472|± |0.0065|
| - other |N/A |none | 5|acc |0.7673|± |0.0072|
| - social_sciences|N/A |none | 5|acc |0.8099|± |0.0070|
| - stem |N/A |none | 5|acc |0.6131|± |0.0083|
```
this compares favorably with the fp16 results which are
```
| Groups |Version|Filter|n-shot|Metric|Value | |Stderr|
|------------------|-------|------|-----:|------|-----:|---|-----:|
|mmlu |N/A |none | 0|acc |0.7020|± |0.1313|
| - humanities |N/A |none | 5|acc |0.6425|± |0.1349|
| - other |N/A |none | 5|acc |0.7744|± |0.1038|
| - social_sciences|N/A |none | 5|acc |0.8131|± |0.0695|
| - stem |N/A |none | 5|acc |0.6108|± |0.1383|
```
Happy hacking!
2024-04-23 18:18:23 -07:00
|
|
|
|
2024-07-19 19:08:15 -04:00
|
|
|
return output, scale
|
2024-07-17 21:38:35 -04:00
|
|
|
|
|
|
|
|
|
2024-05-23 17:29:18 -04:00
|
|
|
# int8
|
2024-06-07 12:36:26 -04:00
|
|
|
def scaled_int8_quant(
|
|
|
|
|
input: torch.Tensor,
|
|
|
|
|
scale: Optional[torch.Tensor] = None
|
|
|
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
2024-05-23 17:29:18 -04:00
|
|
|
"""
|
2024-06-07 12:36:26 -04:00
|
|
|
Quantize the input tensor to int8 and return the quantized tensor and scale.
|
2024-05-23 17:29:18 -04:00
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
input: The input tensor to be quantized to int8.
|
2024-06-07 12:36:26 -04:00
|
|
|
scale: Optional scaling factor for the int8 quantization.
|
|
|
|
|
When not provided, we invoke dynamic-per-token quantization.
|
2024-05-23 17:29:18 -04:00
|
|
|
|
|
|
|
|
Returns:
|
2024-06-07 12:36:26 -04:00
|
|
|
Tuple[Torch.Tensor, Torch.Tensor] : Output int8 tensor and scales.
|
2024-05-23 17:29:18 -04:00
|
|
|
"""
|
2024-06-07 12:36:26 -04:00
|
|
|
output = torch.empty_like(input, dtype=torch.int8)
|
|
|
|
|
if scale is not None:
|
|
|
|
|
# static-per-tensor quantization.
|
2024-06-09 16:23:30 -04:00
|
|
|
torch.ops._C.static_scaled_int8_quant(output, input, scale)
|
2024-06-07 12:36:26 -04:00
|
|
|
return output, scale
|
|
|
|
|
|
|
|
|
|
# dynamic-per-token quantization.
|
|
|
|
|
input_scales = torch.empty((input.numel() // input.shape[-1], 1),
|
|
|
|
|
device=input.device,
|
|
|
|
|
dtype=torch.float32)
|
2024-06-09 16:23:30 -04:00
|
|
|
torch.ops._C.dynamic_scaled_int8_quant(output, input, input_scales)
|
2024-06-07 12:36:26 -04:00
|
|
|
return output, input_scales
|
2024-05-23 17:29:18 -04:00
|
|
|
|
|
|
|
|
|
2024-07-31 21:55:21 +08:00
|
|
|
# qqq ops
|
|
|
|
|
def marlin_qqq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
|
|
|
|
|
s_tok: torch.Tensor, s_ch: torch.Tensor,
|
|
|
|
|
s_group: torch.Tensor, workspace: torch.Tensor,
|
|
|
|
|
size_m: int, size_n: int, size_k: int) -> torch.Tensor:
|
|
|
|
|
return torch.ops._C.marlin_qqq_gemm(a, b_q_weight, s_tok, s_ch, s_group,
|
|
|
|
|
workspace, size_m, size_n, size_k)
|
|
|
|
|
|
|
|
|
|
|
2024-08-06 07:54:23 +08:00
|
|
|
# gguf
|
2024-08-16 17:00:11 -04:00
|
|
|
def ggml_dequantize(W: torch.Tensor, quant_type: int, m: int,
|
|
|
|
|
n: int) -> torch.Tensor:
|
2024-08-06 07:54:23 +08:00
|
|
|
return torch.ops._C.ggml_dequantize(W, quant_type, m, n)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def ggml_mul_mat_vec_a8(
|
|
|
|
|
W: torch.Tensor,
|
|
|
|
|
X: torch.Tensor,
|
|
|
|
|
quant_type: int,
|
|
|
|
|
row: int,
|
2024-08-16 17:00:11 -04:00
|
|
|
) -> torch.Tensor:
|
2024-08-06 07:54:23 +08:00
|
|
|
return torch.ops._C.ggml_mul_mat_vec_a8(W, X, quant_type, row)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def ggml_mul_mat_a8(
|
|
|
|
|
W: torch.Tensor,
|
|
|
|
|
X: torch.Tensor,
|
|
|
|
|
quant_type: int,
|
|
|
|
|
row: int,
|
2024-08-16 17:00:11 -04:00
|
|
|
) -> torch.Tensor:
|
2024-08-06 07:54:23 +08:00
|
|
|
return torch.ops._C.ggml_mul_mat_a8(W, X, quant_type, row)
|
|
|
|
|
|
|
|
|
|
|
2024-08-29 01:06:52 +03:00
|
|
|
# mamba
|
|
|
|
|
def causal_conv1d_fwd(x: torch.Tensor, weight: torch.Tensor,
|
|
|
|
|
bias_: Optional[torch.Tensor],
|
|
|
|
|
seq_idx_: Optional[torch.Tensor],
|
|
|
|
|
initial_states_: Optional[torch.Tensor],
|
|
|
|
|
final_states_out_: Optional[torch.Tensor],
|
|
|
|
|
silu_activation: bool) -> torch.Tensor:
|
|
|
|
|
return torch.ops._C.causal_conv1d_fwd(x, weight, bias_, seq_idx_,
|
|
|
|
|
initial_states_, final_states_out_,
|
|
|
|
|
silu_activation)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def causal_conv1d_update(x: torch.Tensor, conv_state: torch.Tensor,
|
|
|
|
|
weight: torch.Tensor, bias_: Optional[torch.Tensor],
|
|
|
|
|
silu_activation: bool) -> torch.Tensor:
|
|
|
|
|
return torch.ops._C.causal_conv1d_update(x, conv_state, weight, bias_,
|
|
|
|
|
silu_activation)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def selective_scan_fwd(u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor,
|
|
|
|
|
B: torch.Tensor, C: torch.Tensor,
|
|
|
|
|
D_: Optional[torch.Tensor], z_: Optional[torch.Tensor],
|
|
|
|
|
delta_bias_: Optional[torch.Tensor],
|
|
|
|
|
delta_softplus: bool, index_: Optional[torch.Tensor],
|
|
|
|
|
x: Optional[torch.Tensor]) -> List[torch.Tensor]:
|
|
|
|
|
return torch.ops._C.selective_scan_fwd(u, delta, A, B, C, D_, z_,
|
|
|
|
|
delta_bias_, delta_softplus, index_,
|
|
|
|
|
x)
|
|
|
|
|
|
|
|
|
|
|
2024-04-11 03:26:07 +00:00
|
|
|
# moe
|
|
|
|
|
def moe_align_block_size(topk_ids: torch.Tensor, num_experts: int,
|
|
|
|
|
block_size: int, sorted_token_ids: torch.Tensor,
|
|
|
|
|
experts_ids: torch.Tensor,
|
|
|
|
|
num_tokens_post_pad: torch.Tensor) -> None:
|
2024-06-09 16:23:30 -04:00
|
|
|
torch.ops._C.moe_align_block_size(topk_ids, num_experts, block_size,
|
|
|
|
|
sorted_token_ids, experts_ids,
|
|
|
|
|
num_tokens_post_pad)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor,
|
|
|
|
|
token_expert_indicies: torch.Tensor,
|
|
|
|
|
gating_output: float) -> None:
|
|
|
|
|
torch.ops._moe_C.topk_softmax(topk_weights, topk_ids,
|
|
|
|
|
token_expert_indicies, gating_output)
|
2024-04-11 03:26:07 +00:00
|
|
|
|
|
|
|
|
|
|
|
|
|
def reshape_and_cache(
|
|
|
|
|
key: torch.Tensor,
|
|
|
|
|
value: torch.Tensor,
|
|
|
|
|
key_cache: torch.Tensor,
|
|
|
|
|
value_cache: torch.Tensor,
|
|
|
|
|
slot_mapping: torch.Tensor,
|
|
|
|
|
kv_cache_dtype: str,
|
2024-07-16 18:31:32 -04:00
|
|
|
k_scale: float,
|
|
|
|
|
v_scale: float,
|
2024-04-11 03:26:07 +00:00
|
|
|
) -> None:
|
2024-06-09 16:23:30 -04:00
|
|
|
torch.ops._C_cache_ops.reshape_and_cache(key, value, key_cache,
|
|
|
|
|
value_cache, slot_mapping,
|
2024-07-16 18:31:32 -04:00
|
|
|
kv_cache_dtype, k_scale, v_scale)
|
2024-04-11 03:26:07 +00:00
|
|
|
|
|
|
|
|
|
2024-05-03 15:51:27 -07:00
|
|
|
def reshape_and_cache_flash(
|
|
|
|
|
key: torch.Tensor,
|
|
|
|
|
value: torch.Tensor,
|
|
|
|
|
key_cache: torch.Tensor,
|
|
|
|
|
value_cache: torch.Tensor,
|
|
|
|
|
slot_mapping: torch.Tensor,
|
|
|
|
|
kv_cache_dtype: str,
|
2024-07-24 11:36:52 -07:00
|
|
|
k_scale: float,
|
|
|
|
|
v_scale: float,
|
2024-05-03 15:51:27 -07:00
|
|
|
) -> None:
|
2024-06-09 16:23:30 -04:00
|
|
|
torch.ops._C_cache_ops.reshape_and_cache_flash(key, value, key_cache,
|
|
|
|
|
value_cache, slot_mapping,
|
2024-07-24 11:36:52 -07:00
|
|
|
kv_cache_dtype, k_scale,
|
|
|
|
|
v_scale)
|
2024-05-03 15:51:27 -07:00
|
|
|
|
|
|
|
|
|
2024-06-18 02:01:25 +08:00
|
|
|
def copy_blocks(key_caches: List[torch.Tensor],
|
|
|
|
|
value_caches: List[torch.Tensor],
|
2024-04-11 03:26:07 +00:00
|
|
|
block_mapping: torch.Tensor) -> None:
|
2024-06-09 16:23:30 -04:00
|
|
|
torch.ops._C_cache_ops.copy_blocks(key_caches, value_caches, block_mapping)
|
2024-04-11 03:26:07 +00:00
|
|
|
|
|
|
|
|
|
|
|
|
|
def swap_blocks(src: torch.Tensor, dst: torch.Tensor,
|
2024-05-10 20:52:48 +08:00
|
|
|
block_mapping: torch.Tensor) -> None:
|
2024-06-09 16:23:30 -04:00
|
|
|
torch.ops._C_cache_ops.swap_blocks(src, dst, block_mapping)
|
2024-04-11 03:26:07 +00:00
|
|
|
|
|
|
|
|
|
2024-05-09 17:04:17 -07:00
|
|
|
def convert_fp8(output: torch.Tensor,
|
|
|
|
|
input: torch.Tensor,
|
|
|
|
|
scale: float = 1.0,
|
|
|
|
|
kv_dtype: str = "fp8") -> None:
|
2024-06-09 16:23:30 -04:00
|
|
|
torch.ops._C_cache_ops.convert_fp8(output, input, scale, kv_dtype)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_device_attribute(attribute: int, device: int) -> int:
|
|
|
|
|
return torch.ops._C_cuda_utils.get_device_attribute(attribute, device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_max_shared_memory_per_block_device_attribute(device: int) -> int:
|
|
|
|
|
# ruff: noqa: E501
|
|
|
|
|
return torch.ops._C_cuda_utils.get_max_shared_memory_per_block_device_attribute(
|
|
|
|
|
device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# custom ar
|
|
|
|
|
def init_custom_ar(meta: torch.Tensor, rank_data: torch.Tensor,
|
|
|
|
|
handles: List[str], offsets: List[int], rank: int,
|
|
|
|
|
full_nvlink: bool) -> int:
|
|
|
|
|
return torch.ops._C_custom_ar.init_custom_ar(meta, rank_data, handles,
|
|
|
|
|
offsets, rank, full_nvlink)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def should_custom_ar(inp: torch.Tensor, max_size: int, world_size: int,
|
|
|
|
|
full_nvlink: bool) -> bool:
|
|
|
|
|
return torch.ops._C_custom_ar.should_custom_ar(inp, max_size, world_size,
|
|
|
|
|
full_nvlink)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def all_reduce_reg(fa: int, inp: torch.Tensor, out: torch.Tensor) -> None:
|
|
|
|
|
torch.ops._C_custom_ar.all_reduce_reg(fa, inp, out)
|
|
|
|
|
|
2024-04-11 03:26:07 +00:00
|
|
|
|
2024-06-09 16:23:30 -04:00
|
|
|
def all_reduce_unreg(fa: int, inp: torch.Tensor, reg_buffer: torch.Tensor,
|
|
|
|
|
out: torch.Tensor) -> None:
|
|
|
|
|
torch.ops._C_custom_ar.all_reduce_unreg(fa, inp, reg_buffer, out)
|
2024-04-11 03:26:07 +00:00
|
|
|
|
2024-06-09 16:23:30 -04:00
|
|
|
|
|
|
|
|
def dispose(fa: int) -> None:
|
|
|
|
|
torch.ops._C_custom_ar.dispose(fa)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def meta_size() -> int:
|
|
|
|
|
return torch.ops._C_custom_ar.meta_size()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def register_buffer(fa: int, t: torch.Tensor, handles: List[str],
|
|
|
|
|
offsets: List[int]) -> None:
|
|
|
|
|
return torch.ops._C_custom_ar.register_buffer(fa, t, handles, offsets)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_graph_buffer_ipc_meta(fa: int) -> Tuple[List[str], List[int]]:
|
|
|
|
|
return torch.ops._C_custom_ar.get_graph_buffer_ipc_meta(fa)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def register_graph_buffers(fa: int, handles: List[str],
|
|
|
|
|
offsets: List[List[int]]) -> None:
|
|
|
|
|
torch.ops._C_custom_ar.register_graph_buffers(fa, handles, offsets)
|
|
|
|
|
|
|
|
|
|
|
2024-06-12 14:46:35 -07:00
|
|
|
# temporary fix for https://github.com/vllm-project/vllm/issues/5456
|
|
|
|
|
# TODO: remove this in v0.6.0
|
|
|
|
|
names_and_values = globals()
|
|
|
|
|
names_and_values_to_update = {}
|
|
|
|
|
# prepare variables to avoid dict size change during iteration
|
|
|
|
|
k, v, arg = None, None, None
|
|
|
|
|
fn_type = type(lambda x: x)
|
|
|
|
|
for k, v in names_and_values.items():
|
|
|
|
|
# find functions that are defined in this file and have torch.Tensor
|
|
|
|
|
# in their annotations. `arg == "torch.Tensor"` is used to handle
|
|
|
|
|
# the case when users use `import __annotations__` to turn type
|
|
|
|
|
# hints into strings.
|
|
|
|
|
if isinstance(v, fn_type) \
|
|
|
|
|
and v.__code__.co_filename == __file__ \
|
|
|
|
|
and any(arg is torch.Tensor or arg == "torch.Tensor"
|
2024-08-06 14:17:08 -04:00
|
|
|
for arg in v.__annotations__.values()):
|
2024-06-12 14:46:35 -07:00
|
|
|
names_and_values_to_update[k] = hint_on_error(v)
|
|
|
|
|
|
|
|
|
|
names_and_values.update(names_and_values_to_update)
|
|
|
|
|
del names_and_values_to_update, names_and_values, v, k, fn_type
|