New weight loader without np copy (#52)

This commit is contained in:
Zhuohan Li
2023-05-03 15:32:04 +08:00
committed by GitHub
parent 4858f3bb45
commit 27f1410d06
12 changed files with 284 additions and 352 deletions

View File

@@ -1,7 +1,7 @@
import argparse
import asyncio
import time
from typing import List, Dict
from typing import List, Dict, Optional
import json
import ray
@@ -22,11 +22,12 @@ TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds
app = FastAPI()
class FastAPIFrontend:
class FastAPIServer:
def __init__(
self,
model: str,
model_path: str,
cache_dir: Optional[str],
use_np_cache: bool,
pipeline_parallel_size: int,
tensor_parallel_size: int,
block_size: int,
@@ -52,8 +53,9 @@ class FastAPIFrontend:
remote_server_class = ray.remote(num_gpus=1)(Server)
self.server = remote_server_class.remote(
model=model,
model_path=model_path,
cache_dir=cache_dir,
use_dummy_weights=False,
use_np_cache=use_np_cache,
pipeline_parallel_size=pipeline_parallel_size,
tensor_parallel_size=tensor_parallel_size,
block_size=block_size,
@@ -148,7 +150,7 @@ class FastAPIFrontend:
@app.post("/generate")
async def generate_stream(request: Request):
request_dict = await request.json()
return StreamingResponse(frontend.generate(request_dict))
return StreamingResponse(server.generate(request_dict))
if __name__ == "__main__":
@@ -170,9 +172,10 @@ if __name__ == "__main__":
pipeline_parallel_size=args.pipeline_parallel_size,
tensor_parallel_size=args.tensor_parallel_size))
frontend = FastAPIFrontend(
server = FastAPIServer(
model=args.model,
model_path=args.model_path,
cache_dir=args.cache_dir,
use_np_cache=args.use_np_cache,
pipeline_parallel_size=args.pipeline_parallel_size,
tensor_parallel_size=args.tensor_parallel_size,
block_size=args.block_size,