File size: 19,316 Bytes
5830d86
fcaa899
24bdd7c
5830d86
 
fcaa899
 
287d153
 
a461215
5830d86
 
 
 
f370e63
5830d86
 
 
 
 
f370e63
 
5830d86
f370e63
 
5830d86
fcaa899
 
 
 
24bdd7c
 
 
f370e63
287d153
 
62392ef
24bdd7c
5db3b83
fcaa899
 
 
 
24bdd7c
 
 
 
 
8e4efa5
24bdd7c
fcaa899
 
 
 
 
24bdd7c
 
0146535
 
 
 
 
 
 
5830d86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fcaa899
5db3b83
 
 
 
 
5830d86
5db3b83
5830d86
 
 
 
4cab1f1
5db3b83
77af7b8
 
5db3b83
fcaa899
287d153
 
5830d86
 
287d153
 
 
 
 
 
62392ef
 
 
 
4cab1f1
 
 
 
 
5830d86
 
 
 
 
 
 
 
 
 
 
 
 
24bdd7c
 
 
 
 
 
 
 
 
 
 
 
77af7b8
fcaa899
287d153
 
24bdd7c
5db3b83
4cab1f1
24bdd7c
5db3b83
4cab1f1
5db3b83
24bdd7c
5db3b83
287d153
4cab1f1
287d153
24bdd7c
4cab1f1
 
24bdd7c
 
4cab1f1
287d153
5db3b83
24bdd7c
a461215
 
4cab1f1
a461215
 
 
 
 
4cab1f1
 
 
 
5db3b83
4cab1f1
 
 
a461215
 
 
 
 
4cab1f1
 
a461215
 
 
 
 
4cab1f1
 
 
 
 
 
 
 
a461215
4cab1f1
 
5db3b83
4cab1f1
5db3b83
 
4cab1f1
287d153
4cab1f1
5db3b83
5830d86
4cab1f1
 
 
 
5db3b83
4cab1f1
 
 
 
 
 
287d153
 
5db3b83
fcaa899
287d153
 
 
fcaa899
 
5830d86
287d153
 
5830d86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
287d153
5830d86
 
 
 
 
287d153
5830d86
 
287d153
5830d86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4cab1f1
5830d86
 
 
 
 
 
 
 
 
 
 
 
287d153
 
5830d86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62392ef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4cab1f1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5830d86
 
 
 
 
 
 
 
 
 
 
287d153
24bdd7c
f370e63
24bdd7c
fcaa899
 
f370e63
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
from fastapi import FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel, Field
from typing import List, Optional, Dict
import os
from dotenv import load_dotenv
import base64
import time
import random
import asyncio
import aiohttp
from lorem.text import TextLorem
from contextlib import asynccontextmanager


lorem = TextLorem(wsep='-', srange=(2,3), words="A B C D".split())


# Import local modules
if os.getenv("DOCKER_ENV"):
    from server.game.game_logic import GameState, StoryGenerator, MAX_RADIATION
    from server.api_clients import FluxClient
else:
    from game.game_logic import GameState, StoryGenerator, MAX_RADIATION
    from api_clients import FluxClient

# Load environment variables
load_dotenv()

# API configuration
API_HOST = os.getenv("API_HOST", "0.0.0.0")
API_PORT = int(os.getenv("API_PORT", "8000"))
STATIC_FILES_DIR = os.getenv("STATIC_FILES_DIR", "../client/dist")
HF_API_KEY = os.getenv("HF_API_KEY")
AWS_TOKEN = os.getenv("AWS_TOKEN", "VHVlIEZlYiAyNyAwOTowNzoyMiBDRVQgMjAyNA==")  # Token par défaut pour le développement
ELEVEN_LABS_API_KEY = os.getenv("ELEVEN_LABS_API_KEY")  # Nouvelle clé d'API

app = FastAPI(title="Echoes of Influence")

# Configure CORS
app.add_middleware(
    CORSMiddleware,
    allow_origins=[
        "http://localhost:5173",  # Vite dev server
        f"http://localhost:{API_PORT}",  # API port
        "https://huggingface.co",  # HF main domain
        "https://*.hf.space",      # HF Spaces domains
        "https://mistral-ai-game-jam-dont-lookup.hf.space"  # Our HF Space URL
    ],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Initialize game components
game_state = GameState()

# Check for API key
mistral_api_key = os.getenv("MISTRAL_API_KEY")
if not mistral_api_key:
    raise ValueError("MISTRAL_API_KEY environment variable is not set")

story_generator = StoryGenerator(api_key=mistral_api_key)
flux_client = FluxClient(api_key=HF_API_KEY)

# Store client sessions and requests by type
client_sessions: Dict[str, aiohttp.ClientSession] = {}
client_requests: Dict[str, Dict[str, asyncio.Task]] = {}

async def get_client_session(client_id: str) -> aiohttp.ClientSession:
    """Get or create a client session"""
    if client_id not in client_sessions:
        client_sessions[client_id] = aiohttp.ClientSession()
    return client_sessions[client_id]

async def cancel_previous_request(client_id: str, request_type: str):
    """Cancel previous request if it exists"""
    if client_id in client_requests and request_type in client_requests[client_id]:
        task = client_requests[client_id][request_type]
        if not task.done():
            task.cancel()
            try:
                await task
            except asyncio.CancelledError:
                pass

async def store_request(client_id: str, request_type: str, task: asyncio.Task):
    """Store a request for a client"""
    if client_id not in client_requests:
        client_requests[client_id] = {}
    client_requests[client_id][request_type] = task

class Choice(BaseModel):
    id: int
    text: str

class StoryResponse(BaseModel):
    story_text: str = Field(description="The story text with proper nouns in bold using ** markdown")
    choices: List[Choice]
    radiation_level: int = Field(description="Current radiation level from 0 to 10")
    is_victory: bool = Field(description="Whether this segment ends in Sarah's victory", default=False)
    is_first_step: bool = Field(description="Whether this is the first step of the story", default=False)
    is_last_step: bool = Field(description="Whether this is the last step (victory or death)", default=False)
    image_prompts: List[str] = Field(description="List of 1 to 3 comic panel descriptions that illustrate the key moments of the scene", min_items=1, max_items=3)

class ChatMessage(BaseModel):
    message: str
    choice_id: Optional[int] = None

class ImageGenerationRequest(BaseModel):
    prompt: str
    width: int = Field(description="Width of the image to generate")
    height: int = Field(description="Height of the image to generate")

class ImageGenerationResponse(BaseModel):
    success: bool
    image_base64: Optional[str] = None
    error: Optional[str] = None

class TextToSpeechRequest(BaseModel):
    text: str
    voice_id: str = "nPczCjzI2devNBz1zQrb"  # Default voice ID (Rachel)

class DirectImageGenerationRequest(BaseModel):
    prompt: str = Field(description="The prompt to use directly for image generation")
    width: int = Field(description="Width of the image to generate")
    height: int = Field(description="Height of the image to generate")

async def get_test_image(client_id: str, width=1024, height=1024):
    """Get a random image from Lorem Picsum"""
    # Build the Lorem Picsum URL with blur and grayscale effects
    url = f"https://picsum.photos/{width}/{height}?grayscale&blur=2"
    
    session = await get_client_session(client_id)
    async with session.get(url) as response:
        if response.status == 200:
            image_bytes = await response.read()
            return base64.b64encode(image_bytes).decode('utf-8')
        else:
            raise Exception(f"Failed to fetch image: {response.status}")

@app.get("/api/health")
async def health_check():
    """Health check endpoint"""
    return {
        "status": "healthy",
        "game_state": {
            "story_beat": game_state.story_beat,
            "radiation_level": game_state.radiation_level
        }
    }

@app.post("/api/chat", response_model=StoryResponse)
async def chat_endpoint(chat_message: ChatMessage):
    try:
        print("Received chat message:", chat_message)
        
        # Handle restart
        if chat_message.message.lower() == "restart":
            print("Handling restart - Resetting game state")
            game_state.reset()
            previous_choice = "none"
            print(f"After reset - story_beat: {game_state.story_beat}")
        else:
            previous_choice = f"Choice {chat_message.choice_id}" if chat_message.choice_id else "none"

        print("Previous choice:", previous_choice)
        print("Current story beat:", game_state.story_beat)

        # Generate story segment
        llm_response = await story_generator.generate_story_segment(game_state, previous_choice)
        print("Generated story segment:", llm_response)
        
        # Update radiation level
        game_state.radiation_level += llm_response.radiation_increase
        print("Updated radiation level:", game_state.radiation_level)
        
        # Check for radiation death
        is_death = game_state.radiation_level >= MAX_RADIATION
        if is_death:
            llm_response.story_text += f"""

MORT PAR RADIATION: Le corps de Sarah ne peut plus supporter ce niveau de radiation ({game_state.radiation_level}/10). 
Ses cellules se désagrègent alors qu'elle s'effondre, l'esprit rempli de regrets concernant sa sœur. 
Les fournitures médicales qu'elle transportait n'atteindront jamais leur destination. 
Sa mission s'arrête ici, une autre victime du tueur invisible des terres désolées."""
            llm_response.choices = []
            # Pour la mort, on ne garde qu'un seul prompt d'image
            if len(llm_response.image_prompts) > 1:
                llm_response.image_prompts = [llm_response.image_prompts[0]]
        
        # Add segment to history (before victory check to include final state)
        game_state.add_to_history(llm_response.story_text, previous_choice, llm_response.image_prompts)

        # Check for victory condition
        if not is_death and game_state.story_beat >= 5:
            # Chance de victoire augmente avec le nombre de steps
            victory_chance = (game_state.story_beat - 4) * 0.2  # 20% de chance par step après le 5ème
            if random.random() < victory_chance:
                llm_response.is_victory = True
                llm_response.story_text = f"""Sarah l'a fait ! Elle a trouvé un bunker sécurisé avec des survivants. 
                À l'intérieur, elle découvre une communauté organisée qui a réussi à maintenir un semblant de civilisation. 
                Ils ont même un système de décontamination ! Son niveau de radiation : {game_state.radiation_level}/10.
                Elle peut enfin se reposer et peut-être un jour, reconstruire un monde meilleur.
                
                VICTOIRE !"""
                llm_response.choices = []
                # Pour la victoire, on ne garde qu'un seul prompt d'image
                if len(llm_response.image_prompts) > 1:
                    llm_response.image_prompts = [llm_response.image_prompts[0]]

        # Pour la première étape, on ne garde qu'un seul prompt d'image
        if game_state.story_beat == 0 and len(llm_response.image_prompts) > 1:
            llm_response.image_prompts = [llm_response.image_prompts[0]]
        
        # Convert LLM choices to API choices format
        choices = [] if is_death or llm_response.is_victory else [
            Choice(id=i, text=choice.strip())
            for i, choice in enumerate(llm_response.choices, 1)
        ]

        # Convert LLM response to API response format
        response = StoryResponse(
            story_text=llm_response.story_text,
            choices=choices,
            radiation_level=game_state.radiation_level,
            is_victory=llm_response.is_victory,
            is_first_step=game_state.story_beat == 0,
            is_last_step=is_death or llm_response.is_victory,
            image_prompts=llm_response.image_prompts
        )
        
        # Only increment story beat if not dead and not victory
        if not is_death and not llm_response.is_victory:
            game_state.story_beat += 1
            print("Incremented story beat to:", game_state.story_beat)
            
        print("Sending response:", response)
        return response

    except Exception as e:
        import traceback
        print(f"Error in chat_endpoint: {str(e)}")
        print("Traceback:", traceback.format_exc())
        raise HTTPException(status_code=500, detail=str(e))

@app.post("/api/generate-image")
async def generate_image(request: ImageGenerationRequest):
    try:
        # Transform story into art prompt
        art_prompt = await story_generator.transform_story_to_art_prompt(request.prompt)
        
        print(f"Generating image with dimensions: {request.width}x{request.height}")
        print(f"Using prompt: {art_prompt}")

        # Generate image using Flux client
        image_bytes = flux_client.generate_image(
            prompt=art_prompt,
            width=request.width,
            height=request.height
        )
        
        if image_bytes:
            print(f"Received image bytes of length: {len(image_bytes)}")
            # Ensure we're getting raw bytes and encoding them properly
            if isinstance(image_bytes, str):
                print("Warning: image_bytes is a string, converting to bytes")
                image_bytes = image_bytes.encode('utf-8')
            base64_image = base64.b64encode(image_bytes).decode('utf-8').strip('"')
            print(f"Converted to base64 string of length: {len(base64_image)}")
            print(f"First 100 chars of base64: {base64_image[:100]}")
            return {"success": True, "image_base64": base64_image}
        else:
            print("No image bytes received from Flux client")
            return {"success": False, "error": "Failed to generate image"}

    except Exception as e:
        print(f"Error generating image: {str(e)}")
        print(f"Error type: {type(e)}")
        import traceback
        print(f"Traceback: {traceback.format_exc()}")
        return {"success": False, "error": str(e)}

@app.post("/api/test/chat")
async def test_chat_endpoint(request: Request, chat_message: ChatMessage):
    """Endpoint de test qui génère des données aléatoires"""
    try:
        client_id = request.headers.get("x-client-id", "default")
        
        # Cancel any previous chat request from this client
        await cancel_previous_request(client_id, "chat")
        
        async def generate_chat_response():
            # Générer un texte aléatoire
            story_text = f"**Sarah** {lorem.paragraph()}"
            
            # Générer un niveau de radiation aléatoire qui augmente progressivement
            radiation_level = min(10, random.randint(0, 3) + (chat_message.choice_id or 0))
            
            # Déterminer si c'est le premier pas
            is_first_step = chat_message.message == "restart"
            
            # Déterminer si c'est le dernier pas (mort ou victoire)
            is_last_step = radiation_level >= 30 or (
                not is_first_step and random.random() < 0.1  # 10% de chance de victoire
            )
            
            # Générer des choix aléatoires sauf si c'est la fin
            choices = []
            if not is_last_step:
                num_choices = 2
                for i in range(num_choices):
                    choices.append(Choice(
                        id=i+1,
                        text=f"{lorem.sentence() }"
                    ))
            
            # Construire la réponse
            return StoryResponse(
                story_text=story_text,
                choices=choices,
                radiation_level=radiation_level,
                is_victory=is_last_step and radiation_level < 30
            )
        
        # Create and store the new request
        task = asyncio.create_task(generate_chat_response())
        await store_request(client_id, "chat", task)
        
        try:
            response = await task
            return response
        except asyncio.CancelledError:
            print(f"[INFO] Chat request cancelled for client {client_id}")
            raise HTTPException(status_code=409, detail="Request cancelled")

    except Exception as e:
        print(f"[ERROR] Error in test_chat_endpoint: {str(e)}")
        raise HTTPException(status_code=500, detail=str(e))

@app.post("/api/test/generate-image")
async def test_generate_image(request: Request, image_request: ImageGenerationRequest):
    """Endpoint de test qui récupère une image aléatoire"""
    try:
        client_id = request.headers.get("x-client-id", "default")
        
        print(f"[DEBUG] Client ID: {client_id}")
        print(f"[DEBUG] Raw request data: {image_request}")
        
        # Cancel any previous image request from this client
        await cancel_previous_request(client_id, "image")
        
        # Create and store the new request
        task = asyncio.create_task(get_test_image(client_id, image_request.width, image_request.height))
        await store_request(client_id, "image", task)
        
        try:
            image_base64 = await task
            return {
                "success": True,
                "image_base64": image_base64
            }
        except asyncio.CancelledError:
            print(f"[INFO] Image request cancelled for client {client_id}")
            return {
                "success": False,
                "error": "Request cancelled"
            }
            
    except Exception as e:
        print(f"[ERROR] Detailed error in test_generate_image: {str(e)}")
        return {
            "success": False,
            "error": str(e)
        }

@app.post("/api/text-to-speech")
async def text_to_speech(request: TextToSpeechRequest):
    """Endpoint pour convertir du texte en audio via ElevenLabs"""
    try:
        if not ELEVEN_LABS_API_KEY:
            raise HTTPException(status_code=500, detail="ElevenLabs API key not configured")

        # Nettoyer le texte des balises markdown **
        clean_text = request.text.replace("**", "")

        # Appel à l'API ElevenLabs
        url = f"https://api.elevenlabs.io/v1/text-to-speech/{request.voice_id}"
        headers = {
            "Accept": "audio/mpeg",
            "Content-Type": "application/json",
            "xi-api-key": ELEVEN_LABS_API_KEY
        }
        data = {
            "text": clean_text,
            "model_id": "eleven_multilingual_v2",
            "voice_settings": {
                "stability": 0.5,
                "similarity_boost": 0.75
            }
        }

        async with aiohttp.ClientSession() as session:
            async with session.post(url, json=data, headers=headers) as response:
                if response.status == 200:
                    audio_content = await response.read()
                    # Convertir l'audio en base64 pour l'envoyer au client
                    audio_base64 = base64.b64encode(audio_content).decode('utf-8')
                    return {"success": True, "audio_base64": audio_base64}
                else:
                    error_text = await response.text()
                    raise HTTPException(status_code=response.status, detail=error_text)

    except Exception as e:
        print(f"Error in text_to_speech: {str(e)}")
        raise HTTPException(status_code=500, detail=str(e))

@app.post("/api/generate-image-direct")
async def generate_image_direct(request: DirectImageGenerationRequest):
    try:
        print(f"Generating image directly with dimensions: {request.width}x{request.height}")
        print(f"Using prompt: {request.prompt}")

        # Generate image using Flux client directly without transforming the prompt
        image_bytes = await flux_client.generate_image(
            prompt=request.prompt,
            width=request.width,
            height=request.height
        )
        
        if image_bytes:
            print(f"Received image bytes of length: {len(image_bytes)}")
            if isinstance(image_bytes, str):
                print("Warning: image_bytes is a string, converting to bytes")
                image_bytes = image_bytes.encode('utf-8')
            base64_image = base64.b64encode(image_bytes).decode('utf-8').strip('"')
            print(f"Converted to base64 string of length: {len(base64_image)}")
            return {"success": True, "image_base64": base64_image}
        else:
            print("No image bytes received from Flux client")
            return {"success": False, "error": "Failed to generate image"}

    except Exception as e:
        print(f"Error generating image: {str(e)}")
        print(f"Error type: {type(e)}")
        import traceback
        print(f"Traceback: {traceback.format_exc()}")
        return {"success": False, "error": str(e)}

@app.on_event("shutdown")
async def shutdown_event():
    """Clean up sessions on shutdown"""
    # Cancel all pending requests
    for client_id in client_requests:
        for request_type in client_requests[client_id]:
            await cancel_previous_request(client_id, request_type)
    
    # Close all sessions
    for session in client_sessions.values():
        await session.close()

# Mount static files (this should be after all API routes)
app.mount("/", StaticFiles(directory=STATIC_FILES_DIR, html=True), name="static")

if __name__ == "__main__":
    import uvicorn
    uvicorn.run("server.server:app", host=API_HOST, port=API_PORT, reload=True)