[ROCm] gemm_a16w16 upstreaming (#26969)

Signed-off-by: Aleksandr Malyshev <maleksan@amd.com>
Co-authored-by: Aleksandr Malyshev <maleksan@amd.com>
This commit is contained in:
Aleksandr Malyshev
2025-11-04 13:01:00 -08:00
committed by GitHub
parent 1fb4217a05
commit 2d977a7a9e
2 changed files with 43 additions and 9 deletions

View File

@@ -25,12 +25,14 @@ from vllm.model_executor.layers.linear import QKVParallelLinear, RowParallelLine
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.utils import rocm_unquantized_gemm
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.utils import sequence_parallel_chunk
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.utils.math_utils import cdiv
@@ -153,6 +155,7 @@ class MLPBlock(torch.nn.Module):
self.layer_idx = layer_idx
self.num_experts = config.num_local_experts
self.hidden_size = config.hidden_size
self.experts_per_token = config.num_experts_per_tok
self.world_size = dist.get_world_size() if dist.is_initialized() else 1
self.router = torch.nn.Linear(config.hidden_size, config.num_local_experts)
@@ -177,7 +180,12 @@ class MLPBlock(torch.nn.Module):
if self.is_sequence_parallel:
x = sequence_parallel_chunk(x)
g = self.router(x)
if current_platform.is_rocm():
g = rocm_unquantized_gemm(
self, x[:, : self.hidden_size], self.router.weight, self.router.bias
)
else:
g = self.router(x)
x = self.experts(hidden_states=x, router_logits=g)
if self.is_sequence_parallel: