[Core] Implement disagg prefill by StatelessProcessGroup (#10502)
This PR provides initial support for single-node disaggregated prefill in 1P1D scenario. Signed-off-by: KuntaiDu <kuntai@uchicago.edu> Co-authored-by: ApostaC <yihua98@uchicago.edu> Co-authored-by: YaoJiayi <120040070@link.cuhk.edu.cn>
This commit is contained in:
61
benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py
Normal file
61
benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py
Normal file
@@ -0,0 +1,61 @@
|
||||
import os
|
||||
|
||||
import aiohttp
|
||||
from quart import Quart, make_response, request
|
||||
|
||||
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)
|
||||
|
||||
app = Quart(__name__)
|
||||
|
||||
|
||||
async def forward_request(url, data):
|
||||
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"
|
||||
}
|
||||
async with session.post(url=url, json=data,
|
||||
headers=headers) as response:
|
||||
if response.status == 200:
|
||||
# if response.headers.get('Transfer-Encoding') == 'chunked':
|
||||
if True:
|
||||
async for chunk_bytes in response.content.iter_chunked(
|
||||
1024):
|
||||
yield chunk_bytes
|
||||
else:
|
||||
content = await response.read()
|
||||
yield content
|
||||
|
||||
|
||||
@app.route('/v1/completions', methods=['POST'])
|
||||
async def handle_request():
|
||||
try:
|
||||
original_request_data = await request.get_json()
|
||||
|
||||
prefill_request = original_request_data.copy()
|
||||
# change max_tokens = 1 to let it only do prefill
|
||||
prefill_request['max_tokens'] = 1
|
||||
|
||||
# finish prefill
|
||||
async for _ in forward_request('http://localhost:8100/v1/completions',
|
||||
prefill_request):
|
||||
continue
|
||||
|
||||
# return decode
|
||||
generator = forward_request('http://localhost:8200/v1/completions',
|
||||
original_request_data)
|
||||
response = await make_response(generator)
|
||||
response.timeout = None
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
import sys
|
||||
import traceback
|
||||
exc_info = sys.exc_info()
|
||||
print("Error occurred in disagg prefill proxy server")
|
||||
print(e)
|
||||
print("".join(traceback.format_exception(*exc_info)))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
app.run(port=8000)
|
||||
Reference in New Issue
Block a user