[DCP] Support Decode Context Parallel (DCP) for GQA with Flashinfer (#25438)
Signed-off-by: gaojc <1055866782@qq.com> Signed-off-by: Jingchun Gao <gaojingchun1@huawei.com> Signed-off-by: Jingchun Gao <63247409+gjc0824@users.noreply.github.com> Signed-off-by: QiuChunshuo <qiuchunshuo@huawei.com> Co-authored-by: gaojingchun (A) <g00955623@china.huawei.com> Co-authored-by: Jingchun Gao <gaojingchun1@huawei.com> Co-authored-by: QiuChunshuo <qiuchunshuo@huawei.com>
This commit is contained in:
@@ -39,6 +39,7 @@ class ParallelSetup(NamedTuple):
|
||||
class CPTestOptions(NamedTuple):
|
||||
multi_node_only: bool
|
||||
load_format: str | None = None
|
||||
attn_backend: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -58,6 +59,7 @@ class CPTestSettings:
|
||||
multi_node_only: bool = False,
|
||||
runner: RunnerOption = "auto",
|
||||
load_format: str | None = None,
|
||||
attn_backend: str | None = None,
|
||||
):
|
||||
parallel_setups = []
|
||||
for eager_mode_val in [False]:
|
||||
@@ -79,7 +81,9 @@ class CPTestSettings:
|
||||
distributed_backends=["mp"],
|
||||
runner=runner,
|
||||
test_options=CPTestOptions(
|
||||
multi_node_only=multi_node_only, load_format=load_format
|
||||
multi_node_only=multi_node_only,
|
||||
load_format=load_format,
|
||||
attn_backend=attn_backend,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -117,7 +121,7 @@ def _compare_cp_with_tp(
|
||||
chunked_prefill,
|
||||
) = parallel_setup
|
||||
|
||||
multi_node_only, load_format = test_options
|
||||
multi_node_only, load_format, attn_backend = test_options
|
||||
|
||||
model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id)
|
||||
model_info.check_transformers_version(on_fail="skip")
|
||||
@@ -177,6 +181,13 @@ def _compare_cp_with_tp(
|
||||
if hf_overrides:
|
||||
common_args.extend(["--hf-overrides", json.dumps(hf_overrides)])
|
||||
|
||||
if not attn_backend:
|
||||
cp_env = tp_env = {}
|
||||
else:
|
||||
cp_env = tp_env = {
|
||||
"VLLM_ATTENTION_BACKEND": attn_backend,
|
||||
}
|
||||
|
||||
cp_args = [
|
||||
*common_args,
|
||||
"--tensor-parallel-size",
|
||||
@@ -205,6 +216,8 @@ def _compare_cp_with_tp(
|
||||
model_id,
|
||||
cp_args,
|
||||
tp_args,
|
||||
cp_env,
|
||||
tp_env,
|
||||
method=method,
|
||||
max_wait_seconds=720,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user