File size: 2,423 Bytes
7fc61e0
 
 
e888e88
7fc61e0
 
 
 
 
 
 
 
 
a5d32b6
 
 
7fc61e0
 
a5d32b6
7fc61e0
 
 
a5d32b6
7fc61e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
from fastapi import FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import os
import json
from dotenv import load_dotenv
from mistralai import Mistral
from prompts.instruction_prompts import instruction_prompt
from prompts.game_rules import game_rules
from prompts.hints import hints
from prompts.triggers import triggers
from helper_functions import load_chat_history, save_chat_history, update_chat_history
from utils import model, trump_character, client

app = FastAPI()

class Message(BaseModel):
    message: str

def generate_text(message: Message):
    # Load existing chat history
        chat_history = load_chat_history()

        # Add user message to history
        chat_history = update_chat_history(chat_history, user_message=message.message)

        # Format the prompt
        formatted_prompt = instruction_prompt.format(
            hints=hints,
            chat_history=chat_history,
            character=trump_character,
            rules=game_rules,
            triggers=triggers
        )

        # Get Character's response
        chat_response = client.chat.complete(
            model=model,
            messages=[
                {
                    "role": "system",
                    "content": formatted_prompt
                },
                {
                    "role": "user",
                    "content": message.message
                }
            ]
        )
        clean_response = chat_response.choices[0].message.content

        # Add character response to history
        chat_history = update_chat_history(chat_history, character_response=clean_response)

        # Save updated chat history
        save_chat_history(chat_history)

        return {
            "character_response": clean_response,
            "chat_history": chat_history
        }

@app.post("/api/generate-text")
def inference(message: Message, request: Request):
    if request.headers.get("origin") != "https://Mistral-AI-Game-Jam/team13.static.hf.space":
        return 204
    return generate_text(messages=message.messages)


@app.get("/chat-history", tags=["History"])
def get_chat_history(request: Request):
    if request.headers.get("origin") != "https://Mistral-AI-Game-Jam/team13.static.hf.space":
        return 204
    else:
        chat_history = load_chat_history()
        return {"chat_history": chat_history}