[Bugfix][Mamba] - Fix Conv State Kernel FP32 Support (#24883)
Signed-off-by: asafg <39553475+Josephasafg@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
3ed1ec4af2
commit
66072b36db
@@ -418,7 +418,9 @@ def test_full_cuda_graph(
|
||||
@pytest.mark.parametrize("model", FP32_STATE_MODELS)
|
||||
@pytest.mark.parametrize("max_tokens", [64])
|
||||
@pytest.mark.parametrize("num_logprobs", [5])
|
||||
def test_fp32_state(
|
||||
@pytest.mark.parametrize("cache_dtype_param",
|
||||
["mamba_ssm_cache_dtype", "mamba_cache_dtype"])
|
||||
def test_fp32_cache_state(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
@@ -426,6 +428,7 @@ def test_fp32_state(
|
||||
model: str,
|
||||
max_tokens: int,
|
||||
num_logprobs: int,
|
||||
cache_dtype_param: str,
|
||||
) -> None:
|
||||
|
||||
try:
|
||||
@@ -443,13 +446,13 @@ def test_fp32_state(
|
||||
m.setenv("VLLM_USE_V1", "0")
|
||||
with vllm_runner(model,
|
||||
max_num_seqs=MAX_NUM_SEQS,
|
||||
mamba_ssm_cache_dtype="float32") as vllm_model:
|
||||
**{cache_dtype_param: "float32"}) as vllm_model:
|
||||
vllm_v0_outputs = vllm_model.generate_greedy_logprobs(
|
||||
example_prompts, max_tokens, num_logprobs)
|
||||
|
||||
with vllm_runner(model,
|
||||
max_num_seqs=MAX_NUM_SEQS,
|
||||
mamba_ssm_cache_dtype="float32") as vllm_model:
|
||||
**{cache_dtype_param: "float32"}) as vllm_model:
|
||||
vllm_v1_outputs = vllm_model.generate_greedy_logprobs(
|
||||
example_prompts, max_tokens, num_logprobs)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user