[core][misc] remove logical block (#5882)

This commit is contained in:
youkaichao
2024-06-27 13:34:55 -07:00
committed by GitHub
parent 79c92c7c8a
commit 64e8d2a783
3 changed files with 16 additions and 120 deletions

View File

@@ -1,13 +1,13 @@
"""Sequence and its related classes."""
import copy
import enum
import math
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import torch
from vllm.block import LogicalTokenBlock
from vllm.inputs import LLMInputs
from vllm.lora.request import LoRARequest
from vllm.pooling_params import PoolingParams
@@ -236,9 +236,6 @@ class Sequence:
self.output_logprobs: SampleLogprobs = []
self.output_text = ""
self.logical_token_blocks: List[LogicalTokenBlock] = []
# Initialize the logical token blocks with the prompt token ids.
self._append_tokens_to_blocks(self.prompt_token_ids)
self.status = SequenceStatus.WAITING
self.stop_reason: Union[int, str, None] = None
@@ -248,6 +245,10 @@ class Sequence:
# Input + output tokens
self.tokens: Optional[List[str]] = None
@property
def n_blocks(self) -> int:
return math.ceil(self.get_len() / self.block_size)
@property
def prompt(self) -> Optional[str]:
return self.inputs.get("prompt")
@@ -287,36 +288,12 @@ class Sequence:
"""Reset the sequence states for recomputation."""
self.data.reset_state_for_recompute()
def _append_logical_block(self) -> None:
block = LogicalTokenBlock(
block_number=len(self.logical_token_blocks),
block_size=self.block_size,
)
self.logical_token_blocks.append(block)
def _append_tokens_to_blocks(self, token_ids: List[int]) -> None:
cursor = 0
while cursor < len(token_ids):
if not self.logical_token_blocks:
self._append_logical_block()
last_block = self.logical_token_blocks[-1]
if last_block.is_full():
self._append_logical_block()
last_block = self.logical_token_blocks[-1]
num_empty_slots = last_block.get_num_empty_slots()
last_block.append_tokens(token_ids[cursor:cursor +
num_empty_slots])
cursor += num_empty_slots
def append_token_id(
self,
token_id: int,
logprobs: Dict[int, Logprob],
) -> None:
assert token_id in logprobs
self._append_tokens_to_blocks([token_id])
self.output_logprobs.append(logprobs)
self.data.append_token_id(token_id, logprobs[token_id].logprob)
@@ -388,7 +365,7 @@ class Sequence:
def __repr__(self) -> str:
return (f"Sequence(seq_id={self.seq_id}, "
f"status={self.status.name}, "
f"num_blocks={len(self.logical_token_blocks)})")
f"num_blocks={self.n_blocks}, ")
@dataclass