[Bugfix][Mamba] - Fix Conv State Kernel FP32 Support (#24883)
Signed-off-by: asafg <39553475+Josephasafg@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
3ed1ec4af2
commit
66072b36db
@@ -415,6 +415,9 @@ def causal_conv1d_fn(
|
||||
activation = "silu"
|
||||
|
||||
args = None
|
||||
# Store original dtype to cast back at the end
|
||||
original_x_dtype = x.dtype
|
||||
x = x.to(conv_states.dtype)
|
||||
out = torch.empty_like(x)
|
||||
if metadata is not None:
|
||||
cu_seqlen = metadata.cu_seqlen
|
||||
@@ -613,7 +616,7 @@ def causal_conv1d_fn(
|
||||
BLOCK_N=256,
|
||||
num_stages=2,
|
||||
)
|
||||
return out
|
||||
return out.to(original_x_dtype)
|
||||
|
||||
|
||||
@triton.jit()
|
||||
@@ -973,6 +976,9 @@ def causal_conv1d_update(
|
||||
activation = "silu" if activation is True else None
|
||||
elif activation is not None:
|
||||
assert activation in ["silu", "swish"]
|
||||
|
||||
original_x_dtype = x.dtype
|
||||
x = x.to(conv_state.dtype)
|
||||
unsqueeze = query_start_loc is None and x.dim() == 2
|
||||
if unsqueeze:
|
||||
# make it (batch, dim, seqlen) with seqlen == 1
|
||||
@@ -1081,4 +1087,4 @@ def causal_conv1d_update(
|
||||
)
|
||||
if unsqueeze:
|
||||
out = out.squeeze(-1)
|
||||
return out
|
||||
return out.to(original_x_dtype)
|
||||
|
||||
Reference in New Issue
Block a user