OjciecTadeusz commited on
Commit
d6b0a9b
·
verified ·
1 Parent(s): 1fb73a8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +61 -55
app.py CHANGED
@@ -1,34 +1,28 @@
1
  import gradio as gr
2
- from transformers import AutoTokenizer, AutoModelForCausalLM
3
- import torch
4
- import json
5
  from fastapi import FastAPI, Request
6
  from fastapi.responses import JSONResponse
7
  import datetime
 
 
 
8
  import asyncio
9
 
10
  # Initialize FastAPI
11
  app = FastAPI()
12
 
13
- # Load model and tokenizer
14
- model_name = "Qwen/Qwen2.5-Coder-32B"
15
- tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
16
-
17
- # Configure model loading with specific parameters
18
- model = AutoModelForCausalLM.from_pretrained(
19
- model_name,
20
- device_map="auto",
21
- trust_remote_code=True,
22
- torch_dtype=torch.float16,
23
- low_cpu_mem_usage=True
24
- )
25
 
26
- def format_chat_response(response_text, prompt_tokens, completion_tokens):
27
  return {
28
  "id": f"chatcmpl-{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}",
29
  "object": "chat.completion",
30
  "created": int(datetime.datetime.now().timestamp()),
31
- "model": model_name,
32
  "choices": [{
33
  "index": 0,
34
  "message": {
@@ -44,37 +38,48 @@ def format_chat_response(response_text, prompt_tokens, completion_tokens):
44
  }
45
  }
46
 
 
 
 
 
47
  @app.post("/v1/chat/completions")
48
  async def chat_completion(request: Request):
49
  try:
50
  data = await request.json()
51
  messages = data.get("messages", [])
52
 
53
- # Convert messages to model input format
54
- prompt = tokenizer.apply_chat_template(
55
- messages,
56
- tokenize=False,
57
- add_generation_prompt=True
58
- )
 
 
 
 
 
 
59
 
60
- # Count prompt tokens
61
- prompt_tokens = len(tokenizer.encode(prompt))
62
 
63
- # Generate response
64
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
65
- outputs = model.generate(
66
- **inputs,
67
- max_new_tokens=data.get("max_tokens", 2048),
68
- temperature=data.get("temperature", 0.7),
69
- top_p=data.get("top_p", 0.95),
70
- do_sample=True
71
- )
72
 
73
- response_text = tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
74
- completion_tokens = len(tokenizer.encode(response_text))
75
 
76
  return JSONResponse(
77
- content=format_chat_response(response_text, prompt_tokens, completion_tokens)
 
 
 
 
 
 
78
  )
79
  except Exception as e:
80
  return JSONResponse(
@@ -82,26 +87,27 @@ async def chat_completion(request: Request):
82
  content={"error": str(e)}
83
  )
84
 
85
- # Synchronous function to generate response
86
  def generate_response(messages):
87
- # Convert messages to model input format
88
- prompt = tokenizer.apply_chat_template(
89
- messages,
90
- tokenize=False,
91
- add_generation_prompt=True
92
- )
 
 
 
 
 
 
 
 
93
 
94
- # Generate response
95
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
96
- outputs = model.generate(
97
- **inputs,
98
- max_new_tokens=2048,
99
- temperature=0.7,
100
- top_p=0.95,
101
- do_sample=True
102
- )
103
 
104
- return tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
105
 
106
  # Gradio interface for testing
107
  def chat_interface(message, history):
@@ -126,7 +132,7 @@ def chat_interface(message, history):
126
  interface = gr.ChatInterface(
127
  chat_interface,
128
  title="Qwen2.5-Coder-32B Chat",
129
- description="Chat with Qwen2.5-Coder-32B model. This Space also provides a /v1/chat/completions endpoint."
130
  )
131
 
132
  # Mount both FastAPI and Gradio
 
1
  import gradio as gr
 
 
 
2
  from fastapi import FastAPI, Request
3
  from fastapi.responses import JSONResponse
4
  import datetime
5
+ import requests
6
+ import os
7
+ import json
8
  import asyncio
9
 
10
  # Initialize FastAPI
11
  app = FastAPI()
12
 
13
+ # Configuration
14
+ API_URL = "https://api-inference.huggingface.co/models/Qwen/Qwen2.5-Coder-32B"
15
+ headers = {
16
+ "Authorization": f"Bearer {os.getenv('HF_API_TOKEN')}",
17
+ "Content-Type": "application/json"
18
+ }
 
 
 
 
 
 
19
 
20
+ def format_chat_response(response_text, prompt_tokens=0, completion_tokens=0):
21
  return {
22
  "id": f"chatcmpl-{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}",
23
  "object": "chat.completion",
24
  "created": int(datetime.datetime.now().timestamp()),
25
+ "model": "Qwen/Qwen2.5-Coder-32B",
26
  "choices": [{
27
  "index": 0,
28
  "message": {
 
38
  }
39
  }
40
 
41
+ async def query_model(payload):
42
+ response = requests.post(API_URL, headers=headers, json=payload)
43
+ return response.json()
44
+
45
  @app.post("/v1/chat/completions")
46
  async def chat_completion(request: Request):
47
  try:
48
  data = await request.json()
49
  messages = data.get("messages", [])
50
 
51
+ # Prepare the payload for the Inference API
52
+ payload = {
53
+ "inputs": {
54
+ "messages": messages
55
+ },
56
+ "parameters": {
57
+ "max_new_tokens": data.get("max_tokens", 2048),
58
+ "temperature": data.get("temperature", 0.7),
59
+ "top_p": data.get("top_p", 0.95),
60
+ "do_sample": True
61
+ }
62
+ }
63
 
64
+ # Get response from model
65
+ response = await query_model(payload)
66
 
67
+ if isinstance(response, dict) and "error" in response:
68
+ return JSONResponse(
69
+ status_code=500,
70
+ content={"error": response["error"]}
71
+ )
 
 
 
 
72
 
73
+ response_text = response[0]["generated_text"]
 
74
 
75
  return JSONResponse(
76
+ content=format_chat_response(
77
+ response_text,
78
+ # Note: Actual token counts would need to be calculated differently
79
+ # or obtained from the API response if available
80
+ prompt_tokens=0,
81
+ completion_tokens=0
82
+ )
83
  )
84
  except Exception as e:
85
  return JSONResponse(
 
87
  content={"error": str(e)}
88
  )
89
 
90
+ # Synchronous function to generate response for Gradio
91
  def generate_response(messages):
92
+ payload = {
93
+ "inputs": {
94
+ "messages": messages
95
+ },
96
+ "parameters": {
97
+ "max_new_tokens": 2048,
98
+ "temperature": 0.7,
99
+ "top_p": 0.95,
100
+ "do_sample": True
101
+ }
102
+ }
103
+
104
+ response = requests.post(API_URL, headers=headers, json=payload)
105
+ result = response.json()
106
 
107
+ if isinstance(result, dict) and "error" in result:
108
+ return f"Error: {result['error']}"
 
 
 
 
 
 
 
109
 
110
+ return result[0]["generated_text"]
111
 
112
  # Gradio interface for testing
113
  def chat_interface(message, history):
 
132
  interface = gr.ChatInterface(
133
  chat_interface,
134
  title="Qwen2.5-Coder-32B Chat",
135
+ description="Chat with Qwen2.5-Coder-32B model via Hugging Face Inference API. This Space also provides a /v1/chat/completions endpoint."
136
  )
137
 
138
  # Mount both FastAPI and Gradio