File size: 5,867 Bytes
5830d86
 
 
4cab1f1
5830d86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62392ef
4cab1f1
5830d86
4cab1f1
 
 
 
 
 
5830d86
 
 
4ce1ecd
5830d86
 
 
 
 
 
 
 
 
 
4cab1f1
 
2bc5010
5830d86
4cab1f1
 
 
5830d86
 
 
 
 
 
4cab1f1
5830d86
 
 
 
 
4cab1f1
5830d86
 
4cab1f1
 
 
 
 
 
 
 
 
 
 
 
 
 
5830d86
4cab1f1
 
 
 
5830d86
 
 
 
 
4cab1f1
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import os
import requests
import asyncio
import aiohttp
from typing import Optional
from langchain_mistralai.chat_models import ChatMistralAI
from langchain.schema import SystemMessage, HumanMessage

class MistralClient:
    def __init__(self, api_key: str):
        self.chat_model = ChatMistralAI(
            mistral_api_key=api_key,
            model="ft:ministral-3b-latest:82f3f89c:20250125:12222969",
            temperature=0.7
        )
        
        # Pour le fixing parser
        self.fixing_model = ChatMistralAI(
            mistral_api_key=api_key,
            model="ft:ministral-3b-latest:82f3f89c:20250125:12222969",
            temperature=0.1
        )
        
        # Pour gérer le rate limit
        self.last_call_time = 0
        self.min_delay = 1  # 1 seconde minimum entre les appels
    
    async def _wait_for_rate_limit(self):
        """Attend le temps nécessaire pour respecter le rate limit."""
        current_time = asyncio.get_event_loop().time()
        time_since_last_call = current_time - self.last_call_time
        
        if time_since_last_call < self.min_delay:
            await asyncio.sleep(self.min_delay - time_since_last_call)
        
        self.last_call_time = asyncio.get_event_loop().time()
    
    async def generate_story(self, messages) -> str:
        """Génère une réponse à partir d'une liste de messages."""
        try:
            await self._wait_for_rate_limit()
            response = self.chat_model.invoke(messages)
            return response.content
        except Exception as e:
            print(f"Error in Mistral API call: {str(e)}")
            raise

    async def transform_prompt(self, story_text: str, system_prompt: str) -> str:
        """Transforme un texte d'histoire en prompt artistique."""
        try:
            await self._wait_for_rate_limit()
            messages = [
                SystemMessage(content=system_prompt),
                HumanMessage(content=f"Transform into a short prompt: {story_text}")
            ]
            response = self.chat_model.invoke(messages)
            return response.content
        except Exception as e:
            print(f"Error transforming prompt: {str(e)}")
            return story_text

class FluxClient:
    def __init__(self, api_key: str):
        self.api_key = api_key
        self.endpoint = os.getenv("FLUX_ENDPOINT")
        self._session = None
    
    async def _get_session(self):
        if self._session is None:
            self._session = aiohttp.ClientSession()
        return self._session
    
    async def generate_image(self, 
                      prompt: str, 
                      width: int, 
                      height: int,
                      num_inference_steps: int = 3,
                      guidance_scale: float = 9.0) -> Optional[bytes]:
        """Génère une image à partir d'un prompt."""
        try:
            # Ensure dimensions are multiples of 8
            width = (width // 8) * 8
            height = (height // 8) * 8
            
            print(f"Sending request to Hugging Face API: {self.endpoint}")
            print(f"Headers: Authorization: Bearer {self.api_key[:4]}...")
            print(f"Request body: {prompt[:100]}...")

            prefix =  "François Schuiten comic book artist."
            "Bubbles, text, caption. Do not include bright or clean clothing."
            
            
            session = await self._get_session()
            async with session.post(
                self.endpoint,
                headers={
                    "Authorization": f"Bearer {self.api_key}",
                    "Accept": "image/jpeg"
                },
                json={
                    "inputs": "in the style of " + prefix + " --- content: " + prompt,
                    "parameters": {
                        "num_inference_steps": num_inference_steps,
                        "guidance_scale": guidance_scale,
                        "width": width,
                        "height": height,
                        "negative_prompt": "Bubbles, text, caption. Do not include bright or clean clothing."
                    }
                }
            ) as response:
                print(f"Response status code: {response.status}")
                print(f"Response headers: {response.headers}")
                print(f"Response content type: {response.headers.get('content-type', 'unknown')}")
                
                if response.status == 200:
                    content = await response.read()
                    content_length = len(content)
                    print(f"Received successful response with content length: {content_length}")
                    if isinstance(content, bytes):
                        print("Response content is bytes (correct)")
                    else:
                        print(f"Warning: Response content is {type(content)}")
                    return content
                else:
                    error_content = await response.text()
                    print(f"Error from Flux API: {response.status}")
                    print(f"Response content: {error_content}")
                    return None
                
        except Exception as e:
            print(f"Error in FluxClient.generate_image: {str(e)}")
            import traceback
            print(f"Traceback: {traceback.format_exc()}")
            return None
            
    async def close(self):
        if self._session:
            await self._session.close()
            # self._session = None where there is a post apocalypse scene, and my main character is a tough female survivor in a post-apocalyptic world with short, messy hair, dirt-smeared skin, and rugged clothing. She wears a leather jacket, utility pants, and carries makeshift weapons she is infront of an abandoned hospital