BICORP's picture
Update app.py
4dd8683 verified
import gradio as gr
from llama_cpp import Llama
import torch
import os
from accelerate import Accelerator
import tensorflow as tf # Import TensorFlow
import numpy as np # For handling input data
# Set device for PyTorch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device set to:", device)
# Initialize the accelerator
accelerator = Accelerator()
class LocalInferenceClient:
def __init__(self, model_name: str, model_path: str):
"""
Initialize the inference client with the model.
Args:
model_name (str): The name of the model.
model_path (str): The path to the model file or directory.
"""
self.model_name = model_name
self.model_path = model_path
# Initialize the Llama model specifically for gguf
self.model = Llama(model_path=model_path, n_ctx=2048, n_threads=8, n_gpu_layers=5)
# Move the model to the appropriate device
self.model = accelerator.prepare(self.model)
# Load the TensorFlow Lite model
self.tflite_interpreter = tf.lite.Interpreter(model_path='model.tflite')
self.tflite_interpreter.allocate_tensors()
# Get input and output tensors
self.input_details = self.tflite_interpreter.get_input_details()
self.output_details = self.tflite_interpreter.get_output_details()
def text_generation(self, prompt: str, max_new_tokens: int, temperature: float, top_p: float) -> str:
"""
Generate text based on the provided prompt.
Args:
prompt (str): The input prompt.
max_new_tokens (int): The maximum number of tokens to generate.
temperature (float): Sampling temperature.
top_p (float): Nucleus sampling probability.
Returns:
str: The generated text.
"""
# Use the Llama model for text generation
response = self.model.create_chat_completion(
messages=[{"role": "user", "content": prompt}],
max_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p
)
# Print the response to understand its structure
print("Response from model:", response)
# Access the content correctly based on the response structure
if 'choices' in response and len(response['choices']) > 0:
return response['choices'][0]['message']['content'] # Access the content key
else:
return "⚠️ Error: Unexpected response format."
def run_tflite_model(self, input_data: np.ndarray) -> np.ndarray:
"""
Run inference using the TensorFlow Lite model.
Args:
input_data (np.ndarray): Input data for the model.
Returns:
np.ndarray: Output data from the model.
"""
# Set the input tensor
self.tflite_interpreter.set_tensor(self.input_details[0]['index'], input_data)
# Run the model
self.tflite_interpreter.invoke()
# Get the output tensor
output_data = self.tflite_interpreter.get_tensor(self.output_details[0]['index'])
return output_data
# Specify the model paths for gguf models
model_configs = {
"Test": {
"path": r"./test-model.gguf",
"specs": """
## Lake 1 Chat Specifications
- **Architecture**: Test
- **Parameters**: IDK
- **Capabilities**: test
- **Intended Use**: test
"""
}
}
# Set up a dictionary mapping model names to their clients
clients = {name: LocalInferenceClient(name, config['path']) for name, config in model_configs.items()}
# Presets for performance/quality tradeoffs
presets = {
"Test": {
"Fast": {"max_new_tokens": 100, "temperature": 1.0, "top_p": 0.9},
"Normal": {"max_new_tokens": 200, "temperature": 0.7, "top_p": 0.95},
"Quality": {"max_new_tokens": 300, "temperature": 0.5, "top_p": 0.90},
}
}
# A system prompt for the model
system_messages = {
"Test": "You are Lake 1 Chat, a powerful open-source reasoning model. Think carefully and answer step by step.",
}
def generate_response(message: str, model_name: str, preset: str) -> str:
"""
Generate a response based on the user's message.
Args:
message (str): The user's message.
model_name (str): The name of the model to use.
preset (str): The performance preset to apply.
Returns:
str: The generated response.
"""
client = clients[model_name]
params = presets[model_name][preset]
system_msg = system_messages[model_name]
prompt = f"{system_msg}\n\n:User {message}\nAssistant:"
return client.text_generation(
prompt,
max_new_tokens=params["max_new_tokens"],
temperature=params["temperature"],
top_p=params["top_p"]
)
def handle_chat(message: str, history: list, model: str, preset: str) -> str:
"""
Handle the chat interaction.
Args:
message (str): The user's message.
history (list): The conversation history.
model (str): The model to use.
preset (str): The performance preset.
Returns:
str: The generated response.
"""
try:
return generate_response(message, model, preset)
except Exception as e:
return f"⚠️ Error: {str(e)}"
with gr.Blocks(title="BI CORP AI Assistant", theme="soft") as demo:
gr.Markdown("# <center>Lake AI Assistant</center>")
gr.Markdown("### <center>Powered by Lake 1 Chat</center>")
with gr.Row():
with gr.Column(scale=1):
model_dropdown = gr.Dropdown(
label="🤖 Model Selection",
choices=list(clients.keys()),
value="Lake 1 Chat",
interactive=True
)
preset_dropdown = gr.Dropdown(
label="⚙️ Performance Preset",
choices=["Fast", "Normal", "Quality"],
value="Normal",
interactive=True
)
model_info_md = gr.Markdown(
value=model_configs["Test"]["specs"],
label="📝 Model Specifications"
)
with gr.Column(scale=3):
chat_interface = gr.ChatInterface(
fn=handle_chat,
additional_inputs=[model_dropdown, preset_dropdown],
examples=[["Explain quantum computing", "Test", "Normal"]],
chatbot=gr.Chatbot(height=600, label="💬 Conversation", show_copy_button=True),
textbox=gr.Textbox(placeholder="Type your message...", container=False, scale=7, autofocus=True),
submit_btn=gr.Button("🚀 Send", variant="primary")
)
clear_button = gr.Button("🧹 Clear History")
clear_button.click(
fn=lambda: None,
inputs=[],
outputs=chat_interface.chatbot,
queue=False
)
model_dropdown.change(
fn=lambda model: model_configs[model]["specs"],
inputs=model_dropdown,
outputs=model_info_md,
queue=False
)
if __name__ == "__main__":
demo.launch(server_port=7865)