from flask import Flask, request, jsonify, render_template
from backend_utils import initialize_all_components, make_predictions
from config import classifier_class_mapping, config
import subprocess
# from flask_cors import CORS, cross_origin
import json


# todo: downgrade version sklearn to 1.0.2

app = Flask(__name__)
# CORS(app)
components = initialize_all_components(config)
db_metadata = components[0]
db_constructor = components[1]
db_params = components[2]
ex_list = components[3]
model_retrieval = components[4]
model_generative = components[5]
tokenizer_generative = components[6]
model_classifier = components[7]
classifier_head = components[8]
tokenizer_classifier = components[9]

def call_predict_api(
    input_query, 
    model_retrieval, 
    model_generative,  
    model_classifier, classifier_head,
    tokenizer_generative, tokenizer_classifier,
    db_metadata, db_constructor, db_params, ex_list,
    config
    ):
    '''
    wrapper to the make prediction function
    '''
    predictions = make_predictions(
        input_query, 
        model_retrieval, 
        model_generative,  
        model_classifier, classifier_head,
        tokenizer_generative, tokenizer_classifier,
        db_metadata, db_constructor, db_params, ex_list,
        config
    )
    return predictions

@app.route("/")
def hello_world():
    return render_template("index.html")

@app.route('/predict', methods=['POST', 'GET'])
def predict():
    #request_data = request.get_json()
    #user_query = request_data.get('user_query', None)
    user_query = request.args.get("user_query")
    print(f"user_query: {user_query}")

    if user_query != None:
        print("predicting")
        predictions = call_predict_api(
                user_query,
                model_retrieval,
                model_generative,
                model_classifier, classifier_head,
                tokenizer_generative, tokenizer_classifier,
                db_metadata, db_constructor, db_params, ex_list,
                config
            )
        
        print(predictions)
        if type(predictions) == str:
            if predictions == 'null':
                return jsonify({'predictions': 'null'})

        return jsonify({
            'predictions': predictions
        })
if __name__ == '__main__':
    app.run(host="0.0.0.0", port=7860)