Shanks / app.py
MandlaZwane's picture
Update app.py
082868c verified
import sys
import os
import torch
from safetensors.torch import load_file
from transformers import AutoConfig, GenerationConfig
import json
from flask import Flask, request, jsonify, render_template
app = Flask(__name__)
# Define paths
drive_folder = '/app' # Path where files will be downloaded in Docker container
tokenizer_config_file = os.path.join(drive_folder, 'tokenizer_config.json')
model_config_file = os.path.join(drive_folder, 'config.json')
# Add the custom tokenizer and model paths to sys.path
sys.path.append(drive_folder)
# Debugging print statements
print(f"Drive folder: {drive_folder}")
print(f"Tokenizer config file: {tokenizer_config_file}")
print(f"Model config file: {model_config_file}")
# Import the custom configuration, tokenizer, and model classes
try:
from configuration_qwen import QWenConfig
from tokenization_qwen import QWenTokenizer
from modeling_qwen import QWenLMHeadModel
print("Imported custom classes successfully!")
except ImportError as e:
print(f"Import error: {e}")
raise
# Ensure the tokenizer configuration file exists
if not os.path.exists(tokenizer_config_file):
raise FileNotFoundError(f"Tokenizer configuration file not found at {tokenizer_config_file}")
# Load the tokenizer configuration
with open(tokenizer_config_file, 'r') as f:
tokenizer_config = json.load(f)
# Load the model configuration from the provided config file
with open(model_config_file, 'r') as f:
model_config = json.load(f)
# Disable FlashAttention for non-supported GPUs
model_config["use_flash_attn"] = False
model_config["use_dynamic_ntk"] = False # Disable other advanced features if necessary
# Use the provided configuration for model initialization
try:
tokenizer = QWenTokenizer.from_pretrained(drive_folder)
model = QWenLMHeadModel.from_pretrained(drive_folder, config=QWenConfig.from_pretrained(drive_folder, **model_config))
model.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
print("Model and tokenizer loaded successfully!")
except Exception as e:
print("Error loading model or tokenizer:", e)
raise
def generate_text(model, tokenizer, prompt, max_length=200, temperature=0.7, top_k=50, top_p=0.9):
try:
# Tokenize the input
input_ids = tokenizer.encode(prompt, return_tensors='pt').to(model.device)
# Set up generation configuration
generation_config = GenerationConfig(
max_length=max_length + len(input_ids[0]),
do_sample=True,
temperature=temperature,
top_k=top_k,
top_p=top_p,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id
)
# Generate text using advanced sampling
outputs = model.generate(
input_ids,
generation_config=generation_config
)
# Decode the generated sequence
decoded_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Clean up the output
start_index = decoded_output.find(prompt)
generated_text = decoded_output[start_index + len(prompt):].strip()
return generated_text
except Exception as e:
print("Error during text generation:", e)
raise
@app.route('/')
def home():
return render_template('index.html')
@app.route('/generate', methods=['POST'])
def generate():
user_input = request.form['user_input']
try:
if "urname" in user_input and "what" in user_input:
response_text = "I am Shanks, a large language model developed by Motaung.inc"
elif "your name" in user_input and "what" in user_input:
response_text = "I am Shanks, a large language model developed by Motaung.inc"
elif "tell " in user_input and "your name" in user_input:
response_text = "I am Shanks, a large language model developed by Motaung.inc"
elif "what" in user_input and "you go by" in user_input:
response_text = "I am Shanks, a large language model developed by Motaung.inc"
elif "what" in user_input and "call yourself" in user_input:
response_text = "I am Shanks, a large language model developed by Motaung.inc"
elif "what" in user_input and "they call you" in user_input:
response_text = "I am Shanks, a large language model developed by Motaung.inc"
else:
response_text = generate_text(model, tokenizer, user_input)
return jsonify({"response": response_text})
except Exception as e:
return jsonify({"error": str(e)}), 500
if __name__ == '__main__':
app.run(host='0.0.0.0', port=8080)