MatteoScript commited on
Commit
a81da59
·
1 Parent(s): 83ab480

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +18 -9
main.py CHANGED
@@ -7,20 +7,29 @@ client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")
7
 
8
  class InputData(BaseModel):
9
  input: str
 
 
 
 
10
 
11
  def format_prompt(message, history):
12
- prompt = "<s>"
13
- for user_prompt, bot_response in history:
14
- prompt += f"[INST] {user_prompt} [/INST]"
15
- prompt += f" {bot_response}</s> "
16
- prompt += f"[INST] {message} [/INST]"
17
- return prompt
18
-
19
- @app.post("/Genera/")
20
  def read_root(request: Request, input_data: InputData):
21
  input_text = input_data.input
 
 
 
 
 
22
  history = [] # Puoi definire la history se necessario
23
- generated_response = generate(input_text, history)
24
  return {"response": generated_response}
25
 
26
  @app.get("/")
 
7
 
8
  class InputData(BaseModel):
9
  input: str
10
+ temperature: float = 0.2
11
+ max_new_tokens: int = 30000
12
+ top_p: float = 0.95
13
+ repetition_penalty: float = 1.0
14
 
15
  def format_prompt(message, history):
16
+ prompt = "<s>"
17
+ for user_prompt, bot_response in history:
18
+ prompt += f"[INST] {user_prompt} [/INST]"
19
+ prompt += f" {bot_response}</s> "
20
+ prompt += f"[INST] {message} [/INST]"
21
+ return prompt
22
+
23
+ @app.post("/Genera")
24
  def read_root(request: Request, input_data: InputData):
25
  input_text = input_data.input
26
+ temperature = input_data.temperature
27
+ max_new_tokens = input_data.max_new_tokens
28
+ top_p = input_data.top_p
29
+ repetition_penalty = input_data.repetition_penalty
30
+
31
  history = [] # Puoi definire la history se necessario
32
+ generated_response = generate(input_text, history, temperature, max_new_tokens, top_p, repetition_penalty)
33
  return {"response": generated_response}
34
 
35
  @app.get("/")