[ROCm][Misc] Follow-ups for Skinny Gemms on ROCm. (#17011)
Signed-off-by: charlifu <charlifu@amd.com>
This commit is contained in:
@@ -12,6 +12,7 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank,
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase, method_has_implemented_embedding)
|
||||
from vllm.model_executor.layers.utils import dispatch_unquantized_gemm
|
||||
from vllm.model_executor.parameter import BasevLLMParameter
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
@@ -40,7 +41,7 @@ class UnquantizedEmbeddingMethod(QuantizeMethodBase):
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
return F.linear(x, layer.weight, bias)
|
||||
return dispatch_unquantized_gemm()(x, layer.weight, bias)
|
||||
|
||||
def embedding(self, layer: torch.nn.Module,
|
||||
input_: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
Reference in New Issue
Block a user