[Perf] Optimize token_embed for pooling models, 1.0% token throughput improvement (#37347)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Wentao Ye
2026-03-18 22:27:51 -04:00
committed by GitHub
parent 6accb21f2a
commit e37ff5b5c8
2 changed files with 24 additions and 6 deletions

View File

@@ -101,6 +101,7 @@ class PoolingMetadata:
num_scheduled_tokens_np: np.ndarray,
seq_lens_cpu: torch.Tensor,
device: torch.device,
query_start_loc_gpu: torch.Tensor | None = None,
):
n_seq = len(num_scheduled_tokens_np)
prompt_lens = self.prompt_lens
@@ -109,11 +110,25 @@ class PoolingMetadata:
index = list(range(n_seq))
num_scheduled_tokens_cpu = torch.from_numpy(num_scheduled_tokens_np)
cumsum = torch.zeros(
n_seq + 1, dtype=torch.int64, pin_memory=pin_memory, device="cpu"
)
torch.cumsum(num_scheduled_tokens_cpu, dim=0, out=cumsum[1:])
cumsum = cumsum.to(device, non_blocking=True)
if query_start_loc_gpu is None:
cumsum = torch.zeros(
n_seq + 1, dtype=torch.int64, pin_memory=pin_memory, device="cpu"
)
torch.cumsum(num_scheduled_tokens_cpu, dim=0, out=cumsum[1:])
cumsum = cumsum.to(device, non_blocking=True)
else:
if query_start_loc_gpu.shape[0] != n_seq + 1:
raise ValueError(
"query_start_loc_gpu length does not match "
f"the number of sequences: {query_start_loc_gpu.shape[0]} "
f"!= {n_seq + 1}."
)
if query_start_loc_gpu.device != device:
raise ValueError(
"query_start_loc_gpu must be on the same device as the "
f"hidden states: {query_start_loc_gpu.device} != {device}."
)
cumsum = query_start_loc_gpu
self.pooling_cursor = PoolingCursor(
index=index,
first_token_indices_gpu=cumsum[:n_seq],

View File

@@ -2928,7 +2928,10 @@ class GPUModelRunner(
pooling_metadata = self.input_batch.get_pooling_metadata()
pooling_metadata.build_pooling_cursor(
num_scheduled_tokens_np, seq_lens_cpu, device=hidden_states.device
num_scheduled_tokens_np,
seq_lens_cpu,
device=hidden_states.device,
query_start_loc_gpu=self.query_start_loc.gpu[: num_reqs + 1],
)
model = cast(VllmModelForPooling, self.model)