[ROCm][Misc] Follow-ups for Skinny Gemms on ROCm. (#17011)

Signed-off-by: charlifu <charlifu@amd.com>
This commit is contained in:
Charlie Fu
2025-04-26 00:05:10 -05:00
committed by GitHub
parent 9e96f56efb
commit 54271bb766
4 changed files with 18 additions and 15 deletions

View File

@@ -12,6 +12,7 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank,
tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase, method_has_implemented_embedding)
from vllm.model_executor.layers.utils import dispatch_unquantized_gemm
from vllm.model_executor.parameter import BasevLLMParameter
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
@@ -40,7 +41,7 @@ class UnquantizedEmbeddingMethod(QuantizeMethodBase):
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
return F.linear(x, layer.weight, bias)
return dispatch_unquantized_gemm()(x, layer.weight, bias)
def embedding(self, layer: torch.nn.Module,
input_: torch.Tensor) -> torch.Tensor: