[Bugfix][Model] Fix gpt-oss batch invariance (#35404)
Signed-off-by: Jakub Zakrzewski <jzakrzewski@nvidia.com>
This commit is contained in:
@@ -28,7 +28,6 @@ from vllm.model_executor.layers.quantization.base_config import (
|
||||
)
|
||||
from vllm.model_executor.layers.utils import (
|
||||
dispatch_unquantized_gemm,
|
||||
is_layer_moe_router_gate,
|
||||
)
|
||||
from vllm.model_executor.parameter import (
|
||||
BasevLLMParameter,
|
||||
@@ -257,11 +256,7 @@ class UnquantizedLinearMethod(LinearMethodBase):
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
if (
|
||||
vllm_is_batch_invariant()
|
||||
and current_platform.is_cuda_alike()
|
||||
and is_layer_moe_router_gate(getattr(layer, "prefix", ""))
|
||||
):
|
||||
if vllm_is_batch_invariant() and current_platform.is_cuda_alike():
|
||||
return linear_batch_invariant(x, layer.weight, bias)
|
||||
return dispatch_unquantized_gemm()(layer, x, layer.weight, bias)
|
||||
|
||||
|
||||
@@ -23,7 +23,11 @@ from vllm.model_executor.layers.attention import Attention
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEParallelConfig
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import QKVParallelLinear, RowParallelLinear
|
||||
from vllm.model_executor.layers.linear import (
|
||||
QKVParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import OCP_MX_BLOCK_SIZE
|
||||
@@ -165,7 +169,14 @@ class MLPBlock(torch.nn.Module):
|
||||
self.hidden_size = config.hidden_size
|
||||
self.experts_per_token = config.num_experts_per_tok
|
||||
self.world_size = dist.get_world_size() if dist.is_initialized() else 1
|
||||
self.router = torch.nn.Linear(config.hidden_size, config.num_local_experts)
|
||||
self.router = ReplicatedLinear(
|
||||
config.hidden_size,
|
||||
config.num_local_experts,
|
||||
bias=True,
|
||||
quant_config=None,
|
||||
prefix=f"{prefix}.router",
|
||||
return_bias=False,
|
||||
)
|
||||
assert config.intermediate_size % self.world_size == 0
|
||||
self.experts = FusedMoE(
|
||||
num_experts=config.num_local_experts,
|
||||
|
||||
Reference in New Issue
Block a user