OjciecTadeusz commited on
Commit
af7dbec
·
verified ·
1 Parent(s): 3b698bf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -151
app.py CHANGED
@@ -1,155 +1,56 @@
1
-
2
- import gradio as gr
3
- from fastapi import FastAPI, Request, HTTPException
4
- from fastapi.responses import JSONResponse
5
- import datetime
6
- import requests
7
- import os
8
- import logging
9
- import toml
10
  import uvicorn
11
 
12
- # Initialize FastAPI
13
- app = FastAPI()
14
-
15
- # Configure logging
16
- logging.basicConfig(level=logging.INFO)
17
- logger = logging.getLogger(__name__)
18
-
19
- # Load config
20
- # with open("config.toml") as f:
21
- # config = toml.load(f)
22
-
23
- #API_URL = os.getenv('API_URL')
24
- #API_TOKEN = os.getenv('API_TOKEN')
25
- # API_URL = 'https://ojciectadeusz-fastapi-inference-qwen2-5-coder-32b-instruct.hf.space/generate'
26
- API_URL = 'https://ojciectadeusz-fastapi-inference-qwen2-5-coder-32-a0ab504.hf.space/generate'
27
- headers = {
28
- "Authorization": f"Bearer {os.getenv('HF_API_TOKEN')}",
29
- "Content-Type": "application/json"
30
- }
31
-
32
- def format_chat_response(response_text, prompt_tokens=0, completion_tokens=0):
33
- return {
34
- "id": f"chatcmpl-{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}",
35
- "object": "chat.completion",
36
- "created": int(datetime.datetime.now().timestamp()),
37
- "model": "Qwen/Qwen2.5-Coder-32B",
38
- "choices": [{
39
- "index": 0,
40
- "message": {
41
- "role": "assistant",
42
- "content": response_text
43
- },
44
- "finish_reason": "stop"
45
- }],
46
- "usage": {
47
- "prompt_tokens": prompt_tokens,
48
- "completion_tokens": completion_tokens,
49
- "total_tokens": prompt_tokens + completion_tokens
50
- }
51
- }
52
-
53
- async def query_model(payload):
54
- try:
55
- response = requests.post(API_URL, headers=headers, json=payload)
56
- response.raise_for_status()
57
- return response.json()
58
- except requests.exceptions.RequestException as e:
59
- logger.error(f"Request failed: {e}")
60
- raise HTTPException(status_code=500, detail=str(e))
61
 
62
- @app.get("/status")
63
- async def status():
64
- try:
65
-
66
- response_text = "it's working"
67
- return JSONResponse(content=format_chat_response(response_text))
68
- except Exception as e:
69
- logger.error(f"Status check failed: {e}")
70
- raise HTTPException(status_code=500, detail=str(e))
71
-
72
- @app.post("/generate")
73
- async def chat_completion(request: Request):
74
- try:
75
- data = await request.json()
76
- messages = data.get("messages", [])
77
- if not messages:
78
- raise HTTPException(status_code=400, detail="Messages are required")
79
-
80
- payload = {
81
- "inputs": {
82
- "messages": messages
83
- },
84
- "parameters": {
85
- "max_new_tokens": data.get("max_tokens", 2048),
86
- "temperature": data.get("temperature", 0.7),
87
- "top_p": data.get("top_p", 0.95),
88
- "do_sample": True
89
- }
90
- }
91
-
92
- response = await query_model(payload)
93
- print(response)
94
- if isinstance(response, dict) and "error" in response:
95
- raise HTTPException(status_code=500, detail=response["error"])
96
-
97
- response_text = response[0]["generated_text"]
98
-
99
- return JSONResponse(content=format_chat_response(response_text))
100
- except HTTPException as e:
101
- logger.error(f"Chat completion failed: {e.detail}")
102
- raise e
103
- except Exception as e:
104
- logger.error(f"Unexpected error: {e}")
105
- raise HTTPException(status_code=500, detail=str(e))
106
-
107
- def generate_response(messages):
108
- payload = {
109
- "inputs": {
110
- "messages": messages
111
- },
112
- "parameters": {
113
- "max_new_tokens": 2048,
114
- "temperature": 0.7,
115
- "top_p": 0.95,
116
- "do_sample": True
117
- }
118
- }
119
-
120
- try:
121
- response = requests.post(API_URL, headers=headers, json=payload)
122
- response.raise_for_status()
123
- result = response.json()
124
-
125
- if isinstance(result, dict) and "error" in result:
126
- return f"Error: {result['error']}"
127
-
128
- return result[0]["generated_text"]
129
- except requests.exceptions.RequestException as e:
130
- logger.error(f"Request failed: {e}")
131
- return f"Error: {e}"
132
-
133
- def chat_interface(messages):
134
- chat_history = []
135
- for message in messages:
136
- try:
137
- response = generate_response([{"role": "user", "content": message}])
138
- chat_history.append({"role": "user", "content": message})
139
- chat_history.append({"role": "assistant", "content": response})
140
- except Exception as e:
141
- chat_history.append({"role": "user", "content": message})
142
- chat_history.append({"role": "assistant", "content": f"Error: {str(e)}"})
143
- return chat_history
144
-
145
- # Create Gradio interface
146
- def gradio_app():
147
- return gr.ChatInterface(chat_interface, type="messages")
148
-
149
- # Mount both FastAPI and Gradio
150
- app = gr.mount_gradio_app(app, gradio_app(), path="/")
151
 
152
- # For running with uvicorn directly
153
- if __name__ == "__main__":
154
- import uvicorn
155
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI
2
+ from pydantic import BaseModel
3
+ from huggingface_hub import InferenceClient
 
 
 
 
 
 
4
  import uvicorn
5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
+ app = FastAPI()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
+ # client = InferenceClient("nvidia/Llama-3.1-Nemotron-70B-Instruct-HF")
10
+ client = InferenceClient("Qwen/Qwen2.5-Coder-32B-Instruct")
11
+
12
+ class Item(BaseModel):
13
+ prompt: str
14
+ history: list
15
+ system_prompt: str
16
+ temperature: float = 0.0
17
+ max_new_tokens: int = 1048
18
+ top_p: float = 0.15
19
+ repetition_penalty: float = 1.0
20
+
21
+ def format_prompt(message, history):
22
+ prompt = "<s>"
23
+ for user_prompt, bot_response in history:
24
+ prompt += f"[INST] {user_prompt} [/INST]"
25
+ prompt += f" {bot_response}</s> "
26
+ prompt += f"[INST] {message} [/INST]"
27
+ return prompt
28
+
29
+ def generate(item: Item):
30
+ temperature = float(item.temperature)
31
+ if temperature < 1e-2:
32
+ temperature = 1e-2
33
+ top_p = float(item.top_p)
34
+
35
+ generate_kwargs = dict(
36
+ temperature=temperature,
37
+ max_new_tokens=item.max_new_tokens,
38
+ top_p=top_p,
39
+ repetition_penalty=item.repetition_penalty,
40
+ do_sample=True,
41
+ seed=42,
42
+ )
43
+
44
+ formatted_prompt = format_prompt(f"{item.system_prompt}, {item.prompt}", item.history)
45
+ stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
46
+ output = ""
47
+
48
+ for response in stream:
49
+ output += response.token.text
50
+ return output
51
+
52
+ @app.post("/generate/")
53
+ async def generate_text(item: Item):
54
+ data = {"response": generate(item)}
55
+ print(data)
56
+ return data