[ROCm] gemm_a16w16 upstreaming (#26969)
Signed-off-by: Aleksandr Malyshev <maleksan@amd.com> Co-authored-by: Aleksandr Malyshev <maleksan@amd.com>
This commit is contained in:
committed by
GitHub
parent
1fb4217a05
commit
2d977a7a9e
@@ -25,12 +25,14 @@ from vllm.model_executor.layers.linear import QKVParallelLinear, RowParallelLine
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.utils import rocm_unquantized_gemm
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead,
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.utils import sequence_parallel_chunk
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils.math_utils import cdiv
|
||||
|
||||
@@ -153,6 +155,7 @@ class MLPBlock(torch.nn.Module):
|
||||
|
||||
self.layer_idx = layer_idx
|
||||
self.num_experts = config.num_local_experts
|
||||
self.hidden_size = config.hidden_size
|
||||
self.experts_per_token = config.num_experts_per_tok
|
||||
self.world_size = dist.get_world_size() if dist.is_initialized() else 1
|
||||
self.router = torch.nn.Linear(config.hidden_size, config.num_local_experts)
|
||||
@@ -177,7 +180,12 @@ class MLPBlock(torch.nn.Module):
|
||||
if self.is_sequence_parallel:
|
||||
x = sequence_parallel_chunk(x)
|
||||
|
||||
g = self.router(x)
|
||||
if current_platform.is_rocm():
|
||||
g = rocm_unquantized_gemm(
|
||||
self, x[:, : self.hidden_size], self.router.weight, self.router.bias
|
||||
)
|
||||
else:
|
||||
g = self.router(x)
|
||||
x = self.experts(hidden_states=x, router_logits=g)
|
||||
|
||||
if self.is_sequence_parallel:
|
||||
|
||||
Reference in New Issue
Block a user