diff --git a/benchmark/benchmark_latency.py b/benchmark/benchmark_latency.py index d3a5cc901..410f7f0bd 100644 --- a/benchmark/benchmark_latency.py +++ b/benchmark/benchmark_latency.py @@ -6,7 +6,7 @@ from tqdm import tqdm import numpy as np import torch -from cacheflow.master.server import ( +from cacheflow.core.server import ( add_server_arguments, process_server_arguments, init_local_server_and_frontend_with_arguments) from cacheflow.sampling_params import SamplingParams @@ -15,15 +15,14 @@ from cacheflow.sampling_params import SamplingParams def main(args: argparse.Namespace): server, frontend = init_local_server_and_frontend_with_arguments(args) - sampling_params_dict = { - 'n': args.n, - 'temperature': 0.0 if args.use_beam_search else 1.0, - 'top_p': 1.0, - 'use_beam_search': args.use_beam_search, - 'stop_token_ids': set(), - 'max_num_steps': args.output_len, - } - sampling_params = SamplingParams.from_dict(sampling_params_dict) + sampling_params = SamplingParams( + n=args.n, + temperature=0.0 if args.use_beam_search else 1.0, + top_p=1.0, + use_beam_search=args.use_beam_search, + stop_token_ids=set(), + max_tokens=args.output_len, + ) print(sampling_params) input_token_ids = [0] * args.input_len @@ -31,7 +30,8 @@ def main(args: argparse.Namespace): if profile: torch.cuda.cudart().cudaProfilerStart() for _ in range(args.batch_size): - frontend._add_query(input_token_ids, sampling_params) + dummy_prompt = "" + frontend._add_query(dummy_prompt, input_token_ids, sampling_params) server.add_sequence_groups(frontend.get_inputs()) start_time = time.time() while True: diff --git a/cacheflow/core/scheduler.py b/cacheflow/core/scheduler.py index 12f6157ec..7d7d1dab7 100644 --- a/cacheflow/core/scheduler.py +++ b/cacheflow/core/scheduler.py @@ -316,7 +316,7 @@ class Scheduler: continue # Check if the sequence has reached the maximum number of steps. - max_num_steps = self.sampling_params[group_id].max_num_steps + max_num_steps = self.sampling_params[group_id].max_tokens if self.num_steps[group_id] == max_num_steps: self._free_seq(seq) continue diff --git a/cacheflow/frontend/fastapi_frontend.py b/cacheflow/frontend/fastapi_frontend.py index c47120972..9dabd4dba 100644 --- a/cacheflow/frontend/fastapi_frontend.py +++ b/cacheflow/frontend/fastapi_frontend.py @@ -89,8 +89,8 @@ class FastAPIServer: async def generate(self, request_dict: Dict): # Preprocess the request. - prompt = request_dict["prompt"] - sampling_params = SamplingParams.from_dict(request_dict) + prompt = request_dict.pop("prompt") + sampling_params = SamplingParams(**request_dict) sampling_params.stop_token_ids.add(self.tokenizer.eos_token_id) token_ids = self.tokenizer.encode(prompt) seqs: List[Sequence] = [] diff --git a/cacheflow/model_executor/layers/sampler.py b/cacheflow/model_executor/layers/sampler.py index 838339c78..05703bce4 100644 --- a/cacheflow/model_executor/layers/sampler.py +++ b/cacheflow/model_executor/layers/sampler.py @@ -367,7 +367,7 @@ def _sample( next_token_ids = _sample_from_prompt(prob, sampling_params) # Get top-k log probabilities for the next tokens. next_logprobs = _get_topk_logprobs( - logprob, sampling_params.num_logprobs) + logprob, sampling_params.logprobs) # Build the output. for seq_id, next_token_id in zip(seq_ids, next_token_ids): @@ -392,7 +392,7 @@ def _sample( next_logprobs: Dict[int, Dict[int, float]] = {} for i, seq_id in enumerate(seq_ids): next_logprobs[seq_id] = _get_topk_logprobs( - logprob[i], sampling_params.num_logprobs) + logprob[i], sampling_params.logprobs) # Build the output. for seq_id, parent_seq_id, next_token_id in zip( diff --git a/cacheflow/sampling_params.py b/cacheflow/sampling_params.py index e8f670568..d1b36c410 100644 --- a/cacheflow/sampling_params.py +++ b/cacheflow/sampling_params.py @@ -5,16 +5,16 @@ class SamplingParams: def __init__( self, - n: int, - presence_penalty: float, - frequency_penalty: float, - temperature: float, - top_p: float, - top_k: int, - use_beam_search: bool, - stop_token_ids: Set[int], - max_num_steps: int, - num_logprobs: int, + n: int = 1, + presence_penalty: float = 0.0, + frequency_penalty: float = 0.0, + temperature: float = 1.0, + top_p: float = 1.0, + top_k: int = -1, + use_beam_search: bool = False, + stop_token_ids: Set[int] = set(), + max_tokens: int = 16, + logprobs: int = 0, ) -> None: if n < 1: raise ValueError(f"n must be at least 1, got {n}.") @@ -32,12 +32,12 @@ class SamplingParams: if top_k < -1 or top_k == 0: raise ValueError(f"top_k must be -1 (disable), or at least 1, " f"got {top_k}.") - if max_num_steps < 1: + if max_tokens < 1: raise ValueError( - f"max_num_steps must be at least 1, got {max_num_steps}.") - if num_logprobs < 0: + f"max_tokens must be at least 1, got {max_tokens}.") + if logprobs < 0: raise ValueError( - f"num_logprobs must be non-negative, got {num_logprobs}.") + f"logprobs must be non-negative, got {logprobs}.") if use_beam_search: if n == 1: @@ -72,8 +72,8 @@ class SamplingParams: self.top_k = top_k self.use_beam_search = use_beam_search self.stop_token_ids = stop_token_ids - self.max_num_steps = max_num_steps - self.num_logprobs = num_logprobs + self.max_tokens = max_tokens + self.logprobs = logprobs def __repr__(self) -> str: return (f"SamplingParams(n={self.n}, " @@ -84,23 +84,5 @@ class SamplingParams: f"top_k={self.top_k}," f"use_beam_search={self.use_beam_search}, " f"stop_token_ids={self.stop_token_ids}, " - f"max_num_steps={self.max_num_steps}, " - f"num_logprobs={self.num_logprobs}") - - @classmethod - def from_dict(cls, d: Dict) -> "SamplingParams": - sampling_params = cls( - n=d.pop("n", 1), - presence_penalty=d.pop("presence_penalty", 0.0), - frequency_penalty=d.pop("frequency_penalty", 0.0), - temperature=d.pop("temperature", 1.0), - top_p=d.pop("top_p", 1.0), - top_k=d.pop("top_k", -1), - use_beam_search=d.pop("use_beam_search", False), - stop_token_ids=set(d.pop("stop_token_ids", set())), - max_num_steps=d.pop("max_num_steps", 16), - num_logprobs=d.pop("num_logprobs", 0), - ) - if d: - raise ValueError(f"Unrecognized keys in dict: {d.keys()}") - return sampling_params + f"max_tokens={self.max_tokens}, " + f"logprobs={self.logprobs}") diff --git a/gradio_webserver.py b/gradio_webserver.py index 290496da3..d819ecab0 100644 --- a/gradio_webserver.py +++ b/gradio_webserver.py @@ -10,7 +10,7 @@ def http_bot(prompt): headers = {"User-Agent": "Cacheflow Client"} pload = { "prompt": prompt, - "max_num_steps": 128, + "max_tokens": 128, } response = requests.post(args.model_url, headers=headers, json=pload, stream=True) diff --git a/simple_server.py b/simple_server.py index c8cea42df..2aca052ba 100644 --- a/simple_server.py +++ b/simple_server.py @@ -18,7 +18,7 @@ def main(args: argparse.Namespace): while True: if test_inputs: text, sampling_params_dict = test_inputs.pop(0) - sampling_params = SamplingParams.from_dict(sampling_params_dict) + sampling_params = SamplingParams(**sampling_params_dict) sampling_params = frontend.add_eos_token(sampling_params) frontend.query(text, sampling_params) server.add_sequence_groups(frontend.get_inputs())