Jamiiwej2903 commited on
Commit
cf39022
·
verified ·
1 Parent(s): 5b3b82b

Rename app.py to main.py

Browse files
Files changed (1) hide show
  1. app.py → main.py +26 -30
app.py → main.py RENAMED
@@ -1,7 +1,8 @@
1
  from fastapi import FastAPI
2
  from pydantic import BaseModel
3
  from huggingface_hub import InferenceClient
4
- from fastapi.responses import StreamingResponse
 
5
 
6
  app = FastAPI()
7
 
@@ -24,35 +25,30 @@ def format_prompt(message, history):
24
  prompt += f"[INST] {message} [/INST]"
25
  return prompt
26
 
27
- async def generate_stream(item: Item):
28
- try:
29
- temperature = max(float(item.temperature), 1e-2)
30
- top_p = float(item.top_p)
31
-
32
- generate_kwargs = dict(
33
- temperature=temperature,
34
- max_new_tokens=item.max_new_tokens,
35
- top_p=top_p,
36
- repetition_penalty=item.repetition_penalty,
37
- do_sample=True,
38
- seed=42,
39
- )
40
-
41
- formatted_prompt = format_prompt(f"{item.system_prompt}, {item.prompt}", item.history)
42
- stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
43
-
44
- for response in stream:
45
- yield response.token.text
46
- except Exception as e:
47
- print(f"Error in generate_stream: {e}")
48
- finally:
49
- if 'stream' in locals():
50
- stream.close()
51
 
52
  @app.post("/generate/")
53
  async def generate_text(item: Item):
54
- try:
55
- return StreamingResponse(generate_stream(item), media_type="text/plain")
56
- except Exception as e:
57
- print(f"Error in generate_text: {e}")
58
- return {"error": str(e)}
 
1
  from fastapi import FastAPI
2
  from pydantic import BaseModel
3
  from huggingface_hub import InferenceClient
4
+ import uvicorn
5
+
6
 
7
  app = FastAPI()
8
 
 
25
  prompt += f"[INST] {message} [/INST]"
26
  return prompt
27
 
28
+ def generate(item: Item):
29
+ temperature = float(item.temperature)
30
+ if temperature < 1e-2:
31
+ temperature = 1e-2
32
+ top_p = float(item.top_p)
33
+
34
+ generate_kwargs = dict(
35
+ temperature=temperature,
36
+ max_new_tokens=item.max_new_tokens,
37
+ top_p=top_p,
38
+ repetition_penalty=item.repetition_penalty,
39
+ do_sample=True,
40
+ seed=42,
41
+ )
42
+
43
+ formatted_prompt = format_prompt(f"{item.system_prompt}, {item.prompt}", item.history)
44
+ stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
45
+ output = ""
46
+
47
+ for response in stream:
48
+ output += response.token.text
49
+ return output
 
 
50
 
51
  @app.post("/generate/")
52
  async def generate_text(item: Item):
53
+ return {"response": generate(item)}
54
+