Abhaykoul commited on
Commit
6f99d02
·
verified ·
1 Parent(s): 1c7150a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +709 -617
app.py CHANGED
@@ -1,634 +1,726 @@
1
- from fastapi import FastAPI, HTTPException, Query
2
- from fastapi.responses import JSONResponse, StreamingResponse
3
- from webscout import WEBS, YTTranscriber, LLM, GoogleS
4
- from typing import Optional, List, Dict
5
- from fastapi.encoders import jsonable_encoder
6
- from bs4 import BeautifulSoup
7
- import requests
8
  import aiohttp
 
 
 
 
9
  import asyncio
10
- import threading
11
- import json
12
- from huggingface_hub import InferenceClient
13
- from PIL import Image
14
- import io
15
- from easygoogletranslate import EasyGoogleTranslate
16
- from pydantic import BaseModel
17
-
18
-
19
- app = FastAPI()
20
-
21
- # Define Pydantic models for request payloads
22
- class ChatRequest(BaseModel):
23
- q: str
24
- model: str = "gpt-4o-mini"
25
- history: List[Dict[str, str]] = []
26
- proxy: Optional[str] = None
27
-
28
- class AIRequest(BaseModel):
29
- user: str
30
- model: str = "llama3-70b"
31
- system: str = "Answer as concisely as possible."
32
-
33
- @app.get("/")
34
- async def root():
35
- return {"message": "API documentation can be found at /docs"}
36
-
37
- @app.get("/health")
38
- async def health_check():
39
- return {"status": "OK"}
40
-
41
- @app.get("/api/search")
42
- async def search(
43
- q: str,
44
- max_results: int = 10,
45
- timelimit: Optional[str] = None,
46
- safesearch: str = "moderate",
47
- region: str = "wt-wt",
48
- backend: str = "api",
49
- proxy: Optional[str] = None
50
- ):
51
- """Perform a text search."""
52
- try:
53
- with WEBS(proxy=proxy) as webs:
54
- results = webs.text(
55
- keywords=q,
56
- region=region,
57
- safesearch=safesearch,
58
- timelimit=timelimit,
59
- backend=backend,
60
- max_results=max_results,
61
- )
62
- return JSONResponse(content=jsonable_encoder(results))
63
- except Exception as e:
64
- raise HTTPException(status_code=500, detail=f"Error during search: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
- @app.get("/api/search_google")
67
- async def search_google(
68
- q: str,
69
- max_results: int = 10,
70
- safesearch: str = "moderate",
71
- region: str = "wt-wt",
72
- proxy: Optional[str] = None
73
- ):
74
- """Perform a text search."""
75
- try:
76
- with GoogleS(proxy=proxy) as webs:
77
- results = webs.search(
78
- query=q,
79
- region=region,
80
- safe=safesearch,
81
- max_results=max_results,
82
- )
83
- return JSONResponse(content=jsonable_encoder(results))
84
- except Exception as e:
85
- raise HTTPException(status_code=500, detail=f"Error during search: {e}")
86
- @app.get("/api/images")
87
- async def images(
88
- q: str,
89
- max_results: int = 10,
90
- safesearch: str = "moderate",
91
- region: str = "wt-wt",
92
- timelimit: Optional[str] = None,
93
- size: Optional[str] = None,
94
- color: Optional[str] = None,
95
- type_image: Optional[str] = None,
96
- layout: Optional[str] = None,
97
- license_image: Optional[str] = None,
98
- proxy: Optional[str] = None
99
- ):
100
- """Perform an image search."""
101
- try:
102
- with WEBS(proxy=proxy) as webs:
103
- results = webs.images(
104
- keywords=q,
105
- region=region,
106
- safesearch=safesearch,
107
- timelimit=timelimit,
108
- size=size,
109
- color=color,
110
- type_image=type_image,
111
- layout=layout,
112
- license_image=license_image,
113
- max_results=max_results,
114
- )
115
- return JSONResponse(content=jsonable_encoder(results))
116
- except Exception as e:
117
- raise HTTPException(status_code=500, detail=f"Error during image search: {e}")
118
-
119
- @app.get("/api/videos")
120
- async def videos(
121
- q: str,
122
- max_results: int = 10,
123
- safesearch: str = "moderate",
124
- region: str = "wt-wt",
125
- timelimit: Optional[str] = None,
126
- resolution: Optional[str] = None,
127
- duration: Optional[str] = None,
128
- license_videos: Optional[str] = None,
129
- proxy: Optional[str] = None
130
- ):
131
- """Perform a video search."""
132
- try:
133
- with WEBS(proxy=proxy) as webs:
134
- results = webs.videos(
135
- keywords=q,
136
- region=region,
137
- safesearch=safesearch,
138
- timelimit=timelimit,
139
- resolution=resolution,
140
- duration=duration,
141
- license_videos=license_videos,
142
- max_results=max_results,
143
- )
144
- return JSONResponse(content=jsonable_encoder(results))
145
- except Exception as e:
146
- raise HTTPException(status_code=500, detail=f"Error during video search: {e}")
147
 
148
- @app.get("/api/news")
149
- async def news(
150
- q: str,
151
- max_results: int = 10,
152
- safesearch: str = "moderate",
153
- region: str = "wt-wt",
154
- timelimit: Optional[str] = None,
155
- proxy: Optional[str] = None
156
- ):
157
- """Perform a news search."""
158
- try:
159
- with WEBS(proxy=proxy) as webs:
160
- results = webs.news(
161
- keywords=q,
162
- region=region,
163
- safesearch=safesearch,
164
- timelimit=timelimit,
165
- max_results=max_results
166
  )
167
- return JSONResponse(content=jsonable_encoder(results))
168
- except Exception as e:
169
- raise HTTPException(status_code=500, detail=f"Error during news search: {e}")
170
-
171
- @app.get("/api/answers")
172
- async def answers(q: str, proxy: Optional[str] = None):
173
- """Get instant answers for a query."""
174
- try:
175
- with WEBS(proxy=proxy) as webs:
176
- results = webs.answers(keywords=q)
177
- return JSONResponse(content=jsonable_encoder(results))
178
- except Exception as e:
179
- raise HTTPException(status_code=500, detail=f"Error getting instant answers: {e}")
180
-
181
- @app.get("/api/maps")
182
- async def maps(
183
- q: str,
184
- place: Optional[str] = None,
185
- street: Optional[str] = None,
186
- city: Optional[str] = None,
187
- county: Optional[str] = None,
188
- state: Optional[str] = None,
189
- country: Optional[str] = None,
190
- postalcode: Optional[str] = None,
191
- latitude: Optional[str] = None,
192
- longitude: Optional[str] = None,
193
- radius: int = 0,
194
- max_results: int = 10,
195
- proxy: Optional[str] = None
196
- ):
197
- """Perform a maps search."""
198
- try:
199
- with WEBS(proxy=proxy) as webs:
200
- results = webs.maps(keywords=q, place=place, street=street, city=city, county=county, state=state, country=country, postalcode=postalcode, latitude=latitude, longitude=longitude, radius=radius, max_results=max_results)
201
- return JSONResponse(content=jsonable_encoder(results))
202
- except Exception as e:
203
- raise HTTPException(status_code=500, detail=f"Error during maps search: {e}")
204
-
205
- @app.get("/api/chat")
206
- async def chat(
207
- q: str,
208
- model: str = "gpt-4o-mini",
209
- proxy: Optional[str] = None
210
- ):
211
- """Interact with a specified large language model."""
212
- try:
213
- with WEBS(proxy=proxy) as webs:
214
- results = webs.chat(keywords=q, model=model)
215
- return JSONResponse(content=jsonable_encoder(results))
216
- except Exception as e:
217
- raise HTTPException(status_code=500, detail=f"Error getting chat results: {e}")
218
-
219
- @app.post("/api/chat-post")
220
- async def chat_post(request: ChatRequest):
221
- """Interact with a specified large language model with chat history."""
222
- try:
223
- with WEBS(proxy=request.proxy) as webs:
224
- results = webs.chat(keywords=request.q, model=request.model, chat_messages=request.history)
225
- return JSONResponse(content=jsonable_encoder(results))
226
- except Exception as e:
227
- raise HTTPException(status_code=500, detail=f"Error getting chat results: {e}")
228
-
229
- @app.get("/api/llm")
230
- async def llm_chat(
231
- model: str,
232
- message: str,
233
- system_prompt: str = Query(None, description="Optional custom system prompt")
234
- ):
235
- """Interact with a specified large language model with an optional system prompt."""
236
- try:
237
- messages = [{"role": "user", "content": message}]
238
- if system_prompt:
239
- messages.insert(0, {"role": "system", "content": system_prompt})
240
-
241
- llm = LLM(model=model)
242
- response = llm.chat(messages=messages)
243
- return JSONResponse(content={"response": response})
244
- except Exception as e:
245
- raise HTTPException(status_code=500, detail=f"Error during LLM chat: {e}")
246
-
247
- @app.post("/api/ai-post")
248
- async def ai_post(request: AIRequest):
249
- """Interact with a specified large language model (using AIRequest model)."""
250
- try:
251
- llm = LLM(model=request.model)
252
- response = llm.chat(messages=[
253
- {"role": "system", "content": request.system},
254
- {"role": "user", "content": request.user}
255
- ])
256
- return JSONResponse(content={"response": response})
257
- except Exception as e:
258
- raise HTTPException(status_code=500, detail=f"Error during AI request: {e}")
259
-
260
- def extract_text_from_webpage(html_content):
261
- """Extracts visible text from HTML content using BeautifulSoup."""
262
- soup = BeautifulSoup(html_content, "html.parser")
263
- # Remove unwanted tags
264
- for tag in soup(["script", "style", "header", "footer", "nav"]):
265
- tag.extract()
266
- # Get the remaining visible text
267
- visible_text = soup.get_text(strip=True)
268
- return visible_text
269
-
270
- async def fetch_and_extract(url, max_chars, proxy: Optional[str] = None):
271
- """Fetches a URL and extracts text asynchronously."""
272
-
273
- async with aiohttp.ClientSession() as session:
274
- try:
275
- async with session.get(url, headers={"User-Agent": "Mozilla/5.0"}, proxy=proxy) as response:
276
- response.raise_for_status()
277
- html_content = await response.text()
278
- visible_text = extract_text_from_webpage(html_content)
279
- if len(visible_text) > max_chars:
280
- visible_text = visible_text[:max_chars] + "..."
281
- return {"link": url, "text": visible_text}
282
- except (aiohttp.ClientError, requests.exceptions.RequestException) as e:
283
- print(f"Error fetching or processing {url}: {e}")
284
- return {"link": url, "text": None}
285
-
286
- @app.get("/api/web_extract")
287
- async def web_extract(
288
- url: str,
289
- max_chars: int = 12000, # Adjust based on token limit
290
- proxy: Optional[str] = None
291
- ):
292
- """Extracts text from a given URL."""
293
- try:
294
- result = await fetch_and_extract(url, max_chars, proxy)
295
- return {"url": url, "text": result["text"]}
296
- except requests.exceptions.RequestException as e:
297
- raise HTTPException(status_code=500, detail=f"Error fetching or processing URL: {e}")
298
-
299
- @app.get("/api/search-and-extract")
300
- async def web_search_and_extract(
301
- q: str,
302
- max_results: int = 3,
303
- timelimit: Optional[str] = None,
304
- safesearch: str = "moderate",
305
- region: str = "wt-wt",
306
- backend: str = "html",
307
- max_chars: int = 6000,
308
- extract_only: bool = True,
309
- proxy: Optional[str] = None
310
- ):
311
- """
312
- Searches using WEBS, extracts text from the top results, and returns both.
313
- """
314
- try:
315
- with WEBS(proxy=proxy) as webs:
316
- # Perform WEBS search
317
- search_results = webs.text(keywords=q, region=region, safesearch=safesearch,
318
- timelimit=timelimit, backend=backend, max_results=max_results)
319
-
320
- # Extract text from each result's link asynchronously
321
- tasks = [fetch_and_extract(result['href'], max_chars, proxy) for result in search_results if 'href' in result]
322
- extracted_results = await asyncio.gather(*tasks)
323
-
324
- if extract_only:
325
- return JSONResponse(content=jsonable_encoder(extracted_results))
326
- else:
327
- return JSONResponse(content=jsonable_encoder({"search_results": search_results, "extracted_results": extracted_results}))
328
- except Exception as e:
329
- raise HTTPException(status_code=500, detail=f"Error during search and extraction: {e}")
330
 
331
- def extract_text_from_webpage2(html_content):
332
- """Extracts visible text from HTML content using BeautifulSoup."""
333
- soup = BeautifulSoup(html_content, "html.parser")
334
- # Remove unwanted tags
335
- for tag in soup(["script", "style", "header", "footer", "nav"]):
336
- tag.extract()
337
- # Get the remaining visible text
338
- visible_text = soup.get_text(strip=True)
339
- return visible_text
340
 
341
- def fetch_and_extract2(url, max_chars, proxy: Optional[str] = None):
342
- """Fetches a URL and extracts text using threading."""
343
- proxies = {'http': proxy, 'https': proxy} if proxy else None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
344
  try:
345
- response = requests.get(url, headers={"User-Agent": "Mozilla/5.0"}, proxies=proxies)
346
- response.raise_for_status()
347
- html_content = response.text
348
- visible_text = extract_text_from_webpage2(html_content)
349
- if len(visible_text) > max_chars:
350
- visible_text = visible_text[:max_chars] + "..."
351
- return {"link": url, "text": visible_text}
352
- except (requests.exceptions.RequestException) as e:
353
- print(f"Error fetching or processing {url}: {e}")
354
- return {"link": url, "text": None}
 
 
 
 
 
 
 
 
355
 
356
- @app.get("/api/websearch-and-extract-threading")
357
- def web_search_and_extract_threading(
358
- q: str,
359
- max_results: int = 3,
360
- timelimit: Optional[str] = None,
361
- safesearch: str = "moderate",
362
- region: str = "wt-wt",
363
- backend: str = "html",
364
- max_chars: int = 6000,
365
- extract_only: bool = True,
366
- proxy: Optional[str] = None
367
- ):
368
- """
369
- Searches using WEBS, extracts text from the top results using threading, and returns both.
370
- """
371
  try:
372
- with WEBS(proxy=proxy) as webs:
373
- # Perform WEBS search
374
- search_results = webs.text(keywords=q, region=region, safesearch=safesearch,
375
- timelimit=timelimit, backend=backend, max_results=max_results)
376
-
377
- # Extract text from each result's link using threading
378
- extracted_results = []
379
- threads = []
380
- for result in search_results:
381
- if 'href' in result:
382
- thread = threading.Thread(target=lambda: extracted_results.append(fetch_and_extract2(result['href'], max_chars, proxy)))
383
- threads.append(thread)
384
- thread.start()
385
-
386
- # Wait for all threads to finish
387
- for thread in threads:
388
- thread.join()
389
-
390
- if extract_only:
391
- return JSONResponse(content=jsonable_encoder(extracted_results))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
392
  else:
393
- return JSONResponse(content=jsonable_encoder({"search_results": search_results, "extracted_results": extracted_results}))
394
- except Exception as e:
395
- raise HTTPException(status_code=500, detail=f"Error during search and extraction: {e}")
396
-
397
- @app.get("/api/adv_web_search")
398
- async def adv_web_search(
399
- q: str,
400
- model: str = "gpt-4o-mini", # Use webs.chat by default
401
- max_results: int = 5,
402
- timelimit: Optional[str] = None,
403
- safesearch: str = "moderate",
404
- region: str = "wt-wt",
405
- backend: str = "html",
406
- max_chars: int = 15000,
407
- system_prompt: str = "You are an advanced AI chatbot. Provide the best answer to the user based on Google search results.",
408
- proxy: Optional[str] = None
409
- ):
410
- """
411
- Combines web search, web extraction, and chat model for advanced search.
412
- """
413
- try:
414
- with WEBS(proxy=proxy) as webs:
415
- search_results = webs.text(keywords=q, region=region,
416
- safesearch=safesearch,
417
- timelimit=timelimit, backend=backend,
418
- max_results=max_results)
419
-
420
- # 2. Extract text from top search result URLs asynchronously
421
- extracted_text = ""
422
- tasks = [fetch_and_extract(result['href'], 6000, proxy) for result in search_results if 'href' in result]
423
- extracted_results = await asyncio.gather(*tasks)
424
- for result in extracted_results:
425
- if result['text'] and len(extracted_text) < max_chars:
426
- extracted_text += f"## Content from: {result['link']}\n\n{result['text']}\n\n"
427
-
428
- extracted_text[:max_chars]
429
-
430
-
431
- # 3. Construct the prompt for the chat model
432
- ai_prompt = (
433
- f"User Query: {q}\n\n"
434
- f"Please provide a detailed and accurate answer to the user's query. Include relevant information extracted from the search results below. Ensure to cite sources by providing links to the original content where applicable. Format your response as follows:\n\n"
435
- f"1. **Answer:** Provide a clear and comprehensive answer to the user's query.\n"
436
- f"2. **Details:** Include any additional relevant details or explanations.\n"
437
- f"3. **Sources:** List the sources of the information with clickable links for further reading.\n\n"
438
- f"Search Results:\n{extracted_text}"
439
- )
440
-
441
- # 4. Get the chat model's response using webs.chat
442
- with WEBS(proxy=proxy) as webs:
443
- response = webs.chat(keywords=ai_prompt, model=model)
444
-
445
- # 5. Return the results
446
- return JSONResponse(content={"response": response})
447
-
448
- except Exception as e:
449
- raise HTTPException(status_code=500, detail=f"Error during advanced search: {e}")
450
 
451
- @app.post("/api/AI_search_google")
452
- async def adv_web_search(
453
- q: str,
454
- model: str = "claude-3-haiku", # Use webs.chat by default
455
- max_results: int = 5,
456
- timelimit: Optional[str] = None,
457
- safesearch: str = "moderate",
458
- region: str = "wt-wt",
459
- # backend: str = "html",
460
- max_chars: int = 6000,
461
- system_prompt: str = "You are an advanced AI chatbot. Provide the best answer to the user based on Google search results.",
462
- proxy: Optional[str] = None
463
- ):
464
- """
465
- Combines web search, web extraction, and chat model for advanced search.
466
- """
467
- try:
468
- with GoogleS(proxy=proxy) as webs:
469
- search_results = webs.search(query=q, region=region,
470
- safe=safesearch,
471
- time_period=timelimit,
472
- max_results=max_results)
473
- # 2. Extract text from top search result URLs asynchronously
474
- extracted_text = ""
475
- tasks = [fetch_and_extract(result['href'], 6000, proxy) for result in search_results if 'href' in result]
476
- extracted_results = await asyncio.gather(*tasks)
477
- for result in extracted_results:
478
- if result['text'] and len(extracted_text) < max_chars:
479
- extracted_text += f"## Content from: {result['link']}\n\n{result['text']}\n\n"
480
-
481
- extracted_text[:max_chars]
482
-
483
-
484
- # 3. Construct the prompt for the chat model
485
- ai_prompt = (
486
- f"User Query: **{q}**\n\n"
487
- f"**Objective:** Provide a comprehensive and informative response to the user's query based on the extracted content from Google search results. Your answer should be structured in Markdown format for clarity and readability.\n\n"
488
- f"**Response Structure:**\n"
489
- f"1. **Answer:**\n"
490
- f" - Begin with a clear and concise answer to the user's question.\n\n"
491
- f"2. **Key Points:**\n"
492
- f" - Highlight essential details or facts relevant to the query using bullet points.\n\n"
493
- f"3. **Contextual Information:**\n"
494
- f" - Provide any necessary background or additional context that enhances understanding, using paragraphs as needed.\n\n"
495
- f"4. **Summary of Search Results:**\n"
496
- f" - Summarize key findings from the search results, emphasizing diversity in perspectives if applicable, formatted as a list.\n\n"
497
- f"5. **Sources:**\n"
498
- f" - List all sources of information with clickable links for further reading, ensuring proper citation of the extracted content, formatted as follows:\n"
499
- f" - [Source Title](URL)\n\n"
500
- f"**Search Results:**\n{extracted_text}\n\n"
501
- f"---\n\n"
502
- f"*Note: Ensure that all sections are clearly marked and that the response is easy to navigate.*"
503
  )
504
-
505
- # 4. Get the chat model's response using webs.chat
506
- with WEBS(proxy=proxy) as webs:
507
- response = webs.chat(keywords=ai_prompt, model=model)
508
-
509
- # 5. Return the results
510
- return JSONResponse(content={"answer": response})
511
-
512
- except Exception as e:
513
- raise HTTPException(status_code=500, detail=f"Error during advanced search: {e}")
514
-
515
-
516
- @app.get("/api/website_summarizer")
517
- async def website_summarizer(url: str, proxy: Optional[str] = None):
518
- """Summarizes the content of a given URL using a chat model."""
519
- try:
520
- # Extract text from the given URL
521
- proxies = {'http': proxy, 'https': proxy} if proxy else None
522
- response = requests.get(url, headers={"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:109.0) Gecko/20100101 Firefox/111.0"}, proxies=proxies)
523
- response.raise_for_status()
524
- visible_text = extract_text_from_webpage(response.text)
525
- if len(visible_text) > 7500: # Adjust max_chars based on your needs
526
- visible_text = visible_text[:7500] + "..."
527
-
528
- # Use chat model to summarize the extracted text
529
- with WEBS(proxy=proxy) as webs:
530
- summary_prompt = f"Summarize this in detail in Paragraph: {visible_text}"
531
- summary_result = webs.chat(keywords=summary_prompt, model="gpt-4o-mini")
532
-
533
- # Return the summary result
534
- return JSONResponse(content=jsonable_encoder({summary_result}))
535
-
536
- except requests.exceptions.RequestException as e:
537
- raise HTTPException(status_code=500, detail=f"Error fetching or processing URL: {e}")
538
- except Exception as e:
539
- raise HTTPException(status_code=500, detail=f"Error during summarization: {e}")
540
-
541
- @app.get("/api/ask_website")
542
- async def ask_website(url: str, question: str, model: str = "llama-3-70b", proxy: Optional[str] = None):
543
- """
544
- Asks a question about the content of a given website.
545
- """
546
- try:
547
- # Extract text from the given URL
548
- proxies = {'http': proxy, 'https': proxy} if proxy else None
549
- response = requests.get(url, headers={"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:109.0) Gecko/20100101 Firefox/111.0"}, proxies=proxies)
550
- response.raise_for_status()
551
- visible_text = extract_text_from_webpage(response.text)
552
- if len(visible_text) > 7500: # Adjust max_chars based on your needs
553
- visible_text = visible_text[:7500] + "..."
554
-
555
- # Construct a prompt for the chat model
556
- prompt = f"Based on the following text, answer this question in Paragraph: [QUESTION] {question} [TEXT] {visible_text}"
557
-
558
- # Use chat model to get the answer
559
- with WEBS(proxy=proxy) as webs:
560
- answer_result = webs.chat(keywords=prompt, model=model)
561
-
562
- # Return the answer result
563
- return JSONResponse(content=jsonable_encoder({answer_result}))
564
-
565
- except requests.exceptions.RequestException as e:
566
- raise HTTPException(status_code=500, detail=f"Error fetching or processing URL: {e}")
567
- except Exception as e:
568
- raise HTTPException(status_code=500, detail=f"Error during question answering: {e}")
569
-
570
-
571
-
572
- @app.get("/api/translate")
573
- async def translate(
574
- q: str,
575
- from_: Optional[str] = None,
576
- to: str = "en",
577
- proxy: Optional[str] = None
578
- ):
579
- """Translate text."""
580
- try:
581
- with WEBS(proxy=proxy) as webs:
582
- results = webs.translate(keywords=q, from_=from_, to=to)
583
- return JSONResponse(content=jsonable_encoder(results))
584
  except Exception as e:
585
- raise HTTPException(status_code=500, detail=f"Error during translation: {e}")
586
-
587
- @app.get("/api/google_translate")
588
- def google_translate(q: str, from_: Optional[str] = 'auto', to: str = "en"):
589
- try:
590
- translator = EasyGoogleTranslate(
591
- source_language=from_,
592
- target_language=to,
593
- timeout=10
594
  )
595
- result = translator.translate(q)
596
- return JSONResponse(content=jsonable_encoder({"detected_language": from_ , "original": q , "translated": result}))
597
- except Exception as e:
598
- raise HTTPException(status_code=500, detail=f"Error during translation: {e}")
599
-
600
- @app.get("/api/youtube/transcript")
601
- async def youtube_transcript(
602
- video_url: str,
603
- preserve_formatting: bool = False,
604
- proxy: Optional[str] = None # Add proxy parameter
605
- ):
606
- """Get the transcript of a YouTube video."""
607
- try:
608
- proxies = {"http": proxy, "https": proxy} if proxy else None
609
- transcript = YTTranscriber.get_transcript(video_url, languages=None, preserve_formatting=preserve_formatting, proxies=proxies)
610
- return JSONResponse(content=jsonable_encoder(transcript))
611
- except Exception as e:
612
- raise HTTPException(status_code=500, detail=f"Error getting YouTube transcript: {e}")
613
-
614
- @app.get("/weather/json/{location}")
615
- def get_weather_json(location: str):
616
- url = f"https://wttr.in/{location}?format=j1"
617
- response = requests.get(url)
618
- if response.status_code == 200:
619
- return response.json()
620
- else:
621
- return {"error": f"Unable to fetch weather data. Status code: {response.status_code}"}
622
-
623
- @app.get("/weather/ascii/{location}")
624
- def get_ascii_weather(location: str):
625
- url = f"https://wttr.in/{location}"
626
- response = requests.get(url, headers={'User-Agent': 'curl'})
627
- if response.status_code == 200:
628
- return response.text
629
- else:
630
- return {"error": f"Unable to fetch weather data. Status code: {response.status_code}"}
631
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
632
  if __name__ == "__main__":
633
- import uvicorn
634
- uvicorn.run(app, host="0.0.0.0", port=8083)
 
 
 
 
 
1
+ import os
2
+ import io
3
+ import uuid
4
+ import base64
5
+ from typing import Dict, List, Optional, Any, Union
6
+ from pathlib import Path
 
7
  import aiohttp
8
+ from fastapi import FastAPI, HTTPException, BackgroundTasks, Depends, Header, Request
9
+ from fastapi.responses import StreamingResponse, JSONResponse
10
+ from fastapi.middleware.cors import CORSMiddleware
11
+ from pydantic import BaseModel, Field
12
  import asyncio
13
+ import uvicorn
14
+ from datetime import datetime
15
+ import time
16
+
17
+ # Import all TTI providers
18
+ from webscout.Provider.TTI import (
19
+ # Import all image providers
20
+ BlackboxAIImager, AsyncBlackboxAIImager,
21
+ DeepInfraImager, AsyncDeepInfraImager,
22
+ AiForceimager, AsyncAiForceimager,
23
+ NexraImager, AsyncNexraImager,
24
+ FreeAIImager, AsyncFreeAIImager,
25
+ NinjaImager, AsyncNinjaImager,
26
+ TalkaiImager, AsyncTalkaiImager,
27
+ PiclumenImager, AsyncPiclumenImager,
28
+ ArtbitImager, AsyncArtbitImager,
29
+ HFimager, AsyncHFimager,
30
+ )
31
+
32
+ try:
33
+ from webscout.Provider.TTI import AIArtaImager, AsyncAIArtaImager
34
+ AIARTA_AVAILABLE = True
35
+ except ImportError:
36
+ AIARTA_AVAILABLE = False
37
+
38
+ # Create FastAPI instance
39
+ app = FastAPI(
40
+ title="WebScout TTI API Server",
41
+ description="API server for Text-to-Image generation using various providers with OpenAI-compatible interface",
42
+ version="1.0.0",
43
+ )
44
+
45
+ # Add CORS middleware to allow cross-origin requests
46
+ app.add_middleware(
47
+ CORSMiddleware,
48
+ allow_origins=["*"],
49
+ allow_credentials=True,
50
+ allow_methods=["*"],
51
+ allow_headers=["*"],
52
+ )
53
+
54
+ # Storage for generated images (in-memory for demo purposes)
55
+ # In a production environment, you might want to store these in a database or a file system
56
+ IMAGE_STORAGE = {}
57
+
58
+ # Simple API key verification (demo purposes only)
59
+ # In production, you'd want a more secure authentication system
60
+ API_KEYS = {"sk-demo-key": "demo"}
61
+
62
+ # Provider mapping
63
+ PROVIDER_MAP = {
64
+ "blackbox": {
65
+ "class": AsyncBlackboxAIImager,
66
+ "description": "High-performance image generation with advanced retry mechanisms"
67
+ },
68
+ "deepinfra": {
69
+ "class": AsyncDeepInfraImager,
70
+ "description": "Powerful image generation using FLUX-1-schnell and other models"
71
+ },
72
+ "aiforce": {
73
+ "class": AsyncAiForceimager,
74
+ "description": "Advanced AI image generation with 12 specialized models"
75
+ },
76
+ "nexra": {
77
+ "class": AsyncNexraImager,
78
+ "description": "Next-gen image creation with 19+ models"
79
+ },
80
+ "freeai": {
81
+ "class": AsyncFreeAIImager,
82
+ "description": "Premium image generation with DALL-E 3 and Flux series models"
83
+ },
84
+ "ninja": {
85
+ "class": AsyncNinjaImager,
86
+ "description": "Ninja-fast image generation with cyberpunk-themed logging"
87
+ },
88
+ "talkai": {
89
+ "class": AsyncTalkaiImager,
90
+ "description": "Fast and reliable image generation with comprehensive error handling"
91
+ },
92
+ "piclumen": {
93
+ "class": AsyncPiclumenImager,
94
+ "description": "Professional photorealistic image generation with advanced processing"
95
+ },
96
+ "artbit": {
97
+ "class": AsyncArtbitImager,
98
+ "description": "Bit-perfect AI art creation with precise control over parameters"
99
+ },
100
+ "huggingface": {
101
+ "class": AsyncHFimager,
102
+ "description": "Direct integration with HuggingFace's powerful models"
103
+ },
104
+ }
105
+
106
+ # Add AIArta provider if available
107
+ if AIARTA_AVAILABLE:
108
+ PROVIDER_MAP["aiarta"] = {
109
+ "class": AsyncAIArtaImager,
110
+ "description": "Generate stunning AI art with AI Arta with 45+ artistic styles"
111
+ }
112
+
113
+ # Provider model info
114
+ PROVIDER_MODEL_INFO = {
115
+ "blackbox": {
116
+ "default": "blackbox-default",
117
+ "models": ["blackbox-default"],
118
+ "default_params": {}
119
+ },
120
+ "deepinfra": {
121
+ "default": "flux-1-schnell",
122
+ "models": ["flux-1-schnell"],
123
+ "default_params": {
124
+ "num_inference_steps": 25,
125
+ "guidance_scale": 7.5,
126
+ "width": 1024,
127
+ "height": 1024
128
+ }
129
+ },
130
+ "aiforce": {
131
+ "default": "flux-1-pro",
132
+ "models": [
133
+ "stable-diffusion-xl-lightning",
134
+ "stable-diffusion-xl-base",
135
+ "flux-1-pro",
136
+ "ideogram",
137
+ "flux",
138
+ "flux-realism",
139
+ "flux-anime",
140
+ "flux-3d",
141
+ "flux-disney",
142
+ "flux-pixel",
143
+ "flux-4o",
144
+ "any-dark"
145
+ ],
146
+ "default_params": {
147
+ "width": 768,
148
+ "height": 768
149
+ }
150
+ },
151
+ "nexra": {
152
+ "default": "midjourney",
153
+ "models": [
154
+ "emi",
155
+ "stablediffusion-1-5",
156
+ "stablediffusion-2-1",
157
+ "sdxl-lora",
158
+ "dalle",
159
+ "dalle2",
160
+ "dalle-mini",
161
+ "flux",
162
+ "midjourney",
163
+ "dreamshaper-xl",
164
+ "dynavision-xl",
165
+ "juggernaut-xl",
166
+ "realism-engine-sdxl",
167
+ "sd-xl-base-1-0",
168
+ "animagine-xl-v3",
169
+ "sd-xl-base-inpainting",
170
+ "turbovision-xl",
171
+ "devlish-photorealism-sdxl",
172
+ "realvis-xl-v4"
173
+ ],
174
+ "default_params": {}
175
+ },
176
+ "freeai": {
177
+ "default": "dall-e-3",
178
+ "models": [
179
+ "dall-e-3",
180
+ "flux-pro-ultra",
181
+ "flux-pro",
182
+ "flux-pro-ultra-raw",
183
+ "flux-schnell",
184
+ "flux-realism",
185
+ "grok-2-aurora"
186
+ ],
187
+ "default_params": {
188
+ "size": "1024x1024",
189
+ "quality": "standard",
190
+ "style": "vivid"
191
+ }
192
+ },
193
+ "ninja": {
194
+ "default": "flux-dev",
195
+ "models": ["stable-diffusion", "flux-dev"],
196
+ "default_params": {}
197
+ },
198
+ "talkai": {
199
+ "default": "talkai-default",
200
+ "models": ["talkai-default"],
201
+ "default_params": {}
202
+ },
203
+ "piclumen": {
204
+ "default": "piclumen-default",
205
+ "models": ["piclumen-default"],
206
+ "default_params": {}
207
+ },
208
+ "artbit": {
209
+ "default": "sdxl",
210
+ "models": ["sdxl", "sd"],
211
+ "default_params": {
212
+ "selected_ratio": "1024"
213
+ }
214
+ },
215
+ "huggingface": {
216
+ "default": "stable-diffusion-xl-base-1-0",
217
+ "models": ["stable-diffusion-xl-base-1-0", "stable-diffusion-v1-5"],
218
+ "default_params": {
219
+ "guidance_scale": 7.5,
220
+ "num_inference_steps": 30
221
+ }
222
+ }
223
+ }
224
+
225
+ # Normalize model names to OpenAI-like format
226
+ for provider, info in PROVIDER_MODEL_INFO.items():
227
+ info["models"] = [model.replace("/", "-").replace(".", "-").replace("_", "-").lower() for model in info["models"]]
228
+ info["default"] = info["default"].replace("/", "-").replace(".", "-").replace("_", "-").lower()
229
+
230
+ # Add AIArta model info if available
231
+ if AIARTA_AVAILABLE:
232
+ PROVIDER_MODEL_INFO["aiarta"] = {
233
+ "default": "flux",
234
+ "models": [
235
+ "flux", "medieval", "vincent-van-gogh", "f-dev", "low-poly",
236
+ "dreamshaper-xl", "anima-pencil-xl", "biomech", "trash-polka",
237
+ "no-style", "cheyenne-xl", "chicano", "embroidery-tattoo",
238
+ "red-and-black", "fantasy-art", "watercolor", "dotwork",
239
+ "old-school-colored", "realistic-tattoo", "japanese-2",
240
+ "realistic-stock-xl", "f-pro", "revanimated", "katayama-mix-xl",
241
+ "sdxl-l", "cor-epica-xl", "anime-tattoo", "new-school",
242
+ "death-metal", "old-school", "juggernaut-xl", "photographic",
243
+ "sdxl-1-0", "graffiti", "mini-tattoo", "surrealism",
244
+ "neo-traditional", "on-limbs-black", "yamers-realistic-xl",
245
+ "pony-xl", "playground-xl", "anything-xl", "flame-design",
246
+ "kawaii", "cinematic-art", "professional", "flux-black-ink"
247
+ ],
248
+ "default_params": {
249
+ "negative_prompt": "blurry, deformed hands, ugly",
250
+ "guidance_scale": 7,
251
+ "num_inference_steps": 30,
252
+ "aspect_ratio": "1:1"
253
+ }
254
+ }
255
+
256
+ # Define Pydantic models for request and response validation (OpenAI-compatible)
257
+
258
+ class ImageSize(BaseModel):
259
+ width: int = Field(1024, description="Image width")
260
+ height: int = Field(1024, description="Image height")
261
+
262
+ class ImageGenerationRequest(BaseModel):
263
+ model: str = Field(..., description="The model to use for image generation")
264
+ prompt: str = Field(..., description="The prompt to generate images from")
265
+ n: Optional[int] = Field(1, description="Number of images to generate", ge=1, le=10)
266
+ size: Optional[str] = Field("1024x1024", description="Image size in format WIDTHxHEIGHT")
267
+ response_format: Optional[str] = Field("url", description="The format in which the generated images are returned", enum=["url", "b64_json"])
268
+ user: Optional[str] = Field(None, description="A unique identifier for the user")
269
+ style: Optional[str] = Field(None, description="Style for the generation")
270
+ quality: Optional[str] = Field(None, description="Quality level for the generation")
271
+ negative_prompt: Optional[str] = Field(None, description="What to avoid in the generated image")
272
 
273
+ class ImageData(BaseModel):
274
+ url: Optional[str] = Field(None, description="The URL of the generated image")
275
+ b64_json: Optional[str] = Field(None, description="Base64 encoded JSON string of the image")
276
+ revised_prompt: Optional[str] = Field(None, description="The prompt after any revisions")
277
+
278
+ class ImageGenerationResponse(BaseModel):
279
+ created: int = Field(..., description="Unix timestamp for when the request was created")
280
+ data: List[ImageData] = Field(..., description="List of generated images")
281
+
282
+ class ModelsListResponse(BaseModel):
283
+ object: str = Field("list", description="Object type")
284
+ data: List[Dict[str, Any]] = Field(..., description="List of available models")
285
+
286
+ class ErrorResponse(BaseModel):
287
+ error: Dict[str, Any] = Field(..., description="Error details")
288
+
289
+ # Error handling
290
+ class APIError(Exception):
291
+ def __init__(self, message, code=400, param=None, type="invalid_request_error"):
292
+ self.message = message
293
+ self.code = code
294
+ self.param = param
295
+ self.type = type
296
+
297
+ # Authentication dependency
298
+ async def verify_api_key(authorization: Optional[str] = Header(None)):
299
+ if authorization is None:
300
+ raise HTTPException(
301
+ status_code=401,
302
+ detail={
303
+ "error": {
304
+ "message": "No API key provided",
305
+ "type": "authentication_error",
306
+ "param": None,
307
+ "code": "no_api_key"
308
+ }
309
+ }
310
+ )
311
+
312
+ # Extract the key from the Authorization header
313
+ parts = authorization.split()
314
+ if len(parts) != 2 or parts[0].lower() != "bearer":
315
+ raise HTTPException(
316
+ status_code=401,
317
+ detail={
318
+ "error": {
319
+ "message": "Invalid authentication format. Use 'Bearer YOUR_API_KEY'",
320
+ "type": "authentication_error",
321
+ "param": None,
322
+ "code": "invalid_auth_format"
323
+ }
324
+ }
325
+ )
326
+
327
+ api_key = parts[1]
328
+
329
+ # Check if the API key is valid
330
+ # In production, you'd want to use a more secure method
331
+ if api_key not in API_KEYS:
332
+ raise HTTPException(
333
+ status_code=401,
334
+ detail={
335
+ "error": {
336
+ "message": "Invalid API key",
337
+ "type": "authentication_error",
338
+ "param": None,
339
+ "code": "invalid_api_key"
340
+ }
341
+ }
342
+ )
343
+
344
+ return api_key
 
 
 
 
 
 
 
 
 
345
 
346
+ # Find provider from model ID - updating this function to support provider/model format
347
+ def get_provider_for_model(model: str):
348
+ model = model.lower()
349
+
350
+ # Check if it's in the format 'provider/model'
351
+ if "/" in model:
352
+ provider_name, model_name = model.split("/", 1)
353
+ model_name = model_name.replace("/", "-").replace(".", "-").replace("_", "-").lower()
354
+
355
+ # Check if provider exists
356
+ if provider_name not in PROVIDER_MAP:
357
+ raise APIError(
358
+ message=f"Provider '{provider_name}' not found",
359
+ code=404,
360
+ type="provider_not_found"
 
 
 
361
  )
362
+
363
+ # Check if model exists for this provider
364
+ provider_models = PROVIDER_MODEL_INFO[provider_name]["models"]
365
+ if model_name not in provider_models:
366
+ # Try searching with less normalization - some providers might use underscore variants
367
+ original_model_name = model_name.replace("-", "_")
368
+ if original_model_name not in [m.replace("-", "_") for m in provider_models]:
369
+ raise APIError(
370
+ message=f"Model '{model_name}' not found for provider '{provider_name}'",
371
+ code=404,
372
+ type="model_not_found"
373
+ )
374
+
375
+ return provider_name, model_name
376
+
377
+ # If not in provider/model format, search all providers (original behavior)
378
+ for provider_name, provider_info in PROVIDER_MODEL_INFO.items():
379
+ # Check if this model belongs to this provider
380
+ if model in provider_info["models"] or model == provider_info["default"]:
381
+ return provider_name, model
382
+
383
+ # If no provider found, return error
384
+ raise APIError(
385
+ message=f"Model '{model}' not found",
386
+ code=404,
387
+ type="model_not_found"
388
+ )
389
+
390
+ # Health check endpoint
391
+ @app.get("/health", response_model=Dict[str, str])
392
+ async def health_check():
393
+ return {"status": "ok"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
394
 
395
+ # OpenAI-compatible endpoints
 
 
 
 
 
 
 
 
396
 
397
+ # List available models
398
+ @app.get("/v1/models", response_model=ModelsListResponse, dependencies=[Depends(verify_api_key)])
399
+ async def list_models():
400
+ models_data = []
401
+
402
+ for provider_name, provider_info in PROVIDER_MODEL_INFO.items():
403
+ provider_description = PROVIDER_MAP.get(provider_name, {}).get("description", "")
404
+
405
+ for model_name in provider_info["models"]:
406
+ is_default = model_name == provider_info["default"]
407
+
408
+ models_data.append({
409
+ "id": model_name,
410
+ "object": "model",
411
+ "created": int(time.time()),
412
+ "owned_by": provider_name,
413
+ "permission": [],
414
+ "root": model_name,
415
+ "parent": None,
416
+ "description": f"{provider_description} - {'Default model' if is_default else 'Alternative model'}",
417
+ })
418
+
419
+ return {
420
+ "object": "list",
421
+ "data": models_data
422
+ }
423
+
424
+ # Get model information
425
+ @app.get("/v1/models/{model_id}", dependencies=[Depends(verify_api_key)])
426
+ async def get_model(model_id: str):
427
  try:
428
+ provider_name, model = get_provider_for_model(model_id)
429
+ provider_description = PROVIDER_MAP.get(provider_name, {}).get("description", "")
430
+
431
+ return {
432
+ "id": model,
433
+ "object": "model",
434
+ "created": int(time.time()),
435
+ "owned_by": provider_name,
436
+ "permission": [],
437
+ "root": model,
438
+ "parent": None,
439
+ "description": provider_description
440
+ }
441
+ except APIError as e:
442
+ return JSONResponse(
443
+ status_code=e.code,
444
+ content={"error": {"message": e.message, "type": e.type, "param": e.param, "code": e.code}}
445
+ )
446
 
447
+ # Generate images
448
+ @app.post("/v1/images/generations", response_model=ImageGenerationResponse, dependencies=[Depends(verify_api_key)])
449
+ async def create_image(request: ImageGenerationRequest, background_tasks: BackgroundTasks):
 
 
 
 
 
 
 
 
 
 
 
 
450
  try:
451
+ # Get provider for the requested model
452
+ provider_name, model = get_provider_for_model(request.model)
453
+ provider_class = PROVIDER_MAP[provider_name]["class"]
454
+
455
+ # Parse size
456
+ width, height = 1024, 1024
457
+ if request.size:
458
+ try:
459
+ size_parts = request.size.split("x")
460
+ if len(size_parts) == 2:
461
+ width, height = int(size_parts[0]), int(size_parts[1])
462
+ else:
463
+ width = height = int(size_parts[0])
464
+ except:
465
+ pass
466
+
467
+ # Create task ID
468
+ task_id = str(uuid.uuid4())
469
+ IMAGE_STORAGE[task_id] = {"status": "processing", "images": []}
470
+
471
+ # Get default params and update with user-provided values
472
+ default_params = PROVIDER_MODEL_INFO[provider_name].get("default_params", {}).copy()
473
+
474
+ # Add additional parameters from the request
475
+ if request.negative_prompt:
476
+ default_params["negative_prompt"] = request.negative_prompt
477
+ if request.quality:
478
+ default_params["quality"] = request.quality
479
+ if request.style:
480
+ default_params["style"] = request.style
481
+
482
+ # Update size parameters
483
+ default_params["width"] = width
484
+ default_params["height"] = height
485
+
486
+ # Function to generate images in the background
487
+ async def generate_images():
488
+ try:
489
+ # Initialize provider based on the provider name
490
+ if provider_name == "freeai":
491
+ provider_instance = provider_class(model=model)
492
+ elif provider_name == "deepinfra" and "-flux-" in model:
493
+ # Convert back to model format expected by provider
494
+ original_model = "black-forest-labs/FLUX-1-schnell"
495
+ provider_instance = provider_class(model=original_model)
496
+ else:
497
+ provider_instance = provider_class()
498
+
499
+ # Generate images with provider-specific parameters
500
+ # Each provider may have different parameter requirements
501
+ if provider_name == "aiforce":
502
+ images = await provider_instance.generate(
503
+ prompt=request.prompt,
504
+ amount=request.n,
505
+ model=model.replace("-", "_"), # Convert back to format used by provider
506
+ width=default_params.get("width", 768),
507
+ height=default_params.get("height", 768),
508
+ seed=default_params.get("seed", None)
509
+ )
510
+ elif provider_name == "deepinfra":
511
+ images = await provider_instance.generate(
512
+ prompt=request.prompt,
513
+ amount=request.n,
514
+ num_inference_steps=default_params.get("num_inference_steps", 25),
515
+ guidance_scale=default_params.get("guidance_scale", 7.5),
516
+ width=default_params.get("width", 1024),
517
+ height=default_params.get("height", 1024),
518
+ seed=default_params.get("seed", None)
519
+ )
520
+ elif provider_name == "nexra":
521
+ # Convert back to original model format
522
+ original_model = model.replace("-", "_")
523
+ images = await provider_instance.generate(
524
+ prompt=request.prompt,
525
+ amount=request.n,
526
+ model=original_model,
527
+ additional_params=default_params
528
+ )
529
+ elif provider_name == "freeai":
530
+ images = await provider_instance.generate(
531
+ prompt=request.prompt,
532
+ amount=request.n,
533
+ size=f"{width}x{height}",
534
+ quality=default_params.get("quality", "standard"),
535
+ style=default_params.get("style", "vivid")
536
+ )
537
+ elif provider_name == "ninja":
538
+ images = await provider_instance.generate(
539
+ prompt=request.prompt,
540
+ amount=request.n,
541
+ model=model.replace("-", "_")
542
+ )
543
+ elif provider_name == "artbit":
544
+ images = await provider_instance.generate(
545
+ prompt=request.prompt,
546
+ amount=request.n,
547
+ caption_model=model,
548
+ selected_ratio=default_params.get("selected_ratio", "1024"),
549
+ negative_prompt=default_params.get("negative_prompt", "")
550
+ )
551
+ elif provider_name == "huggingface":
552
+ # Convert from dash format to slash format for HF
553
+ original_model = model.replace("-", "/")
554
+ images = await provider_instance.generate(
555
+ prompt=request.prompt,
556
+ amount=request.n,
557
+ model=original_model,
558
+ guidance_scale=default_params.get("guidance_scale", 7.5),
559
+ negative_prompt=default_params.get("negative_prompt", None),
560
+ num_inference_steps=default_params.get("num_inference_steps", 30),
561
+ width=width,
562
+ height=height
563
+ )
564
+ elif provider_name == "aiarta" and AIARTA_AVAILABLE:
565
+ images = await provider_instance.generate(
566
+ prompt=request.prompt,
567
+ amount=request.n,
568
+ model=model,
569
+ negative_prompt=default_params.get("negative_prompt", "blurry, deformed hands, ugly"),
570
+ guidance_scale=default_params.get("guidance_scale", 7),
571
+ num_inference_steps=default_params.get("num_inference_steps", 30),
572
+ aspect_ratio=default_params.get("aspect_ratio", "1:1")
573
+ )
574
+ else:
575
+ # Default case for providers with simpler interfaces
576
+ images = await provider_instance.generate(
577
+ prompt=request.prompt,
578
+ amount=request.n
579
+ )
580
+
581
+ # Process and store the generated images
582
+ for i, img in enumerate(images):
583
+ # Handle both URL strings and binary data
584
+ if isinstance(img, str):
585
+ # For providers that return URLs instead of binary data
586
+ async with aiohttp.ClientSession() as session:
587
+ async with session.get(img) as resp:
588
+ resp.raise_for_status()
589
+ img_data = await resp.read()
590
+ else:
591
+ img_data = img
592
+
593
+ # Generate a unique URL for the image
594
+ image_id = f"{i}"
595
+ image_url = f"/v1/images/{task_id}/{image_id}"
596
+
597
+ # Store image data based on requested format
598
+ if request.response_format == "b64_json":
599
+ encoded = base64.b64encode(img_data).decode('utf-8')
600
+ IMAGE_STORAGE[task_id]["images"].append({
601
+ "image_id": image_id,
602
+ "data": encoded,
603
+ "url": image_url,
604
+ })
605
+ else: # Default to URL
606
+ IMAGE_STORAGE[task_id]["images"].append({
607
+ "image_id": image_id,
608
+ "data": img_data,
609
+ "url": image_url,
610
+ })
611
+
612
+ # Update task status
613
+ IMAGE_STORAGE[task_id]["status"] = "completed"
614
+ except Exception as e:
615
+ # Handle errors
616
+ IMAGE_STORAGE[task_id]["status"] = "failed"
617
+ IMAGE_STORAGE[task_id]["error"] = str(e)
618
+
619
+ # Start background task
620
+ background_tasks.add_task(generate_images)
621
+
622
+ # Immediate response with task details
623
+ # For compatibility, we need to structure this like OpenAI's response
624
+ created_timestamp = int(time.time())
625
+
626
+ # Wait briefly to allow the background task to start
627
+ await asyncio.sleep(0.1)
628
+
629
+ # Check if the task failed immediately
630
+ if IMAGE_STORAGE[task_id]["status"] == "failed":
631
+ error_message = IMAGE_STORAGE[task_id].get("error", "Unknown error")
632
+ raise APIError(message=f"Image generation failed: {error_message}", code=500)
633
+
634
+ # Prepare response data
635
+ image_data = []
636
+ for i in range(request.n):
637
+ if request.response_format == "b64_json":
638
+ image_data.append({
639
+ "b64_json": "", # Will be filled in by the background task
640
+ "revised_prompt": request.prompt
641
+ })
642
  else:
643
+ image_data.append({
644
+ "url": f"/v1/images/{task_id}/{i}",
645
+ "revised_prompt": request.prompt
646
+ })
647
+
648
+ return {
649
+ "created": created_timestamp,
650
+ "data": image_data
651
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
652
 
653
+ except APIError as e:
654
+ return JSONResponse(
655
+ status_code=e.code,
656
+ content={"error": {"message": e.message, "type": e.type, "param": e.param, "code": e.code}}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
657
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
658
  except Exception as e:
659
+ return JSONResponse(
660
+ status_code=500,
661
+ content={"error": {"message": str(e), "type": "server_error", "param": None, "code": 500}}
 
 
 
 
 
 
662
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
663
 
664
+ # Image retrieval endpoint
665
+ @app.get("/v1/images/{task_id}/{image_id}", dependencies=[Depends(verify_api_key)])
666
+ async def get_image(task_id: str, image_id: str):
667
+ if task_id not in IMAGE_STORAGE:
668
+ return JSONResponse(
669
+ status_code=404,
670
+ content={"error": {"message": f"Image not found", "type": "not_found_error"}}
671
+ )
672
+
673
+ task_data = IMAGE_STORAGE[task_id]
674
+
675
+ if task_data["status"] == "failed":
676
+ return JSONResponse(
677
+ status_code=500,
678
+ content={"error": {"message": f"Image generation failed: {task_data.get('error', 'Unknown error')}", "type": "processing_error"}}
679
+ )
680
+
681
+ if task_data["status"] == "processing":
682
+ return JSONResponse(
683
+ status_code=202,
684
+ content={"status": "processing", "message": "Image is still being generated"}
685
+ )
686
+
687
+ # Find the requested image
688
+ for img in task_data["images"]:
689
+ if img["image_id"] == image_id:
690
+ # If it's stored as base64, it's already in the right format
691
+ if isinstance(img["data"], str):
692
+ return JSONResponse(content={"b64_json": img["data"]})
693
+
694
+ # If it's binary data, return as an image stream
695
+ return StreamingResponse(
696
+ io.BytesIO(img["data"]),
697
+ media_type="image/png"
698
+ )
699
+
700
+ return JSONResponse(
701
+ status_code=404,
702
+ content={"error": {"message": f"Image not found", "type": "not_found_error"}}
703
+ )
704
+
705
+ # Legacy endpoints for backward compatibility
706
+ @app.get("/providers")
707
+ async def list_providers_legacy():
708
+ providers = {}
709
+ for provider_name, provider_info in PROVIDER_MAP.items():
710
+ model_info = PROVIDER_MODEL_INFO.get(provider_name, {})
711
+ providers[provider_name] = {
712
+ "description": provider_info.get("description", ""),
713
+ "default_model": model_info.get("default", "default"),
714
+ "models": model_info.get("models", ["default"]),
715
+ "default_params": model_info.get("default_params", {})
716
+ }
717
+ return providers
718
+
719
+ # Main entry point
720
  if __name__ == "__main__":
721
+ uvicorn.run(
722
+ "app:app",
723
+ host="0.0.0.0",
724
+ port=8000,
725
+ reload=True
726
+ )