Add the support for the qwen3 next model (a hybrid attention model). (#24526)
Signed-off-by: Tao He <linzhu.ht@alibaba-inc.com> Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
@@ -156,9 +156,14 @@ class BlockTable:
|
||||
class MultiGroupBlockTable:
|
||||
"""The BlockTables for each KV cache group."""
|
||||
|
||||
def __init__(self, max_num_reqs: int, max_model_len: int,
|
||||
max_num_batched_tokens: int, pin_memory: bool,
|
||||
device: torch.device, block_sizes: list[int]) -> None:
|
||||
def __init__(self,
|
||||
max_num_reqs: int,
|
||||
max_model_len: int,
|
||||
max_num_batched_tokens: int,
|
||||
pin_memory: bool,
|
||||
device: torch.device,
|
||||
block_sizes: list[int],
|
||||
num_speculative_tokens: int = 0) -> None:
|
||||
# Note(hc): each dcp rank only store
|
||||
# (max_model_len//dcp_world_size) tokens in kvcache,
|
||||
# so the block_size which used for calc max_num_blocks_per_req
|
||||
@@ -170,10 +175,11 @@ class MultiGroupBlockTable:
|
||||
dcp_world_size = 1
|
||||
|
||||
self.block_tables = [
|
||||
BlockTable(block_size, max_num_reqs,
|
||||
cdiv(max_model_len, block_size * dcp_world_size),
|
||||
max_num_batched_tokens, pin_memory, device)
|
||||
for block_size in block_sizes
|
||||
BlockTable(
|
||||
block_size, max_num_reqs,
|
||||
max(cdiv(max_model_len, block_size * dcp_world_size),
|
||||
1 + num_speculative_tokens), max_num_batched_tokens,
|
||||
pin_memory, device) for block_size in block_sizes
|
||||
]
|
||||
|
||||
def append_row(self, block_ids: tuple[list[int], ...],
|
||||
|
||||
Reference in New Issue
Block a user