MatteoScript commited on
Commit
c550535
·
1 Parent(s): 33e9df8

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +49 -46
main.py CHANGED
@@ -9,6 +9,7 @@ import requests
9
  import os
10
  import socket
11
 
 
12
  app = FastAPI()
13
  client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")
14
 
@@ -34,20 +35,7 @@ class InputImage(BaseModel):
34
  cfg: int = 5
35
  seed: int = 453666937
36
 
37
- def format_prompt(message, history):
38
- prompt = "<s>"
39
- #with open('Manuale.txt', 'r') as file:
40
- # manual_content = file.read()
41
- # prompt += f"Leggi questo manuale dopo ti farò delle domande: {manual_content}"
42
-
43
- for user_prompt, bot_response in history:
44
- prompt += f"[INST] {user_prompt} [/INST]"
45
- prompt += f" {bot_response}</s> "
46
- now = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")
47
- prompt += f"[{now}] [INST] {message} [/INST]"
48
-
49
- return prompt
50
-
51
  @app.post("/Genera")
52
  def read_root(request: Request, input_data: InputData):
53
  input_text = input_data.input
@@ -55,40 +43,15 @@ def read_root(request: Request, input_data: InputData):
55
  max_new_tokens = input_data.max_new_tokens
56
  top_p = input_data.top_p
57
  repetition_penalty = input_data.repetition_penalty
58
-
59
- history = [] # Puoi definire la history se necessario
60
  generated_response = generate(input_text, history, temperature, max_new_tokens, top_p, repetition_penalty)
61
  return {"response": generated_response}
62
 
63
- @app.post("/Immagine")
64
- def generate_image(request: Request, input_data: InputImage):
65
- client = Client("https://openskyml-fast-sdxl-stable-diffusion-xl.hf.space/--replicas/545b5tw7n/")
66
- result = client.predict(
67
- input_data.input,
68
- input_data.negativePrompt,
69
- input_data.steps,
70
- input_data.cfg,
71
- 1024,
72
- 1024,
73
- input_data.seed,
74
- fn_index=0
75
- )
76
- image_url = result
77
- with open(image_url, 'rb') as img_file:
78
- img_binary = img_file.read()
79
- img_base64 = base64.b64encode(img_binary).decode('utf-8')
80
- return {"response": img_base64}
81
-
82
- @app.get("/")
83
- def read_general():
84
- return {"response": "Benvenuto. Per maggiori info vai a /docs"} # Restituisci la risposta generata come JSON
85
-
86
  def generate(prompt, history, temperature=0.2, max_new_tokens=30000, top_p=0.95, repetition_penalty=1.0):
87
  temperature = float(temperature)
88
  if temperature < 1e-2:
89
  temperature = 1e-2
90
  top_p = float(top_p)
91
-
92
  generate_kwargs = dict(
93
  temperature=temperature,
94
  max_new_tokens=max_new_tokens,
@@ -100,10 +63,50 @@ def generate(prompt, history, temperature=0.2, max_new_tokens=30000, top_p=0.95,
100
  formatted_prompt = format_prompt(prompt, history)
101
  output = client.text_generation(formatted_prompt, **generate_kwargs, stream=False, details=False)
102
  return output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
 
104
- #stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=False, return_full_text=False)
105
- # Accumula l'output in una lista
106
- #output_list = []
107
- #for response in stream:
108
- # output_list.append(response.token.text)
109
- #return iter(output_list) # Restituisci la lista come un iteratore
 
9
  import os
10
  import socket
11
 
12
+ #--------------------------------------------------- Definizione Server FAST API ------------------------------------------------------
13
  app = FastAPI()
14
  client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")
15
 
 
35
  cfg: int = 5
36
  seed: int = 453666937
37
 
38
+ #--------------------------------------------------- Generazione TESTO ------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  @app.post("/Genera")
40
  def read_root(request: Request, input_data: InputData):
41
  input_text = input_data.input
 
43
  max_new_tokens = input_data.max_new_tokens
44
  top_p = input_data.top_p
45
  repetition_penalty = input_data.repetition_penalty
46
+ history = []
 
47
  generated_response = generate(input_text, history, temperature, max_new_tokens, top_p, repetition_penalty)
48
  return {"response": generated_response}
49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  def generate(prompt, history, temperature=0.2, max_new_tokens=30000, top_p=0.95, repetition_penalty=1.0):
51
  temperature = float(temperature)
52
  if temperature < 1e-2:
53
  temperature = 1e-2
54
  top_p = float(top_p)
 
55
  generate_kwargs = dict(
56
  temperature=temperature,
57
  max_new_tokens=max_new_tokens,
 
63
  formatted_prompt = format_prompt(prompt, history)
64
  output = client.text_generation(formatted_prompt, **generate_kwargs, stream=False, details=False)
65
  return output
66
+
67
+ def format_prompt(message, history):
68
+ prompt = "<s>"
69
+ for user_prompt, bot_response in history:
70
+ prompt += f"[INST] {user_prompt} [/INST]"
71
+ prompt += f" {bot_response}</s> "
72
+ now = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")
73
+ prompt += f"[{now}] [INST] {message} [/INST]"
74
+ return prompt
75
+
76
+ #--------------------------------------------------- Generazione IMMAGINE ------------------------------------------------------
77
+ @app.post("/Immagine")
78
+ def generate_image(request: Request, input_data: InputImage):
79
+ client = Client("https://openskyml-fast-sdxl-stable-diffusion-xl.hf.space/--replicas/545b5tw7n/")
80
+ max_attempts = 10
81
+ attempt = 0
82
+ while attempt < max_attempts:
83
+ try:
84
+ result = client.predict(
85
+ input_data.input,
86
+ input_data.negativePrompt,
87
+ input_data.steps,
88
+ input_data.cfg,
89
+ 1024,
90
+ 1024,
91
+ input_data.seed,
92
+ fn_index=0
93
+ )
94
+ image_url = result
95
+ with open(image_url, 'rb') as img_file:
96
+ img_binary = img_file.read()
97
+ img_base64 = base64.b64encode(img_binary).decode('utf-8')
98
+ return {"response": img_base64}
99
+ except requests.exceptions.HTTPError as e:
100
+ if e.response.status_code == 500:
101
+ attempt += 1
102
+ if attempt < max_attempts:
103
+ continue
104
+ else:
105
+ return {"error": "Errore interno del server persistente"}
106
+ else:
107
+ return {"error": "Errore diverso da 500"}
108
+ return {"error": "Numero massimo di tentativi raggiunto"}
109
 
110
+ @app.get("/")
111
+ def read_general():
112
+ return {"response": "Benvenuto. Per maggiori info: https://matteoscript-fastapi.hf.space/docs"}