Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -1,3 +1,245 @@
|
|
1 |
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
+
from llama_cpp import Llama
|
3 |
+
import torch
|
4 |
+
import os
|
5 |
+
from accelerate import Accelerator
|
6 |
+
import tensorflow as tf # Import TensorFlow
|
7 |
+
import numpy as np # For handling input data
|
8 |
|
9 |
+
# Set device for PyTorch
|
10 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
11 |
+
print("device set to:", device)
|
12 |
+
|
13 |
+
# Initialize the accelerator
|
14 |
+
accelerator = Accelerator()
|
15 |
+
|
16 |
+
class LocalInferenceClient:
|
17 |
+
def __init__(self, model_name: str, model_path: str):
|
18 |
+
"""
|
19 |
+
Initialize the inference client with the model.
|
20 |
+
|
21 |
+
Args:
|
22 |
+
model_name (str): The name of the model.
|
23 |
+
model_path (str): The path to the model file or directory.
|
24 |
+
"""
|
25 |
+
self.model_name = model_name
|
26 |
+
self.model_path = model_path
|
27 |
+
|
28 |
+
# Initialize the Llama model specifically for gguf
|
29 |
+
self.model = Llama(model_path=model_path, n_ctx=2048, n_threads=8, n_gpu_layers=5)
|
30 |
+
|
31 |
+
# Move the model to the appropriate device
|
32 |
+
self.model = accelerator.prepare(self.model)
|
33 |
+
|
34 |
+
# Load the TensorFlow Lite model
|
35 |
+
self.tflite_interpreter = tf.lite.Interpreter(model_path='model.tflite')
|
36 |
+
self.tflite_interpreter.allocate_tensors()
|
37 |
+
|
38 |
+
# Get input and output tensors
|
39 |
+
self.input_details = self.tflite_interpreter.get_input_details()
|
40 |
+
self.output_details = self.tflite_interpreter.get_output_details()
|
41 |
+
|
42 |
+
def text_generation(self, prompt: str, max_new_tokens: int, temperature: float, top_p: float) -> str:
|
43 |
+
"""
|
44 |
+
Generate text based on the provided prompt.
|
45 |
+
|
46 |
+
Args:
|
47 |
+
prompt (str): The input prompt.
|
48 |
+
max_new_tokens (int): The maximum number of tokens to generate.
|
49 |
+
temperature (float): Sampling temperature.
|
50 |
+
top_p (float): Nucleus sampling probability.
|
51 |
+
|
52 |
+
Returns:
|
53 |
+
str: The generated text.
|
54 |
+
"""
|
55 |
+
# Use the Llama model for text generation
|
56 |
+
response = self.model.create_chat_completion(
|
57 |
+
messages=[{"role": "user", "content": prompt}],
|
58 |
+
max_tokens=max_new_tokens,
|
59 |
+
temperature=temperature,
|
60 |
+
top_p=top_p
|
61 |
+
)
|
62 |
+
|
63 |
+
# Print the response to understand its structure
|
64 |
+
print("Response from model:", response)
|
65 |
+
|
66 |
+
# Access the content correctly based on the response structure
|
67 |
+
if 'choices' in response and len(response['choices']) > 0:
|
68 |
+
return response['choices'][0]['message']['content'] # Access the content key
|
69 |
+
else:
|
70 |
+
return "⚠️ Error: Unexpected response format."
|
71 |
+
|
72 |
+
def run_tflite_model(self, input_data: np.ndarray) -> np.ndarray:
|
73 |
+
"""
|
74 |
+
Run inference using the TensorFlow Lite model.
|
75 |
+
|
76 |
+
Args:
|
77 |
+
input_data (np.ndarray): Input data for the model.
|
78 |
+
|
79 |
+
Returns:
|
80 |
+
np.ndarray: Output data from the model.
|
81 |
+
"""
|
82 |
+
# Set the input tensor
|
83 |
+
self.tflite_interpreter.set_tensor(self.input_details[0]['index'], input_data)
|
84 |
+
|
85 |
+
# Run the model
|
86 |
+
self.tflite_interpreter.invoke()
|
87 |
+
|
88 |
+
# Get the output tensor
|
89 |
+
output_data = self.tflite_interpreter.get_tensor(self.output_details[0]['index'])
|
90 |
+
return output_data
|
91 |
+
|
92 |
+
# Specify the model paths for gguf models
|
93 |
+
model_configs = {
|
94 |
+
"Lake 1 Chat": {
|
95 |
+
"path": r"C:\Users\BI Corp\Videos\main\Lake-1-chat\Lake-1-Chat.gguf",
|
96 |
+
"specs": """
|
97 |
+
## Lake 1 Chat Specifications
|
98 |
+
- **Architecture**: Lake 1
|
99 |
+
- **Parameters**: 14B
|
100 |
+
- **Capabilities**: Reasoning, logical inference, coding
|
101 |
+
- **Intended Use**: Suitable for complex reasoning tasks, math, coding problems, and detailed conversations.
|
102 |
+
"""
|
103 |
+
},
|
104 |
+
"Lake 1 Mini": {
|
105 |
+
"path": r"C:\Users\BI Corp\Videos\main\Lake-1-mini\Lake-1-Mini.gguf",
|
106 |
+
"specs": """
|
107 |
+
## Lake 1 Mini Specifications
|
108 |
+
- **Architecture**: Lake 1
|
109 |
+
- **Parameters**: 6B
|
110 |
+
- **Capabilities**: Quick responses, compact model
|
111 |
+
- **Intended Use**: Great for fast responses and lightweight use cases.
|
112 |
+
"""
|
113 |
+
},
|
114 |
+
"Lake 1 Base": {
|
115 |
+
"path": r"C:\Users\BI Corp\Videos\main\Lake-1-base\Lake-1-Base.gguf",
|
116 |
+
"specs": """
|
117 |
+
## Lake 1 Base Specifications
|
118 |
+
- **Architecture**: Lake 1
|
119 |
+
- **Parameters**: 12B
|
120 |
+
- **Capabilities**: Balanced performance between speed and accuracy
|
121 |
+
- **Intended Use**: Best for use cases requiring a balance of speed and detail in responses.
|
122 |
+
"""
|
123 |
+
},
|
124 |
+
}
|
125 |
+
|
126 |
+
# Set up a dictionary mapping model names to their clients
|
127 |
+
clients = {name: LocalInferenceClient(name, config['path']) for name, config in model_configs.items()}
|
128 |
+
|
129 |
+
# Presets for performance/quality tradeoffs
|
130 |
+
presets = {
|
131 |
+
"Lake 1 Mini": {
|
132 |
+
"Fast": {"max_new_tokens": 100, "temperature": 1.0, "top_p": 0.9},
|
133 |
+
"Normal": {"max_new_tokens": 200, "temperature": 0.7, "top_p": 0.95},
|
134 |
+
"Quality": {"max_new_tokens": 300, "temperature": 0.5, "top_p": 0.90},
|
135 |
+
},
|
136 |
+
"Lake 1 Base": {
|
137 |
+
"Fast": {"max_new_tokens": 100, "temperature": 1.0, "top_p": 0.9},
|
138 |
+
"Normal": {"max_new_tokens": 200, "temperature": 0.7, "top_p": 0.95},
|
139 |
+
"Quality": {"max_new_tokens": 300, "temperature": 0.5, "top_p": 0.90},
|
140 |
+
},
|
141 |
+
"Lake 1 Chat": {
|
142 |
+
"Fast": {"max_new_tokens": 100, "temperature": 1.0, "top_p": 0.9},
|
143 |
+
"Normal": {"max_new_tokens": 200, "temperature": 0.7, "top_p": 0.95},
|
144 |
+
"Quality": {"max_new_tokens": 300, "temperature": 0.5, "top_p": 0.90},
|
145 |
+
}
|
146 |
+
}
|
147 |
+
|
148 |
+
# A system prompt for the model
|
149 |
+
system_messages = {
|
150 |
+
"Lake 1 Chat": "You are Lake 1 Chat, a powerful open-source reasoning model. Think carefully and answer step by step.",
|
151 |
+
"Lake 1 Mini": "You are Lake 1 Mini, a powerful open-source compact model. Think and answer fast.",
|
152 |
+
"Lake 1 Base": "You are Lake 1 Base, a powerful open-source original model. Think and answer step by step but balance speed and accuracy.",
|
153 |
+
}
|
154 |
+
|
155 |
+
def generate_response(message: str, model_name: str, preset: str) -> str:
|
156 |
+
"""
|
157 |
+
Generate a response based on the user's message.
|
158 |
+
|
159 |
+
Args:
|
160 |
+
message (str): The user's message.
|
161 |
+
model_name (str): The name of the model to use.
|
162 |
+
preset (str): The performance preset to apply.
|
163 |
+
|
164 |
+
Returns:
|
165 |
+
str: The generated response.
|
166 |
+
"""
|
167 |
+
client = clients[model_name]
|
168 |
+
params = presets[model_name][preset]
|
169 |
+
system_msg = system_messages[model_name]
|
170 |
+
prompt = f"{system_msg}\n\n:User {message}\nAssistant:"
|
171 |
+
return client.text_generation(
|
172 |
+
prompt,
|
173 |
+
max_new_tokens=params["max_new_tokens"],
|
174 |
+
temperature=params["temperature"],
|
175 |
+
top_p=params["top_p"]
|
176 |
+
)
|
177 |
+
|
178 |
+
def handle_chat(message: str, history: list, model: str, preset: str) -> str:
|
179 |
+
"""
|
180 |
+
Handle the chat interaction.
|
181 |
+
|
182 |
+
Args:
|
183 |
+
message (str): The user's message.
|
184 |
+
history (list): The conversation history.
|
185 |
+
model (str): The model to use.
|
186 |
+
preset (str): The performance preset.
|
187 |
+
|
188 |
+
Returns:
|
189 |
+
str: The generated response.
|
190 |
+
"""
|
191 |
+
try:
|
192 |
+
return generate_response(message, model, preset)
|
193 |
+
except Exception as e:
|
194 |
+
return f"⚠️ Error: {str(e)}"
|
195 |
+
|
196 |
+
with gr.Blocks(title="BI CORP AI Assistant", theme="soft") as demo:
|
197 |
+
gr.Markdown("# <center>Lake AI Assistant</center>")
|
198 |
+
gr.Markdown("### <center>Powered by Lake 1 Chat</center>")
|
199 |
+
|
200 |
+
with gr.Row():
|
201 |
+
with gr.Column(scale=1):
|
202 |
+
model_dropdown = gr.Dropdown(
|
203 |
+
label="🤖 Model Selection",
|
204 |
+
choices=list(clients.keys()),
|
205 |
+
value="Lake 1 Chat",
|
206 |
+
interactive=True
|
207 |
+
)
|
208 |
+
preset_dropdown = gr.Dropdown(
|
209 |
+
label="⚙️ Performance Preset",
|
210 |
+
choices=["Fast", "Normal", "Quality"],
|
211 |
+
value="Normal",
|
212 |
+
interactive=True
|
213 |
+
)
|
214 |
+
model_info_md = gr.Markdown(
|
215 |
+
value=model_configs["Lake 1 Chat"]["specs"],
|
216 |
+
label="📝 Model Specifications"
|
217 |
+
)
|
218 |
+
|
219 |
+
with gr.Column(scale=3):
|
220 |
+
chat_interface = gr.ChatInterface(
|
221 |
+
fn=handle_chat,
|
222 |
+
additional_inputs=[model_dropdown, preset_dropdown],
|
223 |
+
examples=[["Explain quantum computing", "Lake 1 Chat", "Normal"]],
|
224 |
+
chatbot=gr.Chatbot(height=600, label="💬 Conversation", show_copy_button=True),
|
225 |
+
textbox=gr.Textbox(placeholder="Type your message...", container=False, scale=7, autofocus=True),
|
226 |
+
submit_btn=gr.Button("🚀 Send", variant="primary")
|
227 |
+
)
|
228 |
+
|
229 |
+
clear_button = gr.Button("🧹 Clear History")
|
230 |
+
clear_button.click(
|
231 |
+
fn=lambda: None,
|
232 |
+
inputs=[],
|
233 |
+
outputs=chat_interface.chatbot,
|
234 |
+
queue=False
|
235 |
+
)
|
236 |
+
|
237 |
+
model_dropdown.change(
|
238 |
+
fn=lambda model: model_configs[model]["specs"],
|
239 |
+
inputs=model_dropdown,
|
240 |
+
outputs=model_info_md,
|
241 |
+
queue=False
|
242 |
+
)
|
243 |
+
|
244 |
+
if __name__ == "__main__":
|
245 |
+
demo.launch(server_port=7865)
|