[Feature] Add async tensor parallelism for scaled mm (#20155)

Signed-off-by: cascade812 <cascade812@outlook.com>
This commit is contained in:
cascade
2025-07-30 14:23:41 -07:00
committed by GitHub
parent f12d9256b3
commit 287f527f54
3 changed files with 381 additions and 8 deletions

View File

@@ -22,6 +22,8 @@ from ..utils import (compare_two_settings, create_new_process_for_each_test,
multi_gpu_test)
from .backend import TestBackend
FP8_DTYPE = current_platform.fp8_dtype()
prompts = [
"Hello, my name is",
"The president of the United States is",
@@ -32,9 +34,10 @@ prompts = [
class TestMMRSModel(torch.nn.Module):
def __init__(self, hidden_size=16):
def __init__(self, hidden_size=16, dtype=torch.float16):
super().__init__()
self.hidden_size = hidden_size
self.dtype = dtype
self.gate_proj = torch.nn.Parameter(torch.empty(
(self.hidden_size * 2, hidden_size)),
requires_grad=False)
@@ -64,9 +67,10 @@ class TestMMRSModel(torch.nn.Module):
class TestAGMMModel(torch.nn.Module):
def __init__(self, hidden_size=16):
def __init__(self, hidden_size=16, dtype=torch.float16):
super().__init__()
self.hidden_size = hidden_size
self.dtype = dtype
self.weight = torch.nn.Parameter(torch.empty(
(hidden_size, hidden_size)),
requires_grad=False)
@@ -91,8 +95,125 @@ class TestAGMMModel(torch.nn.Module):
return [torch.ops.symm_mem.fused_all_gather_matmul.default]
class _BaseScaledMMModel(torch.nn.Module):
def __init__(self, hidden_size=16, dtype=torch.float16):
super().__init__()
self.hidden_size = hidden_size
self.dtype = dtype
self.weight = torch.empty([hidden_size, hidden_size], dtype=FP8_DTYPE)\
.contiguous().transpose(0, 1)
# Initialize scale_b for _scaled_mm.
self.scale_b = torch.ones(1, self.hidden_size, dtype=torch.float32)
class TestScaledMMRSModel(_BaseScaledMMModel):
def forward(self, input: torch.Tensor):
"""
Forward pass implementing the scaled_mm + reduce scatter in the FX graph
"""
fp8_input = input.to(FP8_DTYPE)
scale_a = torch.ones(input.shape[0], 1, dtype=torch.float32)
scaled_mm = torch._scaled_mm(fp8_input,
self.weight,
scale_a=scale_a,
scale_b=self.scale_b,
out_dtype=self.dtype)
reduce_scatter = tensor_model_parallel_reduce_scatter(scaled_mm, dim=0)
return reduce_scatter
def ops_in_model_before(self):
return [torch.ops.vllm.reduce_scatter.default]
def ops_in_model_after(self):
return [torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter.default]
class TestAGScaledMMModel(_BaseScaledMMModel):
def forward(self, input: torch.Tensor):
"""
Forward pass implementing the all gather + scaled_mm in the FX graph
"""
# Reshape input
fp8_input = input.to(FP8_DTYPE)
all_gather = tensor_model_parallel_all_gather(fp8_input, dim=0)
scale_a = torch.ones(all_gather.shape[0], 1, dtype=torch.float32)
scaled_mm = torch._scaled_mm(all_gather,
self.weight,
scale_a=scale_a,
scale_b=self.scale_b,
out_dtype=self.dtype)
return scaled_mm
def ops_in_model_before(self):
return [torch.ops.vllm.all_gather.default]
def ops_in_model_after(self):
return [torch.ops.symm_mem.fused_all_gather_scaled_matmul.default]
class TestCutlassScaledMMRSModel(_BaseScaledMMModel):
def forward(self, input: torch.Tensor):
"""
Forward pass implementing the cutlass_scaled_mm + reduce scatter
in the FX graph
"""
fp8_input = input.to(FP8_DTYPE)
scale_a = torch.ones(input.shape[0], 1, dtype=torch.float32)
mm_out = torch.empty((fp8_input.shape[0], self.weight.shape[1]),
dtype=self.dtype,
device=input.device)
torch.ops._C.cutlass_scaled_mm(mm_out, fp8_input, self.weight, scale_a,
self.scale_b, None)
reduce_scatter = tensor_model_parallel_reduce_scatter(mm_out, dim=0)
return reduce_scatter
def ops_in_model_before(self):
return [torch.ops.vllm.reduce_scatter.default]
def ops_in_model_after(self):
return [torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter.default]
class TestAGCutlassScaledMMModel(_BaseScaledMMModel):
def forward(self, input: torch.Tensor):
"""
Forward pass implementing the all gather + cutlass_scaled_mm
in the FX graph
"""
# Reshape input
fp8_input = input.to(FP8_DTYPE)
all_gather = tensor_model_parallel_all_gather(fp8_input, dim=0)
scale_a = torch.ones(all_gather.shape[0], 1, dtype=torch.float32)
mm_out = torch.empty((all_gather.shape[0], self.weight.shape[1]),
dtype=self.dtype,
device=all_gather.device)
torch.ops._C.cutlass_scaled_mm(mm_out, all_gather, self.weight,
scale_a, self.scale_b, None)
return mm_out
def ops_in_model_before(self):
return [torch.ops.vllm.all_gather.default]
def ops_in_model_after(self):
return [torch.ops.symm_mem.fused_all_gather_scaled_matmul.default]
@multi_gpu_test(num_gpus=2)
@pytest.mark.parametrize("test_model", [TestMMRSModel, TestAGMMModel])
@pytest.mark.parametrize("test_model", [
TestMMRSModel, TestAGMMModel, TestScaledMMRSModel, TestAGScaledMMModel,
TestCutlassScaledMMRSModel, TestAGCutlassScaledMMModel
])
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize("seq_len", [16])
@pytest.mark.parametrize("hidden_size", [16])
@@ -101,6 +222,14 @@ class TestAGMMModel(torch.nn.Module):
reason="Only test on CUDA")
def test_async_tp_pass_replace(test_model: str, batch_size: int, seq_len: int,
hidden_size: int, dtype: torch.dtype):
if test_model in (TestScaledMMRSModel, TestAGScaledMMModel,
TestCutlassScaledMMRSModel,
TestAGCutlassScaledMMModel) and dtype == torch.float16:
pytest.skip(
"Only bf16 high precision output types are supported for " \
"per-token (row-wise) scaling"
)
num_processes = 2
def run_torch_spawn(fn, nprocs):
@@ -155,7 +284,8 @@ def async_tp_pass_on_test_model(local_rank: int, world_size: int,
async_tp_pass = AsyncTPPass(vllm_config)
backend = TestBackend(async_tp_pass)
model = test_model_cls(hidden_size)
model = test_model_cls(hidden_size,
dtype) # Pass dtype to model constructor
hidden_states = torch.randn((batch_size * seq_len, hidden_size),
dtype=dtype,
@@ -174,7 +304,10 @@ def async_tp_pass_on_test_model(local_rank: int, world_size: int,
@create_new_process_for_each_test()
@pytest.mark.parametrize("model_id", ["meta-llama/Llama-3.2-1B-Instruct"])
@pytest.mark.parametrize("model_id", [
"meta-llama/Llama-3.2-1B-Instruct",
"RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8"
])
@pytest.mark.parametrize("tp_size", [2])
@pytest.mark.parametrize("async_tp_enabled", [True])
@pytest.mark.parametrize("distributed_backend", ["mp"])