BICORP commited on
Commit
464f46b
·
verified ·
1 Parent(s): 9c2adc8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +243 -1
app.py CHANGED
@@ -1,3 +1,245 @@
1
  import gradio as gr
 
 
 
 
 
 
2
 
3
- gr.load("models/EleutherAI/gpt-neox-20b").launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ from llama_cpp import Llama
3
+ import torch
4
+ import os
5
+ from accelerate import Accelerator
6
+ import tensorflow as tf # Import TensorFlow
7
+ import numpy as np # For handling input data
8
 
9
+ # Set device for PyTorch
10
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
+ print("device set to:", device)
12
+
13
+ # Initialize the accelerator
14
+ accelerator = Accelerator()
15
+
16
+ class LocalInferenceClient:
17
+ def __init__(self, model_name: str, model_path: str):
18
+ """
19
+ Initialize the inference client with the model.
20
+
21
+ Args:
22
+ model_name (str): The name of the model.
23
+ model_path (str): The path to the model file or directory.
24
+ """
25
+ self.model_name = model_name
26
+ self.model_path = model_path
27
+
28
+ # Initialize the Llama model specifically for gguf
29
+ self.model = Llama(model_path=model_path, n_ctx=2048, n_threads=8, n_gpu_layers=5)
30
+
31
+ # Move the model to the appropriate device
32
+ self.model = accelerator.prepare(self.model)
33
+
34
+ # Load the TensorFlow Lite model
35
+ self.tflite_interpreter = tf.lite.Interpreter(model_path='model.tflite')
36
+ self.tflite_interpreter.allocate_tensors()
37
+
38
+ # Get input and output tensors
39
+ self.input_details = self.tflite_interpreter.get_input_details()
40
+ self.output_details = self.tflite_interpreter.get_output_details()
41
+
42
+ def text_generation(self, prompt: str, max_new_tokens: int, temperature: float, top_p: float) -> str:
43
+ """
44
+ Generate text based on the provided prompt.
45
+
46
+ Args:
47
+ prompt (str): The input prompt.
48
+ max_new_tokens (int): The maximum number of tokens to generate.
49
+ temperature (float): Sampling temperature.
50
+ top_p (float): Nucleus sampling probability.
51
+
52
+ Returns:
53
+ str: The generated text.
54
+ """
55
+ # Use the Llama model for text generation
56
+ response = self.model.create_chat_completion(
57
+ messages=[{"role": "user", "content": prompt}],
58
+ max_tokens=max_new_tokens,
59
+ temperature=temperature,
60
+ top_p=top_p
61
+ )
62
+
63
+ # Print the response to understand its structure
64
+ print("Response from model:", response)
65
+
66
+ # Access the content correctly based on the response structure
67
+ if 'choices' in response and len(response['choices']) > 0:
68
+ return response['choices'][0]['message']['content'] # Access the content key
69
+ else:
70
+ return "⚠️ Error: Unexpected response format."
71
+
72
+ def run_tflite_model(self, input_data: np.ndarray) -> np.ndarray:
73
+ """
74
+ Run inference using the TensorFlow Lite model.
75
+
76
+ Args:
77
+ input_data (np.ndarray): Input data for the model.
78
+
79
+ Returns:
80
+ np.ndarray: Output data from the model.
81
+ """
82
+ # Set the input tensor
83
+ self.tflite_interpreter.set_tensor(self.input_details[0]['index'], input_data)
84
+
85
+ # Run the model
86
+ self.tflite_interpreter.invoke()
87
+
88
+ # Get the output tensor
89
+ output_data = self.tflite_interpreter.get_tensor(self.output_details[0]['index'])
90
+ return output_data
91
+
92
+ # Specify the model paths for gguf models
93
+ model_configs = {
94
+ "Lake 1 Chat": {
95
+ "path": r"C:\Users\BI Corp\Videos\main\Lake-1-chat\Lake-1-Chat.gguf",
96
+ "specs": """
97
+ ## Lake 1 Chat Specifications
98
+ - **Architecture**: Lake 1
99
+ - **Parameters**: 14B
100
+ - **Capabilities**: Reasoning, logical inference, coding
101
+ - **Intended Use**: Suitable for complex reasoning tasks, math, coding problems, and detailed conversations.
102
+ """
103
+ },
104
+ "Lake 1 Mini": {
105
+ "path": r"C:\Users\BI Corp\Videos\main\Lake-1-mini\Lake-1-Mini.gguf",
106
+ "specs": """
107
+ ## Lake 1 Mini Specifications
108
+ - **Architecture**: Lake 1
109
+ - **Parameters**: 6B
110
+ - **Capabilities**: Quick responses, compact model
111
+ - **Intended Use**: Great for fast responses and lightweight use cases.
112
+ """
113
+ },
114
+ "Lake 1 Base": {
115
+ "path": r"C:\Users\BI Corp\Videos\main\Lake-1-base\Lake-1-Base.gguf",
116
+ "specs": """
117
+ ## Lake 1 Base Specifications
118
+ - **Architecture**: Lake 1
119
+ - **Parameters**: 12B
120
+ - **Capabilities**: Balanced performance between speed and accuracy
121
+ - **Intended Use**: Best for use cases requiring a balance of speed and detail in responses.
122
+ """
123
+ },
124
+ }
125
+
126
+ # Set up a dictionary mapping model names to their clients
127
+ clients = {name: LocalInferenceClient(name, config['path']) for name, config in model_configs.items()}
128
+
129
+ # Presets for performance/quality tradeoffs
130
+ presets = {
131
+ "Lake 1 Mini": {
132
+ "Fast": {"max_new_tokens": 100, "temperature": 1.0, "top_p": 0.9},
133
+ "Normal": {"max_new_tokens": 200, "temperature": 0.7, "top_p": 0.95},
134
+ "Quality": {"max_new_tokens": 300, "temperature": 0.5, "top_p": 0.90},
135
+ },
136
+ "Lake 1 Base": {
137
+ "Fast": {"max_new_tokens": 100, "temperature": 1.0, "top_p": 0.9},
138
+ "Normal": {"max_new_tokens": 200, "temperature": 0.7, "top_p": 0.95},
139
+ "Quality": {"max_new_tokens": 300, "temperature": 0.5, "top_p": 0.90},
140
+ },
141
+ "Lake 1 Chat": {
142
+ "Fast": {"max_new_tokens": 100, "temperature": 1.0, "top_p": 0.9},
143
+ "Normal": {"max_new_tokens": 200, "temperature": 0.7, "top_p": 0.95},
144
+ "Quality": {"max_new_tokens": 300, "temperature": 0.5, "top_p": 0.90},
145
+ }
146
+ }
147
+
148
+ # A system prompt for the model
149
+ system_messages = {
150
+ "Lake 1 Chat": "You are Lake 1 Chat, a powerful open-source reasoning model. Think carefully and answer step by step.",
151
+ "Lake 1 Mini": "You are Lake 1 Mini, a powerful open-source compact model. Think and answer fast.",
152
+ "Lake 1 Base": "You are Lake 1 Base, a powerful open-source original model. Think and answer step by step but balance speed and accuracy.",
153
+ }
154
+
155
+ def generate_response(message: str, model_name: str, preset: str) -> str:
156
+ """
157
+ Generate a response based on the user's message.
158
+
159
+ Args:
160
+ message (str): The user's message.
161
+ model_name (str): The name of the model to use.
162
+ preset (str): The performance preset to apply.
163
+
164
+ Returns:
165
+ str: The generated response.
166
+ """
167
+ client = clients[model_name]
168
+ params = presets[model_name][preset]
169
+ system_msg = system_messages[model_name]
170
+ prompt = f"{system_msg}\n\n:User {message}\nAssistant:"
171
+ return client.text_generation(
172
+ prompt,
173
+ max_new_tokens=params["max_new_tokens"],
174
+ temperature=params["temperature"],
175
+ top_p=params["top_p"]
176
+ )
177
+
178
+ def handle_chat(message: str, history: list, model: str, preset: str) -> str:
179
+ """
180
+ Handle the chat interaction.
181
+
182
+ Args:
183
+ message (str): The user's message.
184
+ history (list): The conversation history.
185
+ model (str): The model to use.
186
+ preset (str): The performance preset.
187
+
188
+ Returns:
189
+ str: The generated response.
190
+ """
191
+ try:
192
+ return generate_response(message, model, preset)
193
+ except Exception as e:
194
+ return f"⚠️ Error: {str(e)}"
195
+
196
+ with gr.Blocks(title="BI CORP AI Assistant", theme="soft") as demo:
197
+ gr.Markdown("# <center>Lake AI Assistant</center>")
198
+ gr.Markdown("### <center>Powered by Lake 1 Chat</center>")
199
+
200
+ with gr.Row():
201
+ with gr.Column(scale=1):
202
+ model_dropdown = gr.Dropdown(
203
+ label="🤖 Model Selection",
204
+ choices=list(clients.keys()),
205
+ value="Lake 1 Chat",
206
+ interactive=True
207
+ )
208
+ preset_dropdown = gr.Dropdown(
209
+ label="⚙️ Performance Preset",
210
+ choices=["Fast", "Normal", "Quality"],
211
+ value="Normal",
212
+ interactive=True
213
+ )
214
+ model_info_md = gr.Markdown(
215
+ value=model_configs["Lake 1 Chat"]["specs"],
216
+ label="📝 Model Specifications"
217
+ )
218
+
219
+ with gr.Column(scale=3):
220
+ chat_interface = gr.ChatInterface(
221
+ fn=handle_chat,
222
+ additional_inputs=[model_dropdown, preset_dropdown],
223
+ examples=[["Explain quantum computing", "Lake 1 Chat", "Normal"]],
224
+ chatbot=gr.Chatbot(height=600, label="💬 Conversation", show_copy_button=True),
225
+ textbox=gr.Textbox(placeholder="Type your message...", container=False, scale=7, autofocus=True),
226
+ submit_btn=gr.Button("🚀 Send", variant="primary")
227
+ )
228
+
229
+ clear_button = gr.Button("🧹 Clear History")
230
+ clear_button.click(
231
+ fn=lambda: None,
232
+ inputs=[],
233
+ outputs=chat_interface.chatbot,
234
+ queue=False
235
+ )
236
+
237
+ model_dropdown.change(
238
+ fn=lambda model: model_configs[model]["specs"],
239
+ inputs=model_dropdown,
240
+ outputs=model_info_md,
241
+ queue=False
242
+ )
243
+
244
+ if __name__ == "__main__":
245
+ demo.launch(server_port=7865)