[Kernel] GGUF MMVQ kernel for multiple input vectors (#18754)
Signed-off-by: SzymonOzog <szymon.ozog@gmail.com>
This commit is contained in:
@@ -99,6 +99,10 @@ MMQ_QUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES
|
||||
|
||||
def _fused_mul_mat_gguf(x: torch.Tensor, qweight: torch.Tensor,
|
||||
qweight_type: int) -> torch.Tensor:
|
||||
if qweight_type in IMATRIX_QUANT_TYPES:
|
||||
mmvq_safe = 8 if qweight.shape[0] > 5120 else 16
|
||||
else:
|
||||
mmvq_safe = 2 if qweight.shape[0] > 5120 else 6
|
||||
# HACK: when doing chunked prefill we don't generate output tokens
|
||||
# so input to logits generator is empty which causes invalid parameter
|
||||
if x.shape[0] == 0:
|
||||
@@ -110,7 +114,7 @@ def _fused_mul_mat_gguf(x: torch.Tensor, qweight: torch.Tensor,
|
||||
if qweight_type in UNQUANTIZED_TYPES:
|
||||
return x @ qweight.T
|
||||
# enable MMVQ in contiguous batching with batch_size=1
|
||||
if x.shape[0] == 1 and qweight_type in MMVQ_QUANT_TYPES:
|
||||
if x.shape[0] <= mmvq_safe and qweight_type in MMVQ_QUANT_TYPES:
|
||||
y = ops.ggml_mul_mat_vec_a8(qweight, x, qweight_type, qweight.shape[0])
|
||||
# Use MMQ Kernel if it's available (standard + k-quants)
|
||||
elif qweight_type in MMQ_QUANT_TYPES:
|
||||
|
||||
Reference in New Issue
Block a user