Spaces:
Runtime error
Runtime error
import gradio as gr | |
from llama_cpp import Llama | |
import torch | |
import os | |
from accelerate import Accelerator | |
import tensorflow as tf # Import TensorFlow | |
import numpy as np # For handling input data | |
# Set device for PyTorch | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
print("device set to:", device) | |
# Initialize the accelerator | |
accelerator = Accelerator() | |
class LocalInferenceClient: | |
def __init__(self, model_name: str, model_path: str): | |
""" | |
Initialize the inference client with the model. | |
Args: | |
model_name (str): The name of the model. | |
model_path (str): The path to the model file or directory. | |
""" | |
self.model_name = model_name | |
self.model_path = model_path | |
# Initialize the Llama model specifically for gguf | |
self.model = Llama(model_path=model_path, n_ctx=2048, n_threads=8, n_gpu_layers=5) | |
# Move the model to the appropriate device | |
self.model = accelerator.prepare(self.model) | |
# Load the TensorFlow Lite model | |
self.tflite_interpreter = tf.lite.Interpreter(model_path='model.tflite') | |
self.tflite_interpreter.allocate_tensors() | |
# Get input and output tensors | |
self.input_details = self.tflite_interpreter.get_input_details() | |
self.output_details = self.tflite_interpreter.get_output_details() | |
def text_generation(self, prompt: str, max_new_tokens: int, temperature: float, top_p: float) -> str: | |
""" | |
Generate text based on the provided prompt. | |
Args: | |
prompt (str): The input prompt. | |
max_new_tokens (int): The maximum number of tokens to generate. | |
temperature (float): Sampling temperature. | |
top_p (float): Nucleus sampling probability. | |
Returns: | |
str: The generated text. | |
""" | |
# Use the Llama model for text generation | |
response = self.model.create_chat_completion( | |
messages=[{"role": "user", "content": prompt}], | |
max_tokens=max_new_tokens, | |
temperature=temperature, | |
top_p=top_p | |
) | |
# Print the response to understand its structure | |
print("Response from model:", response) | |
# Access the content correctly based on the response structure | |
if 'choices' in response and len(response['choices']) > 0: | |
return response['choices'][0]['message']['content'] # Access the content key | |
else: | |
return "⚠️ Error: Unexpected response format." | |
def run_tflite_model(self, input_data: np.ndarray) -> np.ndarray: | |
""" | |
Run inference using the TensorFlow Lite model. | |
Args: | |
input_data (np.ndarray): Input data for the model. | |
Returns: | |
np.ndarray: Output data from the model. | |
""" | |
# Set the input tensor | |
self.tflite_interpreter.set_tensor(self.input_details[0]['index'], input_data) | |
# Run the model | |
self.tflite_interpreter.invoke() | |
# Get the output tensor | |
output_data = self.tflite_interpreter.get_tensor(self.output_details[0]['index']) | |
return output_data | |
# Specify the model paths for gguf models | |
model_configs = { | |
"Test": { | |
"path": r"./test-model.gguf", | |
"specs": """ | |
## Lake 1 Chat Specifications | |
- **Architecture**: Test | |
- **Parameters**: IDK | |
- **Capabilities**: test | |
- **Intended Use**: test | |
""" | |
} | |
} | |
# Set up a dictionary mapping model names to their clients | |
clients = {name: LocalInferenceClient(name, config['path']) for name, config in model_configs.items()} | |
# Presets for performance/quality tradeoffs | |
presets = { | |
"Test": { | |
"Fast": {"max_new_tokens": 100, "temperature": 1.0, "top_p": 0.9}, | |
"Normal": {"max_new_tokens": 200, "temperature": 0.7, "top_p": 0.95}, | |
"Quality": {"max_new_tokens": 300, "temperature": 0.5, "top_p": 0.90}, | |
} | |
} | |
# A system prompt for the model | |
system_messages = { | |
"Test": "You are Lake 1 Chat, a powerful open-source reasoning model. Think carefully and answer step by step.", | |
} | |
def generate_response(message: str, model_name: str, preset: str) -> str: | |
""" | |
Generate a response based on the user's message. | |
Args: | |
message (str): The user's message. | |
model_name (str): The name of the model to use. | |
preset (str): The performance preset to apply. | |
Returns: | |
str: The generated response. | |
""" | |
client = clients[model_name] | |
params = presets[model_name][preset] | |
system_msg = system_messages[model_name] | |
prompt = f"{system_msg}\n\n:User {message}\nAssistant:" | |
return client.text_generation( | |
prompt, | |
max_new_tokens=params["max_new_tokens"], | |
temperature=params["temperature"], | |
top_p=params["top_p"] | |
) | |
def handle_chat(message: str, history: list, model: str, preset: str) -> str: | |
""" | |
Handle the chat interaction. | |
Args: | |
message (str): The user's message. | |
history (list): The conversation history. | |
model (str): The model to use. | |
preset (str): The performance preset. | |
Returns: | |
str: The generated response. | |
""" | |
try: | |
return generate_response(message, model, preset) | |
except Exception as e: | |
return f"⚠️ Error: {str(e)}" | |
with gr.Blocks(title="BI CORP AI Assistant", theme="soft") as demo: | |
gr.Markdown("# <center>Lake AI Assistant</center>") | |
gr.Markdown("### <center>Powered by Lake 1 Chat</center>") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
model_dropdown = gr.Dropdown( | |
label="🤖 Model Selection", | |
choices=list(clients.keys()), | |
value="Lake 1 Chat", | |
interactive=True | |
) | |
preset_dropdown = gr.Dropdown( | |
label="⚙️ Performance Preset", | |
choices=["Fast", "Normal", "Quality"], | |
value="Normal", | |
interactive=True | |
) | |
model_info_md = gr.Markdown( | |
value=model_configs["Test"]["specs"], | |
label="📝 Model Specifications" | |
) | |
with gr.Column(scale=3): | |
chat_interface = gr.ChatInterface( | |
fn=handle_chat, | |
additional_inputs=[model_dropdown, preset_dropdown], | |
examples=[["Explain quantum computing", "Test", "Normal"]], | |
chatbot=gr.Chatbot(height=600, label="💬 Conversation", show_copy_button=True), | |
textbox=gr.Textbox(placeholder="Type your message...", container=False, scale=7, autofocus=True), | |
submit_btn=gr.Button("🚀 Send", variant="primary") | |
) | |
clear_button = gr.Button("🧹 Clear History") | |
clear_button.click( | |
fn=lambda: None, | |
inputs=[], | |
outputs=chat_interface.chatbot, | |
queue=False | |
) | |
model_dropdown.change( | |
fn=lambda model: model_configs[model]["specs"], | |
inputs=model_dropdown, | |
outputs=model_info_md, | |
queue=False | |
) | |
if __name__ == "__main__": | |
demo.launch(server_port=7865) |