|
from flask import Flask, request, jsonify |
|
from flask_cors import CORS |
|
import torch |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
import logging |
|
import os |
|
|
|
|
|
logging.basicConfig(level=logging.DEBUG) |
|
|
|
|
|
logging.info("Loading model...") |
|
model_repo = "hsb06/toghetherAi-model" |
|
tokenizer = AutoTokenizer.from_pretrained(model_repo) |
|
model = AutoModelForCausalLM.from_pretrained(model_repo, torch_dtype=torch.float16).to("cuda" if torch.cuda.is_available() else "cpu") |
|
logging.info("Model loaded successfully.") |
|
|
|
app = Flask(__name__) |
|
CORS(app) |
|
|
|
def generate_response(prompt): |
|
inputs = tokenizer(prompt, return_tensors="pt").to(model.device) |
|
input_length = inputs.input_ids.shape[1] |
|
outputs = model.generate( |
|
**inputs, |
|
max_new_tokens=256, |
|
do_sample=True, |
|
temperature=0.7, |
|
top_p=0.7, |
|
top_k=50, |
|
return_dict_in_generate=True, |
|
) |
|
token = outputs.sequences[0, input_length:] |
|
full_response = tokenizer.decode(token, skip_special_tokens=True) |
|
|
|
if "<human>" in full_response: |
|
trimmed_response = full_response.split("<human>")[0].strip() |
|
else: |
|
trimmed_response = full_response.strip() |
|
|
|
logging.debug(f"Trimmed response: {trimmed_response}") |
|
return trimmed_response |
|
|
|
|
|
@app.route("/", methods=["GET"]) |
|
def home(): |
|
return jsonify({"message": "Flask app is running!"}) |
|
|
|
@app.route("/chat", methods=["POST"]) |
|
def chat(): |
|
data = request.json |
|
user_input = data.get("message", "") |
|
prompt = f"<human>: {user_input}\n<bot>:" |
|
logging.info(f"User input: {user_input}") |
|
logging.debug(f"Generated prompt: {prompt}") |
|
response = generate_response(prompt) |
|
logging.info(f"Generated response: {response}") |
|
return jsonify({"response": response}) |
|
|
|
if __name__ == "__main__": |
|
port = int(os.getenv("PORT", 7860)) |
|
logging.info(f"Starting Flask app on port {port}") |
|
app.run(debug=False, host="0.0.0.0", port=port) |
|
|