Deepseek-v3 Batch Invariant on 8xH100 (#26609)
Signed-off-by: Bram Wasti <bwasti@meta.com> Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
This commit is contained in:
@@ -395,7 +395,6 @@ def mean_dim(
|
||||
Tensor with mean values along specified dimension
|
||||
"""
|
||||
# Validate inputs
|
||||
assert input.is_cuda, "Input must be a CUDA tensor"
|
||||
assert -input.ndim <= dim < input.ndim, (
|
||||
f"Invalid dimension {dim} for tensor with {input.ndim} dimensions"
|
||||
)
|
||||
@@ -470,6 +469,45 @@ def mm_batch_invariant(a, b):
|
||||
return matmul_persistent(a, b)
|
||||
|
||||
|
||||
def matmul_batch_invariant(a, b, *, out=None):
|
||||
# torch.matmul can handle various dimensions
|
||||
# For 2D x 2D, it's the same as mm
|
||||
if a.ndim == 2 and b.ndim == 2:
|
||||
result = matmul_persistent(a, b)
|
||||
if out is not None:
|
||||
out.copy_(result)
|
||||
return out
|
||||
return result
|
||||
elif a.ndim == 3 and b.ndim == 3:
|
||||
# Handle batched case like bmm
|
||||
return bmm_batch_invariant(a, b, out=out)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"matmul_batch_invariant currently only supports 2D x 2D and 3D x 3D, "
|
||||
f"got shapes {a.shape} and {b.shape}"
|
||||
)
|
||||
|
||||
|
||||
def bmm_batch_invariant(a, b, *, out=None):
|
||||
# Batched matrix multiply: (B, M, K) x (B, K, N) -> (B, M, N)
|
||||
# Process each batch separately with our persistent kernel
|
||||
if a.ndim == 3 and b.ndim == 3:
|
||||
results = []
|
||||
for i in range(a.shape[0]):
|
||||
results.append(matmul_persistent(a[i], b[i]))
|
||||
result = torch.stack(results, dim=0)
|
||||
|
||||
if out is not None:
|
||||
out.copy_(result)
|
||||
return out
|
||||
return result
|
||||
else:
|
||||
raise ValueError(
|
||||
f"bmm_batch_invariant expects 3D tensors, "
|
||||
f"got shapes {a.shape} and {b.shape}"
|
||||
)
|
||||
|
||||
|
||||
def addmm_batch_invariant(bias, a, b):
|
||||
return matmul_persistent(a, b, bias=bias)
|
||||
|
||||
@@ -479,11 +517,24 @@ def _log_softmax_batch_invariant(input, dim, _half_to_float):
|
||||
return log_softmax(input, dim=dim)
|
||||
|
||||
|
||||
def softmax_batch_invariant(input, dim, dtype=None):
|
||||
# Compute softmax in a deterministic way
|
||||
# First subtract max for numerical stability (standard practice)
|
||||
input_max = torch.amax(input, dim=dim, keepdim=True)
|
||||
input = input - input_max
|
||||
exp_x = torch.exp(input)
|
||||
sum_exp_x = torch.sum(exp_x, dim=dim, keepdim=True)
|
||||
return exp_x / sum_exp_x
|
||||
|
||||
|
||||
def mean_batch_invariant(input, dim, keepdim=False, dtype: torch.dtype | None = None):
|
||||
assert dtype is None or dtype == torch.float32, f"unsupported dtype: {dtype}"
|
||||
|
||||
result = input.to(torch.float32)
|
||||
|
||||
if len(dim) == 0:
|
||||
dim = [i for i in range(len(input.shape))]
|
||||
|
||||
# Sort dimensions to reduce from largest to smallest to handle shifting dims
|
||||
# during iterative reduction.
|
||||
sorted_dims = sorted([d % input.ndim for d in dim], reverse=True)
|
||||
@@ -500,8 +551,134 @@ def mean_batch_invariant(input, dim, keepdim=False, dtype: torch.dtype | None =
|
||||
return result
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _rms_norm_kernel(
|
||||
input_ptr,
|
||||
weight_ptr,
|
||||
output_ptr,
|
||||
input_row_stride,
|
||||
output_row_stride,
|
||||
n_cols,
|
||||
eps,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
Compute RMS normalization along the last dimension of a 2D tensor.
|
||||
RMS Norm: y = x / sqrt(mean(x^2) + eps) * weight
|
||||
Each block handles one row of the input tensor.
|
||||
"""
|
||||
row_idx = tl.program_id(0).to(tl.int64)
|
||||
row_start_ptr = input_ptr + row_idx * input_row_stride
|
||||
output_row_start_ptr = output_ptr + row_idx * output_row_stride
|
||||
|
||||
# Step 1: Compute sum of squares in float32 to avoid overflow
|
||||
sum_sq = tl.zeros([1], dtype=tl.float32)
|
||||
for col_offset in range(0, n_cols, BLOCK_SIZE):
|
||||
col_idx = col_offset + tl.arange(0, BLOCK_SIZE)
|
||||
mask = col_idx < n_cols
|
||||
|
||||
vals = tl.load(row_start_ptr + col_idx, mask=mask, other=0.0)
|
||||
# Convert to float32 for accumulation to prevent overflow
|
||||
vals_f32 = vals.to(tl.float32)
|
||||
sq_vals = vals_f32 * vals_f32
|
||||
sum_sq += tl.sum(tl.where(mask, sq_vals, 0.0))
|
||||
|
||||
# Step 2: Compute RMS (root mean square) in float32
|
||||
mean_sq = sum_sq / n_cols
|
||||
rms = tl.sqrt(mean_sq + eps)
|
||||
inv_rms = 1.0 / rms
|
||||
|
||||
# Step 3: Normalize and apply weight
|
||||
for col_offset in range(0, n_cols, BLOCK_SIZE):
|
||||
col_idx = col_offset + tl.arange(0, BLOCK_SIZE)
|
||||
mask = col_idx < n_cols
|
||||
vals = tl.load(row_start_ptr + col_idx, mask=mask, other=0.0)
|
||||
weight = tl.load(weight_ptr + col_idx, mask=mask, other=1.0)
|
||||
# Compute in float32 then convert back to input dtype
|
||||
vals_f32 = vals.to(tl.float32)
|
||||
weight_f32 = weight.to(tl.float32)
|
||||
output_f32 = vals_f32 * inv_rms * weight_f32
|
||||
output = output_f32.to(vals.dtype)
|
||||
tl.store(output_row_start_ptr + col_idx, output, mask=mask)
|
||||
|
||||
|
||||
def rms_norm(
|
||||
input: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute RMS normalization using Triton kernel.
|
||||
|
||||
RMS Norm normalizes the input by the root mean square and scales by weight:
|
||||
output = input / sqrt(mean(input^2) + eps) * weight
|
||||
|
||||
Args:
|
||||
input: Input tensor of shape (..., hidden_size)
|
||||
weight: Weight tensor of shape (hidden_size,)
|
||||
eps: Small constant for numerical stability
|
||||
|
||||
Returns:
|
||||
Tensor with RMS normalization applied along the last dimension
|
||||
"""
|
||||
assert weight.dim() == 1, "Weight must be 1-dimensional"
|
||||
assert input.shape[-1] == weight.shape[0], (
|
||||
f"Input last dimension ({input.shape[-1]}) must match "
|
||||
f"weight dimension ({weight.shape[0]})"
|
||||
)
|
||||
|
||||
# Flatten all dimensions except the last one
|
||||
original_shape = input.shape
|
||||
input_2d = input.reshape(-1, input.shape[-1])
|
||||
input_2d = input_2d.contiguous()
|
||||
weight = weight.contiguous()
|
||||
|
||||
n_rows, n_cols = input_2d.shape
|
||||
|
||||
output = torch.empty_like(input_2d)
|
||||
BLOCK_SIZE = 1024
|
||||
grid = (n_rows,)
|
||||
_rms_norm_kernel[grid](
|
||||
input_2d,
|
||||
weight,
|
||||
output,
|
||||
input_2d.stride(0),
|
||||
output.stride(0),
|
||||
n_cols,
|
||||
eps,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
)
|
||||
return output.reshape(original_shape)
|
||||
|
||||
|
||||
def rms_norm_batch_invariant(
|
||||
input: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Batch-invariant wrapper for RMS normalization.
|
||||
|
||||
This function provides a deterministic, batch-invariant implementation
|
||||
of RMS normalization for use with the batch_invariant mode.
|
||||
|
||||
Args:
|
||||
input: Input tensor of shape (..., hidden_size)
|
||||
weight: Weight tensor of shape (hidden_size,)
|
||||
eps: Small constant for numerical stability
|
||||
|
||||
Returns:
|
||||
RMS normalized tensor
|
||||
"""
|
||||
return rms_norm(input, weight, eps=eps)
|
||||
|
||||
|
||||
def linear_batch_invariant(input, weight, bias=None):
|
||||
output = mm_batch_invariant(input, weight.t())
|
||||
if bias is not None:
|
||||
output = output + bias
|
||||
return output
|
||||
|
||||
|
||||
_batch_invariant_MODE = False
|
||||
_batch_invariant_LIB = None
|
||||
_original_torch_bmm = None
|
||||
|
||||
|
||||
def is_batch_invariant_mode_enabled():
|
||||
@@ -509,7 +686,7 @@ def is_batch_invariant_mode_enabled():
|
||||
|
||||
|
||||
def enable_batch_invariant_mode():
|
||||
global _batch_invariant_MODE, _batch_invariant_LIB
|
||||
global _batch_invariant_MODE, _batch_invariant_LIB, _original_torch_bmm
|
||||
if _batch_invariant_MODE:
|
||||
return
|
||||
|
||||
@@ -517,16 +694,28 @@ def enable_batch_invariant_mode():
|
||||
_batch_invariant_LIB = torch.library.Library("aten", "IMPL")
|
||||
_batch_invariant_LIB.impl("aten::mm", mm_batch_invariant, "CUDA")
|
||||
_batch_invariant_LIB.impl("aten::addmm", addmm_batch_invariant, "CUDA")
|
||||
_batch_invariant_LIB.impl("aten::matmul", matmul_batch_invariant, "CUDA")
|
||||
_batch_invariant_LIB.impl("aten::bmm", bmm_batch_invariant, "CUDA")
|
||||
_batch_invariant_LIB.impl("aten::linear", linear_batch_invariant, "CUDA")
|
||||
_batch_invariant_LIB.impl(
|
||||
"aten::_log_softmax", _log_softmax_batch_invariant, "CUDA"
|
||||
)
|
||||
_batch_invariant_LIB.impl("aten::softmax", softmax_batch_invariant, "CUDA")
|
||||
_batch_invariant_LIB.impl("aten::_softmax", softmax_batch_invariant, "CUDA")
|
||||
_batch_invariant_LIB.impl("aten::mean.dim", mean_batch_invariant, "CUDA")
|
||||
|
||||
# Also monkeypatch torch.bmm directly as a fallback
|
||||
_original_torch_bmm = torch.bmm
|
||||
torch.bmm = bmm_batch_invariant
|
||||
|
||||
|
||||
def disable_batch_invariant_mode():
|
||||
global _batch_invariant_MODE, _batch_invariant_LIB
|
||||
global _batch_invariant_MODE, _batch_invariant_LIB, _original_torch_bmm
|
||||
if _batch_invariant_LIB is not None:
|
||||
_batch_invariant_LIB._destroy()
|
||||
if _original_torch_bmm is not None:
|
||||
torch.bmm = _original_torch_bmm
|
||||
_original_torch_bmm = None
|
||||
_batch_invariant_MODE = False
|
||||
_batch_invariant_LIB = None
|
||||
|
||||
@@ -563,17 +752,55 @@ def vllm_kernel_override_batch_invariant():
|
||||
return is_overridden
|
||||
|
||||
|
||||
def override_envs_for_invariance():
|
||||
curr_attn_backend = envs.VLLM_ATTENTION_BACKEND
|
||||
supported_backends = [
|
||||
"FLASH_ATTN", # best supported backend
|
||||
"FLEX_ATTENTION",
|
||||
"FLASHINFER",
|
||||
"FLASH_ATTN_MLA",
|
||||
"TRITON_MLA",
|
||||
# Not yet supported MLA backends
|
||||
# "FLASHMLA",
|
||||
# "FLASHINFER_MLA",
|
||||
]
|
||||
if curr_attn_backend not in supported_backends:
|
||||
warning = (
|
||||
"Forcibly updating attention backend to"
|
||||
f" {supported_backends[0]} for batch_invariant. "
|
||||
f" Supported backends: {supported_backends}."
|
||||
)
|
||||
logger.warning_once(warning)
|
||||
os.environ["VLLM_ATTENTION_BACKEND"] = supported_backends[0]
|
||||
if os.environ["VLLM_ATTENTION_BACKEND"] != supported_backends[0]:
|
||||
warning = (
|
||||
"You are using a decode-invariant form of batch invariance. "
|
||||
"This will not be invariant between prefill and decode."
|
||||
)
|
||||
logger.warning_once(warning)
|
||||
os.environ["VLLM_ALLREDUCE_USE_SYMM_MEM"] = "0"
|
||||
|
||||
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
|
||||
|
||||
# NCCL determinism settings
|
||||
os.environ["NCCL_LAUNCH_MODE"] = "GROUP"
|
||||
os.environ["NCCL_COLLNET_ENABLE"] = "0"
|
||||
os.environ["NCCL_NVLS_ENABLE"] = "0"
|
||||
os.environ["NCCL_P2P_NET_DISABLE"] = "1"
|
||||
os.environ["NCCL_MIN_NCHANNELS"] = "1"
|
||||
os.environ["NCCL_MAX_NCHANNELS"] = "1"
|
||||
os.environ["NCCL_PROTO"] = "Simple"
|
||||
os.environ["NCCL_ALGO"] = "allreduce:tree"
|
||||
os.environ["NCCL_NTHREADS"] = "1"
|
||||
os.environ["NCCL_SOCKET_NTHREADS"] = "1"
|
||||
|
||||
|
||||
def init_batch_invariance():
|
||||
# this will hit all the csrc overrides as well
|
||||
if vllm_kernel_override_batch_invariant():
|
||||
curr_attn_backend = envs.VLLM_ATTENTION_BACKEND
|
||||
supported_backends = ["FLEX_ATTENTION", "FLASHINFER"]
|
||||
if curr_attn_backend not in supported_backends:
|
||||
warning = (
|
||||
"Forcibly updating attention backend to"
|
||||
f" {supported_backends[0]} for batch_invariant. "
|
||||
f" Supported backends: {supported_backends}."
|
||||
)
|
||||
logger.warning_once(warning)
|
||||
os.environ["VLLM_ATTENTION_BACKEND"] = supported_backends[0]
|
||||
override_envs_for_invariance()
|
||||
enable_batch_invariant_mode()
|
||||
|
||||
# Disable TF32 for batch invariance - it causes non-deterministic rounding
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
torch.backends.cudnn.allow_tf32 = False
|
||||
|
||||
@@ -15,6 +15,9 @@ import vllm.envs as envs
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.batch_invariant import (
|
||||
vllm_kernel_override_batch_invariant,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FUSED_MOE_UNQUANTIZED_CONFIG,
|
||||
FusedMoEQuantConfig,
|
||||
@@ -837,6 +840,10 @@ def get_moe_configs(
|
||||
be picked and the associated configuration chosen to invoke the kernel.
|
||||
"""
|
||||
|
||||
# Avoid optimizing for the batch invariant case. Use default config
|
||||
if vllm_kernel_override_batch_invariant():
|
||||
return None
|
||||
|
||||
# First look up if an optimized configuration is available in the configs
|
||||
# directory
|
||||
block_shape = [block_n, block_k] if block_n and block_k else None
|
||||
@@ -969,6 +976,15 @@ def get_default_config(
|
||||
dtype: str | None,
|
||||
block_shape: list[int] | None = None,
|
||||
) -> dict[str, int]:
|
||||
if vllm_kernel_override_batch_invariant():
|
||||
config = {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
}
|
||||
return config
|
||||
|
||||
if dtype == "fp8_w8a8" and block_shape is not None:
|
||||
# Block-wise quant: BLOCK_SIZE_N must be divisible by block_shape[0]
|
||||
# BLOCK_SIZE_K must be divisible by block_shape[1]
|
||||
@@ -1118,7 +1134,10 @@ def fused_topk_bias(
|
||||
scores_for_choice = scores.view(
|
||||
-1, n_routed_experts
|
||||
) + e_score_correction_bias.unsqueeze(0)
|
||||
topk_indices = torch.topk(scores_for_choice, k=topk, dim=-1, sorted=False)[1]
|
||||
|
||||
# For batch invariance, use sorted=True to ensure deterministic expert selection
|
||||
use_sorted = vllm_kernel_override_batch_invariant()
|
||||
topk_indices = torch.topk(scores_for_choice, k=topk, dim=-1, sorted=use_sorted)[1]
|
||||
topk_weights = scores.gather(1, topk_indices)
|
||||
if renormalize:
|
||||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||||
@@ -1179,7 +1198,10 @@ def grouped_topk(
|
||||
group_scores = (
|
||||
scores.view(num_token, num_expert_group, -1).max(dim=-1).values
|
||||
) # [n, n_group]
|
||||
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[
|
||||
|
||||
# For batch invariance, use sorted=True to ensure deterministic expert selection
|
||||
use_sorted = vllm_kernel_override_batch_invariant()
|
||||
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=use_sorted)[
|
||||
1
|
||||
] # [n, top_k_group]
|
||||
group_mask = torch.zeros_like(group_scores) # [n, n_group]
|
||||
@@ -1192,11 +1214,13 @@ def grouped_topk(
|
||||
tmp_scores = scores.masked_fill(~score_mask.bool(), float("-inf")) # [n, e]
|
||||
|
||||
if e_score_correction_bias is not None:
|
||||
topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)[1]
|
||||
topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=use_sorted)[1]
|
||||
# Use original unbiased scores for the routing weights
|
||||
topk_weights = original_scores.gather(1, topk_ids)
|
||||
else:
|
||||
topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
|
||||
topk_weights, topk_ids = torch.topk(
|
||||
tmp_scores, k=topk, dim=-1, sorted=use_sorted
|
||||
)
|
||||
|
||||
if renormalize:
|
||||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||||
|
||||
@@ -8,6 +8,10 @@ import torch.nn.functional as F
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.model_executor.layers.batch_invariant import (
|
||||
rms_norm_batch_invariant,
|
||||
vllm_kernel_override_batch_invariant,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import direct_register_custom_op
|
||||
|
||||
@@ -21,6 +25,8 @@ def rms_norm(
|
||||
) -> torch.Tensor:
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
if vllm_kernel_override_batch_invariant():
|
||||
return rms_norm_batch_invariant(x, weight, variance_epsilon)
|
||||
out = torch.empty_like(x)
|
||||
ops.rms_norm(
|
||||
out,
|
||||
@@ -39,6 +45,10 @@ def fused_add_rms_norm(
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
if vllm_kernel_override_batch_invariant():
|
||||
return rms_norm_batch_invariant(
|
||||
x + residual, weight, variance_epsilon
|
||||
), x + residual
|
||||
ops.fused_add_rms_norm(
|
||||
x,
|
||||
residual,
|
||||
|
||||
@@ -160,6 +160,7 @@ class MultiHeadLatentAttentionWrapper(CustomOp):
|
||||
k_pe,
|
||||
output_shape=(hidden_states.shape[0], self.num_heads * self.v_head_dim),
|
||||
)
|
||||
|
||||
return self.o_proj(attn_out)[0]
|
||||
|
||||
def forward_cuda(self, *args, **kwargs):
|
||||
|
||||
@@ -14,6 +14,9 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.batch_invariant import (
|
||||
vllm_kernel_override_batch_invariant,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe import (
|
||||
FusedMoE,
|
||||
FusedMoEActivationFormat,
|
||||
@@ -353,6 +356,8 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
# Disable marlin for rocm
|
||||
if current_platform.is_rocm():
|
||||
self.use_marlin = False
|
||||
if vllm_kernel_override_batch_invariant():
|
||||
self.use_marlin = False
|
||||
|
||||
self.use_aiter_and_is_supported = check_aiter_fp8_linear_support()
|
||||
|
||||
@@ -534,6 +539,66 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
# If batch invariant mode is enabled, dequantize and use BF16 compute
|
||||
if vllm_kernel_override_batch_invariant():
|
||||
# Dequantize FP8 weights to BF16
|
||||
weight_fp8 = layer.weight.to(torch.bfloat16)
|
||||
weight_scale = layer.weight_scale.to(torch.bfloat16)
|
||||
|
||||
# Handle different quantization granularities
|
||||
if self.block_quant:
|
||||
# Block-wise quantization:
|
||||
# - Weight is NOT transposed, shape is [N, K] (output_size, input_size)
|
||||
# - Scale has shape [num_blocks_k, num_blocks_n] (TRANSPOSED!)
|
||||
assert self.weight_block_size is not None
|
||||
block_n, block_k = self.weight_block_size # Note: order is [N, K]
|
||||
|
||||
N, K = weight_fp8.shape
|
||||
|
||||
# Scale is stored transposed: [num_blocks_k, num_blocks_n]
|
||||
# We need to transpose it to [num_blocks_n, num_blocks_k] first
|
||||
weight_scale = weight_scale.t()
|
||||
|
||||
# Expand scale to match weight dimensions
|
||||
# scale_expanded should have shape [N, K]
|
||||
scale_expanded = weight_scale.repeat_interleave(
|
||||
block_n, dim=0
|
||||
).repeat_interleave(block_k, dim=1)
|
||||
# Trim to exact weight size (in case of padding)
|
||||
scale_expanded = scale_expanded[:N, :K]
|
||||
weight_bf16 = weight_fp8 * scale_expanded
|
||||
else:
|
||||
# Per-tensor quantization: weight IS transposed to [K, N]
|
||||
# scale should be scalar or [1] or per-output-channel [N]
|
||||
if weight_scale.numel() == 1:
|
||||
# Per-tensor: simple scalar multiplication
|
||||
weight_bf16 = weight_fp8 * weight_scale
|
||||
else:
|
||||
# Multiple scales (fused modules like QKV)
|
||||
# Try to infer correct broadcasting
|
||||
# weight is [K, N], scale could be [num_logical_weights]
|
||||
# Need to figure out how to broadcast - for now just try
|
||||
# direct multiplication
|
||||
if (
|
||||
weight_scale.dim() == 1
|
||||
and weight_scale.shape[0] == weight_fp8.shape[0]
|
||||
):
|
||||
# Per-row scaling
|
||||
weight_bf16 = weight_fp8 * weight_scale.unsqueeze(1)
|
||||
else:
|
||||
# Fallback
|
||||
weight_bf16 = weight_fp8 * weight_scale
|
||||
|
||||
# For block quant, weight is [N, K], for per-tensor it's [K, N]
|
||||
# F.linear expects weight to be [N, K], so:
|
||||
if self.block_quant:
|
||||
# Already in correct shape [N, K]
|
||||
output = torch.nn.functional.linear(x, weight_bf16, bias)
|
||||
else:
|
||||
# Need to transpose back: [K, N] -> [N, K]
|
||||
output = torch.nn.functional.linear(x, weight_bf16.t(), bias)
|
||||
return output
|
||||
|
||||
if self.use_marlin:
|
||||
return apply_fp8_marlin_linear(
|
||||
input=x,
|
||||
|
||||
Reference in New Issue
Block a user