File size: 2,535 Bytes
002fca8
2cd7197
9e3ea07
c53513a
d707be1
 
b916cdf
 
c53513a
 
2cd7197
 
 
 
 
 
 
 
9e3ea07
0a9fba8
a81da59
 
 
 
0a9fba8
c53513a
a81da59
 
 
 
d707be1
 
 
a81da59
b916cdf
a81da59
0a9fba8
 
a81da59
 
 
 
 
c53513a
a81da59
0a9fba8
c53513a
a061413
 
 
 
c53513a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
609a4fb
ca7a52b
c53513a
f381f25
36c21dd
ca7a52b
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware  # Importa il middleware CORS
from pydantic import BaseModel
from huggingface_hub import InferenceClient
from datetime import datetime


app = FastAPI()
client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

class InputData(BaseModel):
    input: str
    temperature: float = 0.2
    max_new_tokens: int = 30000
    top_p: float = 0.95
    repetition_penalty: float = 1.0

def format_prompt(message, history):
    prompt = "<s>"
    for user_prompt, bot_response in history:
        prompt += f"[INST] {user_prompt} [/INST]"
        prompt += f" {bot_response}</s> "
    now = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")
    prompt += f"[{now}] [INST] {message} [/INST]"

    return prompt

@app.post("/Genera")
def read_root(request: Request, input_data: InputData):
    input_text = input_data.input
    temperature = input_data.temperature
    max_new_tokens = input_data.max_new_tokens
    top_p = input_data.top_p
    repetition_penalty = input_data.repetition_penalty

    history = []  # Puoi definire la history se necessario
    generated_response = generate(input_text, history, temperature, max_new_tokens, top_p, repetition_penalty)
    return {"response": generated_response}

@app.get("/")
def read_general():
    return {"response": "Benvenuto. Per maggiori info vai a /docs"}  # Restituisci la risposta generata come JSON

def generate(prompt, history, temperature=0.2, max_new_tokens=30000, top_p=0.95, repetition_penalty=1.0):
    temperature = float(temperature)
    if temperature < 1e-2:
        temperature = 1e-2
    top_p = float(top_p)

    generate_kwargs = dict(
        temperature=temperature,
        max_new_tokens=max_new_tokens,
        top_p=top_p,
        repetition_penalty=repetition_penalty,
        do_sample=True,
        seed=42,
    )
    formatted_prompt = format_prompt(prompt, history)
    output = client.text_generation(formatted_prompt, **generate_kwargs, stream=False, details=False)
    return output

    #stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=False, return_full_text=False)
    # Accumula l'output in una lista
    #output_list = []
    #for response in stream:
    #    output_list.append(response.token.text)
    #return iter(output_list)  # Restituisci la lista come un iteratore