sbxcerthelper / app.py
satishpednekar's picture
Update app.py
7bff2c6 verified
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from peft import PeftModel, PeftConfig
# Model and tokenizer initialization
MODEL_NAME = "satishpednekar/sbxcertqueryhelper"
def load_model_org():
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
# Modified model loading without 8-bit quantization
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=torch.float16, # Use float32 instead of float16 for better compatibility
device_map="auto",
trust_remote_code=True,
load_in_8bit=False
# Removed load_in_8bit parameter
)
return model, tokenizer
def load_model_gpu():
# Load base model first
base_model = AutoModelForCausalLM.from_pretrained(
"unsloth/mistral-7b-v0.3", # Use your base model name
torch_dtype=torch.float16,
device_map="auto",
trust_remote_code=True
)
# Load the PEFT adapter weights
model = PeftModel.from_pretrained(
base_model,
"satishpednekar/sbx-qhelper-mistral-loraWeights", # Path to your trained LoRA weights
torch_dtype=torch.float16,
device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(
"unsloth/mistral-7b-v0.3", # Use your base model name
trust_remote_code=True
)
return model, tokenizer
def load_model():
config = PeftConfig.from_pretrained("satishpednekar/sbx-qhelper-mistral-loraWeights")
model = AutoModelForCausalLM.from_pretrained(
config.base_model_name_or_path,
torch_dtype=torch.float32,
device_map=None,
trust_remote_code=True,
# Remove all quantization-related parameters
)
model = PeftModel.from_pretrained(
model,
"satishpednekar/sbx-qhelper-mistral-loraWeights",
torch_dtype=torch.float32
)
tokenizer = AutoTokenizer.from_pretrained(
config.base_model_name_or_path,
trust_remote_code=True
)
model = model.to("cpu").eval()
return model, tokenizer
# Initialize model and tokenizer
print("Loading model...")
model, tokenizer = load_model()
print("Model loaded successfully!")
def generate_response(prompt, max_length=512, temperature=0.7, top_p=0.95):
"""
Generate a response using the fine-tuned model
"""
try:
# Prepare the input
inputs = tokenizer(prompt, return_tensors="pt")
if torch.cuda.is_available():
inputs = inputs.to(model.device)
# Generate
outputs = model.generate(
**inputs,
max_length=max_length,
temperature=temperature,
top_p=top_p,
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
num_return_sequences=1
)
# Decode the response
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Clean up the response by removing the prompt if it appears at the start
if response.startswith(prompt):
response = response[len(prompt):].strip()
return response
except Exception as e:
return f"An error occurred: {str(e)}"
# Create the Gradio interface
def main():
with gr.Blocks(title="SBX Certification Query Helper") as demo:
gr.Markdown("""
# SBX Certification Query Helper
Ask questions about SBX certifications and get detailed answers!
""")
with gr.Row():
with gr.Column():
input_text = gr.Textbox(
label="Your Question",
placeholder="Enter your question about SBX certifications...",
lines=3
)
with gr.Row():
temperature = gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.7,
step=0.1,
label="Temperature",
info="Higher values make output more random, lower values make it more focused"
)
max_length = gr.Slider(
minimum=64,
maximum=1024,
value=512,
step=64,
label="Maximum Length",
info="Maximum length of the generated response"
)
submit_btn = gr.Button("Get Answer", variant="primary")
with gr.Column():
output_text = gr.Textbox(
label="Answer",
lines=10,
show_copy_button=True
)
# Set up the click event
submit_btn.click(
fn=generate_response,
inputs=[input_text, max_length, temperature],
outputs=output_text
)
gr.Markdown("""
### Tips:
- Be specific in your questions
- Include the certification name if you're asking about a specific certification
- Adjust the temperature slider to control response creativity
""")
return demo
if __name__ == "__main__":
demo = main()
demo.launch(
share=True, # Enable sharing
enable_queue=True, # Enable queue for handling multiple requests
server_name="0.0.0.0" # Listen on all network interfaces
)