File size: 1,992 Bytes
344f451 db9fe30 344f451 db9fe30 a58a3ee db9fe30 344f451 db9fe30 344f451 db9fe30 344f451 db9fe30 2d1ca8a db9fe30 344f451 77fcb58 3b34c62 2d1ca8a 3b34c62 344f451 cddf948 db9fe30 344f451 db445f0 344f451 db445f0 344f451 77fcb58 c7ca80b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 |
from flask import Flask, request, jsonify
from flask_cors import CORS
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import logging
import os
# Initialize logger
logging.basicConfig(level=logging.DEBUG)
# Load tokenizer and model
logging.info("Loading model...")
model_repo = "hsb06/toghetherAi-model"
tokenizer = AutoTokenizer.from_pretrained(model_repo)
model = AutoModelForCausalLM.from_pretrained(model_repo, torch_dtype=torch.float16).to("cuda" if torch.cuda.is_available() else "cpu")
logging.info("Model loaded successfully.")
app = Flask(__name__)
CORS(app)
def generate_response(prompt):
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
input_length = inputs.input_ids.shape[1]
outputs = model.generate(
**inputs,
max_new_tokens=256,
do_sample=True,
temperature=0.7,
top_p=0.7,
top_k=50,
return_dict_in_generate=True,
)
token = outputs.sequences[0, input_length:]
full_response = tokenizer.decode(token, skip_special_tokens=True)
if "<human>" in full_response:
trimmed_response = full_response.split("<human>")[0].strip()
else:
trimmed_response = full_response.strip()
logging.debug(f"Trimmed response: {trimmed_response}")
return trimmed_response
@app.route("/", methods=["GET"])
def home():
return jsonify({"message": "Flask app is running!"})
@app.route("/chat", methods=["POST"])
def chat():
data = request.json
user_input = data.get("message", "")
prompt = f"<human>: {user_input}\n<bot>:"
logging.info(f"User input: {user_input}")
logging.debug(f"Generated prompt: {prompt}")
response = generate_response(prompt)
logging.info(f"Generated response: {response}")
return jsonify({"response": response})
if __name__ == "__main__":
port = int(os.getenv("PORT", 7860))
logging.info(f"Starting Flask app on port {port}")
app.run(debug=False, host="0.0.0.0", port=port)
|