# # Copyright 2024 The InfiniFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import json import traceback from flask import request, Response from flask_login import login_required, current_user from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService from api.settings import RetCode from api.utils import get_uuid from api.utils.api_utils import get_json_result, server_error_response, validate_request, get_data_error_result from agent.canvas import Canvas from peewee import MySQLDatabase, PostgresqlDatabase from api.db.db_models import APIToken @manager.route('/templates', methods=['GET']) # noqa: F821 @login_required def templates(): return get_json_result(data=[c.to_dict() for c in CanvasTemplateService.get_all()]) @manager.route('/list', methods=['GET']) # noqa: F821 @login_required def canvas_list(): return get_json_result(data=sorted([c.to_dict() for c in \ UserCanvasService.query(user_id=current_user.id)], key=lambda x: x["update_time"]*-1) ) @manager.route('/rm', methods=['POST']) # noqa: F821 @validate_request("canvas_ids") @login_required def rm(): for i in request.json["canvas_ids"]: if not UserCanvasService.query(user_id=current_user.id,id=i): return get_json_result( data=False, message='Only owner of canvas authorized for this operation.', code=RetCode.OPERATING_ERROR) UserCanvasService.delete_by_id(i) return get_json_result(data=True) @manager.route('/set', methods=['POST']) # noqa: F821 @validate_request("dsl", "title") @login_required def save(): req = request.json req["user_id"] = current_user.id if not isinstance(req["dsl"], str): req["dsl"] = json.dumps(req["dsl"], ensure_ascii=False) req["dsl"] = json.loads(req["dsl"]) if "id" not in req: if UserCanvasService.query(user_id=current_user.id, title=req["title"].strip()): return get_data_error_result(message=f"{req['title'].strip()} already exists.") req["id"] = get_uuid() if not UserCanvasService.save(**req): return get_data_error_result(message="Fail to save canvas.") else: if not UserCanvasService.query(user_id=current_user.id, id=req["id"]): return get_json_result( data=False, message='Only owner of canvas authorized for this operation.', code=RetCode.OPERATING_ERROR) UserCanvasService.update_by_id(req["id"], req) return get_json_result(data=req) @manager.route('/get/', methods=['GET']) # noqa: F821 @login_required def get(canvas_id): e, c = UserCanvasService.get_by_id(canvas_id) if not e: return get_data_error_result(message="canvas not found.") return get_json_result(data=c.to_dict()) @manager.route('/getsse/', methods=['GET']) # type: ignore # noqa: F821 def getsse(canvas_id): token = request.headers.get('Authorization').split() if len(token) != 2: return get_data_error_result(message='Authorization is not valid!"') token = token[1] objs = APIToken.query(beta=token) if not objs: return get_data_error_result(message='Authentication error: API key is invalid!"') e, c = UserCanvasService.get_by_id(canvas_id) if not e: return get_data_error_result(message="canvas not found.") return get_json_result(data=c.to_dict()) @manager.route('/completion', methods=['POST']) # noqa: F821 @validate_request("id") @login_required def run(): req = request.json stream = req.get("stream", True) e, cvs = UserCanvasService.get_by_id(req["id"]) if not e: return get_data_error_result(message="canvas not found.") if not UserCanvasService.query(user_id=current_user.id, id=req["id"]): return get_json_result( data=False, message='Only owner of canvas authorized for this operation.', code=RetCode.OPERATING_ERROR) if not isinstance(cvs.dsl, str): cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False) final_ans = {"reference": [], "content": ""} message_id = req.get("message_id", get_uuid()) try: canvas = Canvas(cvs.dsl, current_user.id) if "message" in req: canvas.messages.append({"role": "user", "content": req["message"], "id": message_id}) canvas.add_user_input(req["message"]) except Exception as e: return server_error_response(e) if stream: def sse(): nonlocal answer, cvs try: for ans in canvas.run(stream=True): if ans.get("running_status"): yield "data:" + json.dumps({"code": 0, "message": "", "data": {"answer": ans["content"], "running_status": True}}, ensure_ascii=False) + "\n\n" continue for k in ans.keys(): final_ans[k] = ans[k] ans = {"answer": ans["content"], "reference": ans.get("reference", [])} yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n" canvas.messages.append({"role": "assistant", "content": final_ans["content"], "id": message_id}) canvas.history.append(("assistant", final_ans["content"])) if not canvas.path[-1]: canvas.path.pop(-1) if final_ans.get("reference"): canvas.reference.append(final_ans["reference"]) cvs.dsl = json.loads(str(canvas)) UserCanvasService.update_by_id(req["id"], cvs.to_dict()) except Exception as e: cvs.dsl = json.loads(str(canvas)) if not canvas.path[-1]: canvas.path.pop(-1) UserCanvasService.update_by_id(req["id"], cvs.to_dict()) traceback.print_exc() yield "data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, ensure_ascii=False) + "\n\n" yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n" resp = Response(sse(), mimetype="text/event-stream") resp.headers.add_header("Cache-control", "no-cache") resp.headers.add_header("Connection", "keep-alive") resp.headers.add_header("X-Accel-Buffering", "no") resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8") return resp for answer in canvas.run(stream=False): if answer.get("running_status"): continue final_ans["content"] = "\n".join(answer["content"]) if "content" in answer else "" canvas.messages.append({"role": "assistant", "content": final_ans["content"], "id": message_id}) if final_ans.get("reference"): canvas.reference.append(final_ans["reference"]) cvs.dsl = json.loads(str(canvas)) UserCanvasService.update_by_id(req["id"], cvs.to_dict()) return get_json_result(data={"answer": final_ans["content"], "reference": final_ans.get("reference", [])}) @manager.route('/reset', methods=['POST']) # noqa: F821 @validate_request("id") @login_required def reset(): req = request.json try: e, user_canvas = UserCanvasService.get_by_id(req["id"]) if not e: return get_data_error_result(message="canvas not found.") if not UserCanvasService.query(user_id=current_user.id, id=req["id"]): return get_json_result( data=False, message='Only owner of canvas authorized for this operation.', code=RetCode.OPERATING_ERROR) canvas = Canvas(json.dumps(user_canvas.dsl), current_user.id) canvas.reset() req["dsl"] = json.loads(str(canvas)) UserCanvasService.update_by_id(req["id"], {"dsl": req["dsl"]}) return get_json_result(data=req["dsl"]) except Exception as e: return server_error_response(e) @manager.route('/input_elements', methods=['GET']) # noqa: F821 @login_required def input_elements(): cvs_id = request.args.get("id") cpn_id = request.args.get("component_id") try: e, user_canvas = UserCanvasService.get_by_id(cvs_id) if not e: return get_data_error_result(message="canvas not found.") if not UserCanvasService.query(user_id=current_user.id, id=cvs_id): return get_json_result( data=False, message='Only owner of canvas authorized for this operation.', code=RetCode.OPERATING_ERROR) canvas = Canvas(json.dumps(user_canvas.dsl), current_user.id) return get_json_result(data=canvas.get_component_input_elements(cpn_id)) except Exception as e: return server_error_response(e) @manager.route('/debug', methods=['POST']) # noqa: F821 @validate_request("id", "component_id", "params") @login_required def debug(): req = request.json for p in req["params"]: assert p.get("key") try: e, user_canvas = UserCanvasService.get_by_id(req["id"]) if not e: return get_data_error_result(message="canvas not found.") if not UserCanvasService.query(user_id=current_user.id, id=req["id"]): return get_json_result( data=False, message='Only owner of canvas authorized for this operation.', code=RetCode.OPERATING_ERROR) canvas = Canvas(json.dumps(user_canvas.dsl), current_user.id) canvas.get_component(req["component_id"])["obj"]._param.debug_inputs = req["params"] df = canvas.get_component(req["component_id"])["obj"].debug() return get_json_result(data=df.to_dict(orient="records")) except Exception as e: return server_error_response(e) @manager.route('/test_db_connect', methods=['POST']) # noqa: F821 @validate_request("db_type", "database", "username", "host", "port", "password") @login_required def test_db_connect(): req = request.json try: if req["db_type"] in ["mysql", "mariadb"]: db = MySQLDatabase(req["database"], user=req["username"], host=req["host"], port=req["port"], password=req["password"]) elif req["db_type"] == 'postgresql': db = PostgresqlDatabase(req["database"], user=req["username"], host=req["host"], port=req["port"], password=req["password"]) elif req["db_type"] == 'mssql': import pyodbc connection_string = ( f"DRIVER={{ODBC Driver 17 for SQL Server}};" f"SERVER={req['host']},{req['port']};" f"DATABASE={req['database']};" f"UID={req['username']};" f"PWD={req['password']};" ) db = pyodbc.connect(connection_string) cursor = db.cursor() cursor.execute("SELECT 1") cursor.close() else: return server_error_response("Unsupported database type.") if req["db_type"] != 'mssql': db.connect() db.close() return get_json_result(data="Database Connection Successful!") except Exception as e: return server_error_response(e)