File size: 1,992 Bytes
344f451
 
 
 
 
 
 
db9fe30
 
344f451
db9fe30
 
 
a58a3ee
 
db9fe30
344f451
 
db9fe30
344f451
 
db9fe30
344f451
 
db9fe30
2d1ca8a
db9fe30
 
 
 
 
344f451
 
77fcb58
3b34c62
 
 
2d1ca8a
3b34c62
 
 
 
 
344f451
cddf948
 
 
 
db9fe30
344f451
 
 
db445f0
 
 
344f451
db445f0
344f451
 
 
77fcb58
c7ca80b
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
from flask import Flask, request, jsonify
from flask_cors import CORS
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import logging
import os

# Initialize logger
logging.basicConfig(level=logging.DEBUG)

# Load tokenizer and model
logging.info("Loading model...")
model_repo = "hsb06/toghetherAi-model"
tokenizer = AutoTokenizer.from_pretrained(model_repo)
model = AutoModelForCausalLM.from_pretrained(model_repo, torch_dtype=torch.float16).to("cuda" if torch.cuda.is_available() else "cpu")
logging.info("Model loaded successfully.")

app = Flask(__name__)
CORS(app)

def generate_response(prompt):
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    input_length = inputs.input_ids.shape[1]
    outputs = model.generate(
        **inputs,
        max_new_tokens=256, 
        do_sample=True,
        temperature=0.7,
        top_p=0.7,
        top_k=50,
        return_dict_in_generate=True,
    )
    token = outputs.sequences[0, input_length:]
    full_response = tokenizer.decode(token, skip_special_tokens=True)

    if "<human>" in full_response:
        trimmed_response = full_response.split("<human>")[0].strip()
    else:
        trimmed_response = full_response.strip()

    logging.debug(f"Trimmed response: {trimmed_response}")
    return trimmed_response


@app.route("/", methods=["GET"])
def home():
    return jsonify({"message": "Flask app is running!"})

@app.route("/chat", methods=["POST"])
def chat():
    data = request.json
    user_input = data.get("message", "")
    prompt = f"<human>: {user_input}\n<bot>:"
    logging.info(f"User input: {user_input}")
    logging.debug(f"Generated prompt: {prompt}")
    response = generate_response(prompt)
    logging.info(f"Generated response: {response}")
    return jsonify({"response": response})

if __name__ == "__main__":
    port = int(os.getenv("PORT", 7860)) 
    logging.info(f"Starting Flask app on port {port}")
    app.run(debug=False, host="0.0.0.0", port=port)