[Mamba] Add stochastic rounding support (#35753)
Signed-off-by: Roi Koren <roik@nvidia.com>
This commit is contained in:
@@ -12,6 +12,7 @@ from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
|
||||
selective_scan_fn,
|
||||
selective_state_update,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import set_random_seed
|
||||
from vllm.v1.attention.backends.utils import PAD_SLOT_ID
|
||||
|
||||
@@ -429,6 +430,59 @@ def test_selective_state_update(dim, dstate, has_z, itype):
|
||||
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("philox_rounds", [0, 4])
|
||||
@pytest.mark.parametrize("has_z", [False, True])
|
||||
@pytest.mark.parametrize("dstate", [16, 64])
|
||||
@pytest.mark.parametrize("dim", [2048, 4096])
|
||||
@pytest.mark.skipif(
|
||||
not (
|
||||
current_platform.is_cuda() and current_platform.is_device_capability_family(100)
|
||||
),
|
||||
reason="Stochastic rounding in triton is only supported"
|
||||
" on compute capability 10.0 CUDA devices.",
|
||||
)
|
||||
def test_selective_state_update_stochastic_rounding(dim, dstate, has_z, philox_rounds):
|
||||
device = "cuda"
|
||||
rtol, atol = 5e-3, 1e-1
|
||||
# set seed
|
||||
set_random_seed(0)
|
||||
batch_size = 1
|
||||
state = torch.randn(batch_size, dim, dstate, dtype=torch.float16, device=device)
|
||||
x = torch.randn(batch_size, dim, device=device, dtype=torch.bfloat16)
|
||||
out = torch.empty_like(x)
|
||||
dt = torch.randn(batch_size, dim, device=device, dtype=torch.bfloat16)
|
||||
dt_bias = torch.rand(dim, device=device) - 4.0
|
||||
A = -torch.rand(dim, dstate, device=device) - 1.0
|
||||
B = torch.randn(batch_size, dstate, device=device)
|
||||
C = torch.randn(batch_size, dstate, device=device)
|
||||
D = torch.randn(dim, device=device)
|
||||
z = torch.randn_like(x) if has_z else None
|
||||
# Reference uses fp32 state to get ground truth
|
||||
state_ref = state.float()
|
||||
selective_state_update(
|
||||
state,
|
||||
x,
|
||||
dt,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
D=D,
|
||||
z=z,
|
||||
dt_bias=dt_bias,
|
||||
dt_softplus=True,
|
||||
out=out,
|
||||
enable_stochastic_rounding=True,
|
||||
cache_philox_rounds=philox_rounds,
|
||||
)
|
||||
out_ref = selective_state_update_ref(
|
||||
state_ref, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True
|
||||
)
|
||||
|
||||
assert state.dtype == torch.float16
|
||||
assert torch.allclose(state, state_ref.to(torch.float16), rtol=rtol, atol=atol)
|
||||
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("itype", [torch.float32, torch.bfloat16])
|
||||
@pytest.mark.parametrize("has_z", [False, True])
|
||||
@pytest.mark.parametrize("dstate", [16, 64])
|
||||
|
||||
@@ -109,12 +109,20 @@ class CacheConfig:
|
||||
mamba_cache_mode: MambaCacheMode = "none"
|
||||
"""The cache strategy for Mamba layers.
|
||||
- "none": set when prefix caching is disabled.
|
||||
- "all": cache the mamba state of all tokens at position i * block_size. This is
|
||||
- "all": cache the mamba state of all tokens at position i * block_size. This is
|
||||
the default behavior (for models that support it) when prefix caching is
|
||||
enabled.
|
||||
- "align": only cache the mamba state of the last token of each scheduler step and
|
||||
when the token is at position i * block_size.
|
||||
"""
|
||||
enable_mamba_cache_stochastic_rounding: bool = False
|
||||
"""Enable stochastic rounding when writing SSM state to fp16 cache.
|
||||
Uses random bits to unbias the rounding error, which can improve
|
||||
numerical stability for long sequences."""
|
||||
mamba_cache_philox_rounds: int = 0
|
||||
"""Number of Philox PRNG rounds for stochastic rounding random number
|
||||
generation. 0 uses the Triton default. Higher values improve randomness
|
||||
quality at the cost of compute."""
|
||||
|
||||
# Will be set after profiling.
|
||||
num_gpu_blocks: int | None = field(default=None, init=False)
|
||||
@@ -231,3 +239,29 @@ class CacheConfig:
|
||||
"scaling factor."
|
||||
)
|
||||
return cache_dtype
|
||||
|
||||
def __post_init__(self):
|
||||
if self.enable_mamba_cache_stochastic_rounding:
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if not current_platform.is_cuda():
|
||||
raise ValueError(
|
||||
"Stochastic rounding for Mamba cache is only supported "
|
||||
"on NVIDIA CUDA platforms. Please do not specify "
|
||||
"`--enable-mamba-cache-stochastic-rounding`."
|
||||
)
|
||||
if not current_platform.is_device_capability_family(100):
|
||||
raise ValueError(
|
||||
"Stochastic rounding for Mamba cache requires compute "
|
||||
"capability 10.0 (data center Blackwell). The `cvt.rs` PTX "
|
||||
"instruction is not supported on your GPU. Please do not specify "
|
||||
"`--enable-mamba-cache-stochastic-rounding`."
|
||||
)
|
||||
if self.mamba_ssm_cache_dtype != "float16":
|
||||
raise ValueError(
|
||||
"Stochastic rounding for Mamba cache requires "
|
||||
"the SSM cache to be float16. Please set it explicitly, "
|
||||
"by specifying `--mamba-ssm-cache-dtype float16`, or disable "
|
||||
"stochastic rounding by not specifying "
|
||||
"`--enable-mamba-cache-stochastic-rounding`."
|
||||
)
|
||||
|
||||
@@ -604,6 +604,10 @@ class EngineArgs:
|
||||
mamba_ssm_cache_dtype: MambaDType = CacheConfig.mamba_ssm_cache_dtype
|
||||
mamba_block_size: int | None = get_field(CacheConfig, "mamba_block_size")
|
||||
mamba_cache_mode: MambaCacheMode = CacheConfig.mamba_cache_mode
|
||||
enable_mamba_cache_stochastic_rounding: bool = (
|
||||
CacheConfig.enable_mamba_cache_stochastic_rounding
|
||||
)
|
||||
mamba_cache_philox_rounds: int = CacheConfig.mamba_cache_philox_rounds
|
||||
|
||||
additional_config: dict[str, Any] = get_field(VllmConfig, "additional_config")
|
||||
|
||||
@@ -1024,6 +1028,13 @@ class EngineArgs:
|
||||
cache_group.add_argument(
|
||||
"--mamba-cache-mode", **cache_kwargs["mamba_cache_mode"]
|
||||
)
|
||||
cache_group.add_argument(
|
||||
"--enable-mamba-cache-stochastic-rounding",
|
||||
**cache_kwargs["enable_mamba_cache_stochastic_rounding"],
|
||||
)
|
||||
cache_group.add_argument(
|
||||
"--mamba-cache-philox-rounds", **cache_kwargs["mamba_cache_philox_rounds"]
|
||||
)
|
||||
cache_group.add_argument(
|
||||
"--kv-offloading-size", **cache_kwargs["kv_offloading_size"]
|
||||
)
|
||||
@@ -1590,6 +1601,8 @@ class EngineArgs:
|
||||
mamba_ssm_cache_dtype=self.mamba_ssm_cache_dtype,
|
||||
mamba_block_size=self.mamba_block_size,
|
||||
mamba_cache_mode=self.mamba_cache_mode,
|
||||
enable_mamba_cache_stochastic_rounding=self.enable_mamba_cache_stochastic_rounding,
|
||||
mamba_cache_philox_rounds=self.mamba_cache_philox_rounds,
|
||||
kv_offloading_size=self.kv_offloading_size,
|
||||
kv_offloading_backend=self.kv_offloading_backend,
|
||||
)
|
||||
|
||||
@@ -428,6 +428,8 @@ class MambaMixer(MambaBase, PluggableLayer):
|
||||
state_batch_indices=state_indices_tensor_d_input,
|
||||
dst_state_batch_indices=state_indices_tensor_d_output,
|
||||
out=scan_outputs_d,
|
||||
enable_stochastic_rounding=self.cache_config.enable_mamba_cache_stochastic_rounding,
|
||||
cache_philox_rounds=self.cache_config.mamba_cache_philox_rounds,
|
||||
)
|
||||
scan_outputs_d = scan_outputs_d.transpose(0, 1)
|
||||
|
||||
|
||||
@@ -888,6 +888,8 @@ class MambaMixer2(MambaBase, PluggableLayer):
|
||||
num_accepted_tokens=num_accepted_tokens,
|
||||
cu_seqlens=query_start_loc_d,
|
||||
is_blackwell=self.is_blackwell,
|
||||
enable_stochastic_rounding=self.cache_config.enable_mamba_cache_stochastic_rounding,
|
||||
cache_philox_rounds=self.cache_config.mamba_cache_philox_rounds,
|
||||
)
|
||||
|
||||
def get_state_dtype(self) -> tuple[torch.dtype, torch.dtype]:
|
||||
|
||||
@@ -28,6 +28,21 @@ else:
|
||||
return dt
|
||||
|
||||
|
||||
@triton.jit
|
||||
def convert_rs_fp16x2(x: tl.tensor, rand: tl.tensor) -> tl.tensor:
|
||||
y = tl.inline_asm_elementwise(
|
||||
asm="""{
|
||||
cvt.rs.f16x2.f32 $0, $2, $1, $3;
|
||||
}""",
|
||||
constraints="=r,r,r,r,r",
|
||||
args=(x, rand),
|
||||
dtype=tl.float16,
|
||||
is_pure=True,
|
||||
pack=2,
|
||||
)
|
||||
return y
|
||||
|
||||
|
||||
@triton.heuristics({"HAS_DT_BIAS": lambda args: args["dt_bias_ptr"] is not None})
|
||||
@triton.heuristics({"HAS_D": lambda args: args["D_ptr"] is not None})
|
||||
@triton.heuristics({"HAS_Z": lambda args: args["z_ptr"] is not None})
|
||||
@@ -48,6 +63,7 @@ else:
|
||||
def _selective_scan_update_kernel(
|
||||
# Pointers to matrices
|
||||
state_ptr,
|
||||
rand_seed_ptr,
|
||||
x_ptr,
|
||||
dt_ptr,
|
||||
dt_bias_ptr,
|
||||
@@ -113,6 +129,8 @@ def _selective_scan_update_kernel(
|
||||
IS_SPEC_DECODING: tl.constexpr,
|
||||
IS_VARLEN: tl.constexpr,
|
||||
BLOCK_SIZE_DSTATE: tl.constexpr,
|
||||
USE_RS_ROUNDING: tl.constexpr,
|
||||
PHILOX_ROUNDS: tl.constexpr,
|
||||
):
|
||||
pid_m = tl.program_id(axis=0)
|
||||
pid_b = tl.program_id(axis=1)
|
||||
@@ -267,7 +285,35 @@ def _selective_scan_update_kernel(
|
||||
z_ptr += stride_z_batch
|
||||
|
||||
if not IS_SPEC_DECODING:
|
||||
tl.store(dst_state_ptrs, state.to(dst_state_ptrs.dtype.element_ty), mask=mask)
|
||||
if USE_RS_ROUNDING:
|
||||
# Load random seed
|
||||
rand_seed = tl.load(rand_seed_ptr)
|
||||
# Generate random offsets for each element in state
|
||||
if HAS_STATE_BATCH_INDICES:
|
||||
rand_offsets = (
|
||||
state_batch_idx * stride_state_batch + pid_h * stride_state_head
|
||||
)
|
||||
else:
|
||||
rand_offsets = pid_b * stride_state_batch + pid_h * stride_state_head
|
||||
rand_offsets += (
|
||||
offs_m[:, None] * stride_state_dim
|
||||
+ offs_n[None, :] * stride_state_dstate
|
||||
)
|
||||
# Generate random 32-bits for each element in state
|
||||
if PHILOX_ROUNDS > 0:
|
||||
rand = tl.randint(rand_seed, rand_offsets, PHILOX_ROUNDS)
|
||||
else:
|
||||
rand = tl.randint(rand_seed, rand_offsets)
|
||||
# Convert state to fp16 with RS rounding
|
||||
state = convert_rs_fp16x2(state, rand)
|
||||
tl.static_assert(state.dtype == tl.float16, "state must be fp16")
|
||||
tl.static_assert(
|
||||
dst_state_ptrs.dtype.element_ty == tl.float16,
|
||||
"dst_state_ptrs must be fp16",
|
||||
)
|
||||
else:
|
||||
state = state.to(dst_state_ptrs.dtype.element_ty)
|
||||
tl.store(dst_state_ptrs, state, mask=mask)
|
||||
|
||||
|
||||
def selective_state_update(
|
||||
@@ -288,6 +334,8 @@ def selective_state_update(
|
||||
num_accepted_tokens=None,
|
||||
cu_seqlens=None,
|
||||
is_blackwell=False,
|
||||
enable_stochastic_rounding=False,
|
||||
cache_philox_rounds=0,
|
||||
):
|
||||
"""
|
||||
Argument:
|
||||
@@ -419,9 +467,16 @@ def selective_state_update(
|
||||
and dt.stride(-1) == 0
|
||||
and dt_bias.stride(-1) == 0
|
||||
)
|
||||
rand_seed = (
|
||||
torch.randint(0, 2**32, (1,), device=state.device)
|
||||
if enable_stochastic_rounding
|
||||
else None
|
||||
)
|
||||
|
||||
with torch.accelerator.device_index(x.device.index):
|
||||
_selective_scan_update_kernel[grid](
|
||||
state,
|
||||
rand_seed,
|
||||
x,
|
||||
dt,
|
||||
dt_bias,
|
||||
@@ -476,6 +531,8 @@ def selective_state_update(
|
||||
tie_hdim,
|
||||
BLOCK_SIZE_M,
|
||||
num_warps=num_warps,
|
||||
USE_RS_ROUNDING=enable_stochastic_rounding,
|
||||
PHILOX_ROUNDS=cache_philox_rounds,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -445,6 +445,8 @@ class Plamo2MambaMixer(MambaBase, PluggableLayer):
|
||||
dt_softplus=True,
|
||||
state_batch_indices=state_indices_tensor_d,
|
||||
out=preallocated_ssm_out_d.view(num_decodes, -1, self.head_dim),
|
||||
enable_stochastic_rounding=self.cache_config.enable_mamba_cache_stochastic_rounding,
|
||||
cache_philox_rounds=self.cache_config.mamba_cache_philox_rounds,
|
||||
)
|
||||
|
||||
# 4. Final linear projection
|
||||
|
||||
Reference in New Issue
Block a user