Spaces:
Running
Running
File size: 2,423 Bytes
7fc61e0 e888e88 7fc61e0 a5d32b6 7fc61e0 a5d32b6 7fc61e0 a5d32b6 7fc61e0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 |
from fastapi import FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import os
import json
from dotenv import load_dotenv
from mistralai import Mistral
from prompts.instruction_prompts import instruction_prompt
from prompts.game_rules import game_rules
from prompts.hints import hints
from prompts.triggers import triggers
from helper_functions import load_chat_history, save_chat_history, update_chat_history
from utils import model, trump_character, client
app = FastAPI()
class Message(BaseModel):
message: str
def generate_text(message: Message):
# Load existing chat history
chat_history = load_chat_history()
# Add user message to history
chat_history = update_chat_history(chat_history, user_message=message.message)
# Format the prompt
formatted_prompt = instruction_prompt.format(
hints=hints,
chat_history=chat_history,
character=trump_character,
rules=game_rules,
triggers=triggers
)
# Get Character's response
chat_response = client.chat.complete(
model=model,
messages=[
{
"role": "system",
"content": formatted_prompt
},
{
"role": "user",
"content": message.message
}
]
)
clean_response = chat_response.choices[0].message.content
# Add character response to history
chat_history = update_chat_history(chat_history, character_response=clean_response)
# Save updated chat history
save_chat_history(chat_history)
return {
"character_response": clean_response,
"chat_history": chat_history
}
@app.post("/api/generate-text")
def inference(message: Message, request: Request):
if request.headers.get("origin") != "https://Mistral-AI-Game-Jam/team13.static.hf.space":
return 204
return generate_text(messages=message.messages)
@app.get("/chat-history", tags=["History"])
def get_chat_history(request: Request):
if request.headers.get("origin") != "https://Mistral-AI-Game-Jam/team13.static.hf.space":
return 204
else:
chat_history = load_chat_history()
return {"chat_history": chat_history}
|