[Feature] Add async tensor parallelism for scaled mm (#20155)
Signed-off-by: cascade812 <cascade812@outlook.com>
This commit is contained in:
@@ -22,6 +22,8 @@ from ..utils import (compare_two_settings, create_new_process_for_each_test,
|
||||
multi_gpu_test)
|
||||
from .backend import TestBackend
|
||||
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
@@ -32,9 +34,10 @@ prompts = [
|
||||
|
||||
class TestMMRSModel(torch.nn.Module):
|
||||
|
||||
def __init__(self, hidden_size=16):
|
||||
def __init__(self, hidden_size=16, dtype=torch.float16):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.dtype = dtype
|
||||
self.gate_proj = torch.nn.Parameter(torch.empty(
|
||||
(self.hidden_size * 2, hidden_size)),
|
||||
requires_grad=False)
|
||||
@@ -64,9 +67,10 @@ class TestMMRSModel(torch.nn.Module):
|
||||
|
||||
class TestAGMMModel(torch.nn.Module):
|
||||
|
||||
def __init__(self, hidden_size=16):
|
||||
def __init__(self, hidden_size=16, dtype=torch.float16):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.dtype = dtype
|
||||
self.weight = torch.nn.Parameter(torch.empty(
|
||||
(hidden_size, hidden_size)),
|
||||
requires_grad=False)
|
||||
@@ -91,8 +95,125 @@ class TestAGMMModel(torch.nn.Module):
|
||||
return [torch.ops.symm_mem.fused_all_gather_matmul.default]
|
||||
|
||||
|
||||
class _BaseScaledMMModel(torch.nn.Module):
|
||||
|
||||
def __init__(self, hidden_size=16, dtype=torch.float16):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.dtype = dtype
|
||||
self.weight = torch.empty([hidden_size, hidden_size], dtype=FP8_DTYPE)\
|
||||
.contiguous().transpose(0, 1)
|
||||
|
||||
# Initialize scale_b for _scaled_mm.
|
||||
self.scale_b = torch.ones(1, self.hidden_size, dtype=torch.float32)
|
||||
|
||||
|
||||
class TestScaledMMRSModel(_BaseScaledMMModel):
|
||||
|
||||
def forward(self, input: torch.Tensor):
|
||||
"""
|
||||
Forward pass implementing the scaled_mm + reduce scatter in the FX graph
|
||||
|
||||
"""
|
||||
fp8_input = input.to(FP8_DTYPE)
|
||||
scale_a = torch.ones(input.shape[0], 1, dtype=torch.float32)
|
||||
scaled_mm = torch._scaled_mm(fp8_input,
|
||||
self.weight,
|
||||
scale_a=scale_a,
|
||||
scale_b=self.scale_b,
|
||||
out_dtype=self.dtype)
|
||||
reduce_scatter = tensor_model_parallel_reduce_scatter(scaled_mm, dim=0)
|
||||
return reduce_scatter
|
||||
|
||||
def ops_in_model_before(self):
|
||||
return [torch.ops.vllm.reduce_scatter.default]
|
||||
|
||||
def ops_in_model_after(self):
|
||||
return [torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter.default]
|
||||
|
||||
|
||||
class TestAGScaledMMModel(_BaseScaledMMModel):
|
||||
|
||||
def forward(self, input: torch.Tensor):
|
||||
"""
|
||||
Forward pass implementing the all gather + scaled_mm in the FX graph
|
||||
"""
|
||||
# Reshape input
|
||||
fp8_input = input.to(FP8_DTYPE)
|
||||
all_gather = tensor_model_parallel_all_gather(fp8_input, dim=0)
|
||||
|
||||
scale_a = torch.ones(all_gather.shape[0], 1, dtype=torch.float32)
|
||||
scaled_mm = torch._scaled_mm(all_gather,
|
||||
self.weight,
|
||||
scale_a=scale_a,
|
||||
scale_b=self.scale_b,
|
||||
out_dtype=self.dtype)
|
||||
return scaled_mm
|
||||
|
||||
def ops_in_model_before(self):
|
||||
return [torch.ops.vllm.all_gather.default]
|
||||
|
||||
def ops_in_model_after(self):
|
||||
return [torch.ops.symm_mem.fused_all_gather_scaled_matmul.default]
|
||||
|
||||
|
||||
class TestCutlassScaledMMRSModel(_BaseScaledMMModel):
|
||||
|
||||
def forward(self, input: torch.Tensor):
|
||||
"""
|
||||
Forward pass implementing the cutlass_scaled_mm + reduce scatter
|
||||
in the FX graph
|
||||
|
||||
"""
|
||||
fp8_input = input.to(FP8_DTYPE)
|
||||
scale_a = torch.ones(input.shape[0], 1, dtype=torch.float32)
|
||||
mm_out = torch.empty((fp8_input.shape[0], self.weight.shape[1]),
|
||||
dtype=self.dtype,
|
||||
device=input.device)
|
||||
torch.ops._C.cutlass_scaled_mm(mm_out, fp8_input, self.weight, scale_a,
|
||||
self.scale_b, None)
|
||||
reduce_scatter = tensor_model_parallel_reduce_scatter(mm_out, dim=0)
|
||||
return reduce_scatter
|
||||
|
||||
def ops_in_model_before(self):
|
||||
return [torch.ops.vllm.reduce_scatter.default]
|
||||
|
||||
def ops_in_model_after(self):
|
||||
return [torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter.default]
|
||||
|
||||
|
||||
class TestAGCutlassScaledMMModel(_BaseScaledMMModel):
|
||||
|
||||
def forward(self, input: torch.Tensor):
|
||||
"""
|
||||
Forward pass implementing the all gather + cutlass_scaled_mm
|
||||
in the FX graph
|
||||
"""
|
||||
# Reshape input
|
||||
fp8_input = input.to(FP8_DTYPE)
|
||||
all_gather = tensor_model_parallel_all_gather(fp8_input, dim=0)
|
||||
|
||||
scale_a = torch.ones(all_gather.shape[0], 1, dtype=torch.float32)
|
||||
|
||||
mm_out = torch.empty((all_gather.shape[0], self.weight.shape[1]),
|
||||
dtype=self.dtype,
|
||||
device=all_gather.device)
|
||||
torch.ops._C.cutlass_scaled_mm(mm_out, all_gather, self.weight,
|
||||
scale_a, self.scale_b, None)
|
||||
return mm_out
|
||||
|
||||
def ops_in_model_before(self):
|
||||
return [torch.ops.vllm.all_gather.default]
|
||||
|
||||
def ops_in_model_after(self):
|
||||
return [torch.ops.symm_mem.fused_all_gather_scaled_matmul.default]
|
||||
|
||||
|
||||
@multi_gpu_test(num_gpus=2)
|
||||
@pytest.mark.parametrize("test_model", [TestMMRSModel, TestAGMMModel])
|
||||
@pytest.mark.parametrize("test_model", [
|
||||
TestMMRSModel, TestAGMMModel, TestScaledMMRSModel, TestAGScaledMMModel,
|
||||
TestCutlassScaledMMRSModel, TestAGCutlassScaledMMModel
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [8])
|
||||
@pytest.mark.parametrize("seq_len", [16])
|
||||
@pytest.mark.parametrize("hidden_size", [16])
|
||||
@@ -101,6 +222,14 @@ class TestAGMMModel(torch.nn.Module):
|
||||
reason="Only test on CUDA")
|
||||
def test_async_tp_pass_replace(test_model: str, batch_size: int, seq_len: int,
|
||||
hidden_size: int, dtype: torch.dtype):
|
||||
if test_model in (TestScaledMMRSModel, TestAGScaledMMModel,
|
||||
TestCutlassScaledMMRSModel,
|
||||
TestAGCutlassScaledMMModel) and dtype == torch.float16:
|
||||
pytest.skip(
|
||||
"Only bf16 high precision output types are supported for " \
|
||||
"per-token (row-wise) scaling"
|
||||
)
|
||||
|
||||
num_processes = 2
|
||||
|
||||
def run_torch_spawn(fn, nprocs):
|
||||
@@ -155,7 +284,8 @@ def async_tp_pass_on_test_model(local_rank: int, world_size: int,
|
||||
async_tp_pass = AsyncTPPass(vllm_config)
|
||||
backend = TestBackend(async_tp_pass)
|
||||
|
||||
model = test_model_cls(hidden_size)
|
||||
model = test_model_cls(hidden_size,
|
||||
dtype) # Pass dtype to model constructor
|
||||
|
||||
hidden_states = torch.randn((batch_size * seq_len, hidden_size),
|
||||
dtype=dtype,
|
||||
@@ -174,7 +304,10 @@ def async_tp_pass_on_test_model(local_rank: int, world_size: int,
|
||||
|
||||
|
||||
@create_new_process_for_each_test()
|
||||
@pytest.mark.parametrize("model_id", ["meta-llama/Llama-3.2-1B-Instruct"])
|
||||
@pytest.mark.parametrize("model_id", [
|
||||
"meta-llama/Llama-3.2-1B-Instruct",
|
||||
"RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8"
|
||||
])
|
||||
@pytest.mark.parametrize("tp_size", [2])
|
||||
@pytest.mark.parametrize("async_tp_enabled", [True])
|
||||
@pytest.mark.parametrize("distributed_backend", ["mp"])
|
||||
|
||||
Reference in New Issue
Block a user