Spaces:
Sleeping
Sleeping
File size: 6,415 Bytes
3a19185 7fc61e0 20a632a 7fc61e0 e888e88 7fc61e0 3a19185 7fc61e0 a5d32b6 3a19185 b7a57b8 3a19185 a5d32b6 20a632a 3a19185 20a632a 7fc61e0 a5d32b6 3a19185 78484fd 3a19185 b7a57b8 3a19185 78484fd 3a19185 78484fd 3a19185 78484fd 3a19185 6aab2ec 3a19185 78484fd 3a19185 7fc61e0 3a19185 7fc61e0 3a19185 7fc61e0 3a19185 7fc61e0 3a19185 78484fd 3a19185 241e696 876b4c4 241e696 7fc61e0 3a19185 fab8c1b 3a19185 6aab2ec b7a57b8 c80c485 6aab2ec |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 |
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
import os
import json
from dotenv import load_dotenv
from mistralai import Mistral
from original_setup.instruction_prompts import instruction_prompt
from original_setup.game_rules import game_rules
from original_setup.hints import hints
from original_setup.triggers import triggers
from helper_functions import *
from utils import model, trump_character, client
from graph_utils import *
# initialize game
init_game = True
game_over_rich = False
game_over_bankrupt = False
world_graph = WorldGraph(f'original_setup/contexts/world_map.edgelist')
dico_world = world_graph.push_data_to_front()
# Initialize FastAPI app
app = FastAPI()
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # React app's address
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class Message(BaseModel):
message: str
@app.get("/chat-history", tags=['History'])
async def get_chat_history():
global init_game
try:
#If we're ate the beginning of a game
if init_game:
game_number = initialize_game()
init_game = False
else:
game_number = len(os.listdir('games/'))
chat_history = load_chat_history(f'games/game_{game_number}')
return {"chat_history": chat_history}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/api/generate-text")
async def send_message(message: Message):
global init_game
global game_over_rich
global game_over_bankrupt
global dico_world
try:
#If we're ate the beginning of a game
if init_game:
game_number = initialize_game()
init_game = False
else:
game_number = len(os.listdir('games/'))
# Load existing chat history
chat_history = load_chat_history(f'games/game_{game_number}')
#If we're at the beginning of a round
if message.message == "":
idea, concern, advisor_full, events, consequences = generate_round_context(game_number)
round_context = {
"idea": idea,
"concern": concern,
'advisor': advisor_full,
"events": events
}
with open(f'games/game_{game_number}/round_context.json', 'w') as f:
json.dump(round_context, f, indent=4)
with open(f'games/game_{game_number}/round_consequences.json', 'w') as f:
json.dump(consequences, f, indent=4)
else:
file_path = f'games/game_{game_number}/round_context.json'
if os.path.exists(file_path):
with open(file_path, 'r') as f:
round_context = json.load(f)
idea = round_context.get("idea")
concern = round_context.get("concern")
advisor_full = round_context.get("advisor")
events = round_context.get("events")
else:
raise FileNotFoundError(f"Round context file not found: {file_path}")
# Add user message to history
if message.message != "":
chat_history = update_chat_history(chat_history, user_message=message.message)
else:
chat_history = []
# Format the prompt
formatted_prompt = instruction_prompt.format(
hints=hints,
chat_history=chat_history, #useless, don't worry
character=trump_character,
rules=game_rules,
triggers=triggers,
advisor=advisor_full,
events=events,
idea=idea,
concern=concern,
)
# Get Trump's response
#### TO STREAM : USE ASYNC VERSION
system = [{"role": "system", "content": formatted_prompt}]
dynamic_history = []
role_mapping = {
"user": "user",
"trump": "assistant" # Mapping 'trump' to 'assistant'
}
for interaction in chat_history:
for key, value in interaction.items():
user_message = value['user']['message'] if 'user' in value else "..."
trump_message = value['trump']['message'] if value['trump'] else None
dynamic_history.append({
"role": role_mapping["user"],
"content": user_message
})
# Append Trump's message, mapped to 'assistant'
if trump_message:
dynamic_history.append({
"role": role_mapping["trump"],
"content": trump_message
})
messages = system + dynamic_history
chat_response = client.chat.complete(
model=model,
messages=messages
)
trump_response = chat_response.choices[0].message.content
# Add Trump's response to history
chat_history = update_chat_history(chat_history, trump_message=trump_response)
# Save updated chat history
save_chat_history(chat_history, f'games/game_{game_number}')
is_ending, idea_is_accepted = check_end(trump_response)
if is_ending:
process_ending(idea_is_accepted, game_number, idea)
world_graph = WorldGraph(f'games/game_{game_number}/world_graph.edgelist')
dico_world = world_graph.push_data_to_front()
return {
"character_response": trump_response,
"chat_history": chat_history,
"chat_ended": is_ending,
"idea": idea,
"idea_is_accepted": idea_is_accepted,
}
except Exception as e:
print(e)
raise e
raise HTTPException(status_code=500, detail=str(e))
@app.get("/api/world-data", tags=['Game'])
async def get_world_data():
global dico_world
try:
return {"world_data": dico_world}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# Add this before the final static files mount
app.mount("/images", StaticFiles(directory="static/images"), name="images")
# Keep this as the last mount
app.mount("/", StaticFiles(directory="static", html=True), name="static")
|