[ROCm][Misc] Follow-ups for Skinny Gemms on ROCm. (#17011)
Signed-off-by: charlifu <charlifu@amd.com>
This commit is contained in:
@@ -155,8 +155,9 @@ def rocm_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor,
|
||||
scale_b: torch.Tensor, bias: torch.Tensor,
|
||||
input_2d: torch.Tensor,
|
||||
output_shape: List) -> torch.Tensor:
|
||||
if envs.VLLM_ROCM_USE_SKINNY_GEMM and qinput.shape[
|
||||
0] == 1 and qinput.shape[1] % 16 == 0:
|
||||
from vllm.platforms.rocm import on_mi250_mi300
|
||||
if envs.VLLM_ROCM_USE_SKINNY_GEMM and not on_mi250_mi300(
|
||||
) and qinput.shape[0] == 1 and qinput.shape[1] % 16 == 0:
|
||||
output = ops.wvSplitKQ(weight.t(), qinput, out_dtype, scale_a, scale_b,
|
||||
current_platform.get_cu_count())
|
||||
else:
|
||||
@@ -371,7 +372,7 @@ class Fp8LinearOp:
|
||||
|
||||
return w8a8_scaled_mm_func(qinput=qinput,
|
||||
weight=weight,
|
||||
out_dtype=input.dtype,
|
||||
out_dtype=out_dtype,
|
||||
scale_a=x_scale,
|
||||
scale_b=weight_scale,
|
||||
bias=bias,
|
||||
|
||||
Reference in New Issue
Block a user