doo1 / main.py
Jamiiwej2903's picture
Update main.py
2c26da0 verified
import asyncio
import websockets
import json
from fastapi import FastAPI
from pydantic import BaseModel
from huggingface_hub import InferenceClient
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI()
client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")
WEBSOCKET_URL = "wss://4d24-196-75-8-109.ngrok-free.app/ws"
class Item(BaseModel):
prompt: str
history: list = []
system_prompt: str = ""
temperature: float = 0.0
max_new_tokens: int = 1048
top_p: float = 0.15
repetition_penalty: float = 1.0
def format_prompt(message, history):
prompt = "<s>"
for user_prompt, bot_response in history:
prompt += f"[INST] {user_prompt} [/INST]"
prompt += f" {bot_response}</s> "
prompt += f"[INST] {message} [/INST]"
return prompt
async def generate_stream(item: Item):
temperature = max(float(item.temperature), 1e-2)
top_p = float(item.top_p)
generate_kwargs = dict(
temperature=temperature,
max_new_tokens=item.max_new_tokens,
top_p=top_p,
repetition_penalty=item.repetition_penalty,
do_sample=True,
seed=42,
)
formatted_prompt = format_prompt(f"{item.system_prompt}, {item.prompt}", item.history)
stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
for response in stream:
yield response.token.text
async def websocket_client():
call_count = 0
while True:
try:
async with websockets.connect(WEBSOCKET_URL) as websocket:
logger.info("WebSocket connection established")
while True:
# Request a prompt
await websocket.send(json.dumps({"type": "getPrompt"}))
message = await websocket.recv()
data = json.loads(message)
logger.info(f"Received: {data}")
if data.get('type') == 'prompt':
call_count += 1
prompt_id = data.get('id')
prompt_text = data.get('prompt')
logger.info(f"Processing prompt: {prompt_text}")
logger.info(f"Call count: {call_count}")
item = Item(prompt=prompt_text)
async for chunk in generate_stream(item):
await websocket.send(json.dumps({
"type": "chunk",
"id": prompt_id,
"chunk": chunk
}))
logger.info(f"Sent chunk: {chunk}")
await websocket.send(json.dumps({
"type": "completed",
"id": prompt_id
}))
logger.info("Generation completed")
elif data.get('type') == 'error':
logger.info(f"Received error: {data.get('message')}")
# Wait a bit before requesting a new prompt
await asyncio.sleep(5)
else:
logger.info(f"Received unexpected message type: {data.get('type')}")
except websockets.exceptions.ConnectionClosed:
logger.error("WebSocket connection closed. Retrying in 5 seconds...")
await asyncio.sleep(5)
except Exception as e:
logger.error(f"Error: {e}. Retrying in 5 seconds...")
await asyncio.sleep(5)
@app.on_event("startup")
async def startup_event():
asyncio.create_task(websocket_client())
@app.get("/")
async def root():
return {"message": "WebSocket client is running"}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)