New weight loader without np copy (#52)
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
import time
|
||||
from typing import List, Dict
|
||||
from typing import List, Dict, Optional
|
||||
import json
|
||||
|
||||
import ray
|
||||
@@ -22,11 +22,12 @@ TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
class FastAPIFrontend:
|
||||
class FastAPIServer:
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
model_path: str,
|
||||
cache_dir: Optional[str],
|
||||
use_np_cache: bool,
|
||||
pipeline_parallel_size: int,
|
||||
tensor_parallel_size: int,
|
||||
block_size: int,
|
||||
@@ -52,8 +53,9 @@ class FastAPIFrontend:
|
||||
remote_server_class = ray.remote(num_gpus=1)(Server)
|
||||
self.server = remote_server_class.remote(
|
||||
model=model,
|
||||
model_path=model_path,
|
||||
cache_dir=cache_dir,
|
||||
use_dummy_weights=False,
|
||||
use_np_cache=use_np_cache,
|
||||
pipeline_parallel_size=pipeline_parallel_size,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
block_size=block_size,
|
||||
@@ -148,7 +150,7 @@ class FastAPIFrontend:
|
||||
@app.post("/generate")
|
||||
async def generate_stream(request: Request):
|
||||
request_dict = await request.json()
|
||||
return StreamingResponse(frontend.generate(request_dict))
|
||||
return StreamingResponse(server.generate(request_dict))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@@ -170,9 +172,10 @@ if __name__ == "__main__":
|
||||
pipeline_parallel_size=args.pipeline_parallel_size,
|
||||
tensor_parallel_size=args.tensor_parallel_size))
|
||||
|
||||
frontend = FastAPIFrontend(
|
||||
server = FastAPIServer(
|
||||
model=args.model,
|
||||
model_path=args.model_path,
|
||||
cache_dir=args.cache_dir,
|
||||
use_np_cache=args.use_np_cache,
|
||||
pipeline_parallel_size=args.pipeline_parallel_size,
|
||||
tensor_parallel_size=args.tensor_parallel_size,
|
||||
block_size=args.block_size,
|
||||
|
||||
Reference in New Issue
Block a user