[Perf] Optimize token_embed for pooling models, 1.0% token throughput improvement (#37347)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
@@ -101,6 +101,7 @@ class PoolingMetadata:
|
||||
num_scheduled_tokens_np: np.ndarray,
|
||||
seq_lens_cpu: torch.Tensor,
|
||||
device: torch.device,
|
||||
query_start_loc_gpu: torch.Tensor | None = None,
|
||||
):
|
||||
n_seq = len(num_scheduled_tokens_np)
|
||||
prompt_lens = self.prompt_lens
|
||||
@@ -109,11 +110,25 @@ class PoolingMetadata:
|
||||
|
||||
index = list(range(n_seq))
|
||||
num_scheduled_tokens_cpu = torch.from_numpy(num_scheduled_tokens_np)
|
||||
cumsum = torch.zeros(
|
||||
n_seq + 1, dtype=torch.int64, pin_memory=pin_memory, device="cpu"
|
||||
)
|
||||
torch.cumsum(num_scheduled_tokens_cpu, dim=0, out=cumsum[1:])
|
||||
cumsum = cumsum.to(device, non_blocking=True)
|
||||
if query_start_loc_gpu is None:
|
||||
cumsum = torch.zeros(
|
||||
n_seq + 1, dtype=torch.int64, pin_memory=pin_memory, device="cpu"
|
||||
)
|
||||
torch.cumsum(num_scheduled_tokens_cpu, dim=0, out=cumsum[1:])
|
||||
cumsum = cumsum.to(device, non_blocking=True)
|
||||
else:
|
||||
if query_start_loc_gpu.shape[0] != n_seq + 1:
|
||||
raise ValueError(
|
||||
"query_start_loc_gpu length does not match "
|
||||
f"the number of sequences: {query_start_loc_gpu.shape[0]} "
|
||||
f"!= {n_seq + 1}."
|
||||
)
|
||||
if query_start_loc_gpu.device != device:
|
||||
raise ValueError(
|
||||
"query_start_loc_gpu must be on the same device as the "
|
||||
f"hidden states: {query_start_loc_gpu.device} != {device}."
|
||||
)
|
||||
cumsum = query_start_loc_gpu
|
||||
self.pooling_cursor = PoolingCursor(
|
||||
index=index,
|
||||
first_token_indices_gpu=cumsum[:n_seq],
|
||||
|
||||
@@ -2928,7 +2928,10 @@ class GPUModelRunner(
|
||||
|
||||
pooling_metadata = self.input_batch.get_pooling_metadata()
|
||||
pooling_metadata.build_pooling_cursor(
|
||||
num_scheduled_tokens_np, seq_lens_cpu, device=hidden_states.device
|
||||
num_scheduled_tokens_np,
|
||||
seq_lens_cpu,
|
||||
device=hidden_states.device,
|
||||
query_start_loc_gpu=self.query_start_loc.gpu[: num_reqs + 1],
|
||||
)
|
||||
|
||||
model = cast(VllmModelForPooling, self.model)
|
||||
|
||||
Reference in New Issue
Block a user