Update deprecated Python 3.8 typing (#13971)
This commit is contained in:
@@ -6,7 +6,7 @@ distributively on a multi-nodes cluster.
|
||||
Learn more about Ray Data in https://docs.ray.io/en/latest/data/data.html
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import ray
|
||||
@@ -36,13 +36,13 @@ class LLMPredictor:
|
||||
self.llm = LLM(model="meta-llama/Llama-2-7b-chat-hf",
|
||||
tensor_parallel_size=tensor_parallel_size)
|
||||
|
||||
def __call__(self, batch: Dict[str, np.ndarray]) -> Dict[str, list]:
|
||||
def __call__(self, batch: dict[str, np.ndarray]) -> dict[str, list]:
|
||||
# Generate texts from the prompts.
|
||||
# The output is a list of RequestOutput objects that contain the prompt,
|
||||
# generated text, and other information.
|
||||
outputs = self.llm.generate(batch["text"], sampling_params)
|
||||
prompt: List[str] = []
|
||||
generated_text: List[str] = []
|
||||
prompt: list[str] = []
|
||||
generated_text: list[str] = []
|
||||
for output in outputs:
|
||||
prompt.append(output.prompt)
|
||||
generated_text.append(' '.join([o.text for o in output.outputs]))
|
||||
@@ -72,7 +72,7 @@ def scheduling_strategy_fn():
|
||||
pg, placement_group_capture_child_tasks=True))
|
||||
|
||||
|
||||
resources_kwarg: Dict[str, Any] = {}
|
||||
resources_kwarg: dict[str, Any] = {}
|
||||
if tensor_parallel_size == 1:
|
||||
# For tensor_parallel_size == 1, we simply set num_gpus=1.
|
||||
resources_kwarg["num_gpus"] = 1
|
||||
|
||||
Reference in New Issue
Block a user