Spaces:
Sleeping
Sleeping
File size: 4,028 Bytes
5d82b1b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
import json
import os
import subprocess
import sys
import time
import threading
import aiohttp
import requests
from lagent.schema import AgentMessage
class HTTPAgentClient:
def __init__(self, host='127.0.0.1', port=8090, timeout=None):
self.host = host
self.port = port
self.timeout = timeout
@property
def is_alive(self):
try:
resp = requests.get(
f'http://{self.host}:{self.port}/health_check',
timeout=self.timeout)
return resp.status_code == 200
except:
return False
def __call__(self, *message, session_id: int = 0, **kwargs):
response = requests.post(
f'http://{self.host}:{self.port}/chat_completion',
json={
'message': [
m if isinstance(m, str) else m.model_dump()
for m in message
],
'session_id': session_id,
**kwargs,
},
headers={'Content-Type': 'application/json'},
timeout=self.timeout)
resp = response.json()
if response.status_code != 200:
return resp
return AgentMessage.model_validate(resp)
def state_dict(self, session_id: int = 0):
resp = requests.get(
f'http://{self.host}:{self.port}/memory/{session_id}',
timeout=self.timeout)
return resp.json()
class HTTPAgentServer(HTTPAgentClient):
def __init__(self, gpu_id, config, host='127.0.0.1', port=8090):
super().__init__(host, port)
self.gpu_id = gpu_id
self.config = config
self.start_server()
def start_server(self):
# set CUDA_VISIBLE_DEVICES in subprocess
env = os.environ.copy()
env['CUDA_VISIBLE_DEVICES'] = self.gpu_id
cmds = [
sys.executable, 'lagent/distributed/http_serve/app.py', '--host',
self.host, '--port',
str(self.port), '--config',
json.dumps(self.config)
]
self.process = subprocess.Popen(
cmds,
env=env,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True)
self.service_started = False
def log_output(stream):
if stream is not None:
for line in iter(stream.readline, ''):
print(line, end='')
if 'Uvicorn running on' in line:
self.service_started = True
# Start log output thread
threading.Thread(target=log_output, args=(self.process.stdout,), daemon=True).start()
threading.Thread(target=log_output, args=(self.process.stderr,), daemon=True).start()
# Waiting for the service to start
while not self.service_started:
time.sleep(0.1)
def shutdown(self):
self.process.terminate()
self.process.wait()
class AsyncHTTPAgentMixin:
async def __call__(self, *message, session_id: int = 0, **kwargs):
async with aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(self.timeout)) as session:
async with session.post(
f'http://{self.host}:{self.port}/chat_completion',
json={
'message': [
m if isinstance(m, str) else m.model_dump()
for m in message
],
'session_id': session_id,
**kwargs,
},
headers={'Content-Type': 'application/json'},
) as response:
resp = await response.json()
if response.status != 200:
return resp
return AgentMessage.model_validate(resp)
class AsyncHTTPAgentClient(AsyncHTTPAgentMixin, HTTPAgentClient):
pass
class AsyncHTTPAgentServer(AsyncHTTPAgentMixin, HTTPAgentServer):
pass
|