diff --git a/.env.template b/.env.template deleted file mode 100644 index aa91e736d95b6235900ed76f6f76950fdd3e1521..0000000000000000000000000000000000000000 --- a/.env.template +++ /dev/null @@ -1,9 +0,0 @@ -# Database -HOST=127.0.0.1 -PORT=8000 -DATABASE_URL="postgresql://infiniflow:infiniflow@localhost/docgpt" - -# S3 Storage -MINIO_HOST="127.0.0.1:9000" -MINIO_USR="infiniflow" -MINIO_PWD="infiniflow_docgpt" diff --git a/Cargo.toml b/Cargo.toml deleted file mode 100644 index aa229a54b1b0d8345793c109c44a04c27ae61bbf..0000000000000000000000000000000000000000 --- a/Cargo.toml +++ /dev/null @@ -1,42 +0,0 @@ -[package] -name = "doc_gpt" -version = "0.1.0" -edition = "2021" - -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html - -[dependencies] -actix-web = "4.3.1" -actix-rt = "2.8.0" -actix-files = "0.6.2" -actix-multipart = "0.4" -actix-session = { version = "0.5" } -actix-identity = { version = "0.4" } -actix-web-httpauth = { version = "0.6" } -actix-ws = "0.2.5" -uuid = { version = "1.6.1", features = [ - "v4", - "fast-rng", - "macro-diagnostics", -] } -thiserror = "1.0" -postgres = "0.19.7" -sea-orm = { version = "0.12.9", features = ["sqlx-postgres", "runtime-tokio-native-tls", "macros"] } -serde = { version = "1", features = ["derive"] } -serde_json = "1.0" -tracing-subscriber = "0.3.18" -dotenvy = "0.15.7" -listenfd = "1.0.1" -chrono = "0.4.31" -migration = { path = "./migration" } -minio = "0.1.0" -futures-util = "0.3.29" -actix-multipart-extract = "0.1.5" -regex = "1.10.2" -tokio = { version = "1.35.1", features = ["rt", "time", "macros"] } - -[[bin]] -name = "doc_gpt" - -[workspace] -members = [".", "migration"] diff --git a/python/conf/mapping.json b/conf/mapping.json old mode 100755 new mode 100644 similarity index 100% rename from python/conf/mapping.json rename to conf/mapping.json diff --git a/conf/private.pem b/conf/private.pem new file mode 100644 index 0000000000000000000000000000000000000000..ff333058d0d8087fadb09295c369a59f1d46a166 --- /dev/null +++ b/conf/private.pem @@ -0,0 +1,30 @@ +-----BEGIN RSA PRIVATE KEY----- +Proc-Type: 4,ENCRYPTED +DEK-Info: DES-EDE3-CBC,EFF8327C41E531AD + +7jdPFDAA6fiTzOIU7XGzKuT324JKZEcK5vBRJqBkA5XO6ENN1wLdhh3zQbl1Ejfv +KMSUIgbtQEJB4bvOzS//okbZa1vCNYuTS/NGcpKUnhqdOmAL3hl/kOtOLLjTZrwo +3KX8iujLH7wQ64GxArtpUuaFq1k0whN1BB5RGJp3IO/L6pMpSWVRKO+JPUrD1Ujr +XA/LUKQJaZtXVUVOYPtIwbyqPsh93QBetJnRwwV3gNOwGpcX2jDpyTxDUkLJCPPg +6Hw0pwlQEd8A11sjxCBbASwLeJO1L0w69QiX9chyOkZ+sfDsVpPt/wf1NexA7Cdj +9uifJ4JGbby39QD6mInZGtnRzQRdafjuXlBR2I0Qa7fBRu8QsfhmLbWZfWno7j08 +4bAAoqB1vRNfSu8LVJXdEEh/HKuwu11pgRr5eH8WQ3hJg+Y2k7zDHpp1VaHL7/Kn +S+aN5bhQ4Xt0Ujdi1+rsmNchnF6LWsDezHWJeWUM6X7dJnqIBl8oCyghbghT8Tyw +aEKWXc2+7FsP5yd0NfG3PFYOLdLgfI43pHTAv5PEQ47w9r1XOwfblKKBUDEzaput +T3t5wQ6wxdyhRxeO4arCHfe/i+j3fzvhlwgbuwrmrkWGWSS86eMTaoGM8+uUrHv0 +6TbU0tj6DKKUslVk1dCHh9TnmNsXZuLJkceZF38PSKNxhzudU8OTtzhS0tFL91HX +vo7N+XdiGMs8oOSpjE6RPlhFhVAKGJpXwBj/vXLLcmzesA7ZB2kYtFKMIdsUQpls +PE/4K5PEX2d8pxA5zxo0HleA1YjW8i5WEcDQThZQzj2sWvg06zSjenVFrbCm9Bro +hFpAB/3zJHxdRN2MpNpvK35WITy1aDUdX1WdyrlcRtIE5ssFTSoxSj9ibbDZ78+z +gtbw/MUi6vU6Yz1EjvoYu/bmZAHt9Aagcxw6k58fjO2cEB9njK7xbbiZUSwpJhEe +U/PxK+SdOU/MmGKeqdgqSfhJkq0vhacvsEjFGRAfivSCHkL0UjhObU+rSJ3g1RMO +oukAev6TOAwbTKVWjg3/EX+pl/zorAgaPNYFX64TSH4lE3VjeWApITb9Z5C/sVxR +xW6hU9qyjzWYWY+91y16nkw1l7VQvWHUZwV7QzTScC2BOzDVpeqY1KiYJxgoo6sX +ZCqR5oh4vToG4W8ZrRyauwUaZJ3r+zhAgm+6n6TJQNwFEl0muji+1nPl32EiFsRs +qR6CtuhUOVQM4VnILDwFJfuGYRFtKzQgvseLNU4ZqAVqQj8l4ARGAP2P1Au/uUKy +oGzI7a+b5MvRHuvkxPAclOgXgX/8yyOLaBg+mgaqv9h2JIJD28PzouFl3BajRaVB +7GWTnROJYhX5SuX/g585SLRKoQUtK0WhdJCjTRfyRJPwfdppgdTbWO99R4G+ir02 +JQdSkZf2vmZRXenPNTEPDOUY6nVN6sUuBjmtOwoUF194ODgpYB6IaHqK08sa1pUh +1mZyxitHdPbygePTe20XWMZFoK2knAqN0JPPbbNjCqiVV+7oqQAnkDIutspu9t2m +ny3jefFmNozbblQMghLUrq+x9wOEgvS76Sqvq3DG/2BkLzJF3MNkvw== +-----END RSA PRIVATE KEY----- diff --git a/conf/public.pem b/conf/public.pem new file mode 100644 index 0000000000000000000000000000000000000000..3fbcfe189593c174d9893721839dcffadf7bcce8 --- /dev/null +++ b/conf/public.pem @@ -0,0 +1,9 @@ +-----BEGIN PUBLIC KEY----- +MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEArq9XTUSeYr2+N1h3Afl/ +z8Dse/2yD0ZGrKwx+EEEcdsBLca9Ynmx3nIB5obmLlSfmskLpBo0UACBmB5rEjBp +2Q2f3AG3Hjd4B+gNCG6BDaawuDlgANIhGnaTLrIqWrrcm4EMzJOnAOI1fgzJRsOO +UEfaS318Eq9OVO3apEyCCt0lOQK6PuksduOjVxtltDav+guVAA068NrPYmRNabVK +RNLJpL8w4D44sfth5RvZ3q9t+6RTArpEtc5sh5ChzvqPOzKGMXW83C95TxmXqpbK +6olN4RevSfVjEAgCydH6HN6OhtOQEcnrU97r9H0iZOWwbw3pVrZiUkuRD1R56Wzs +2wIDAQAB +-----END PUBLIC KEY----- diff --git a/conf/service_conf.yaml b/conf/service_conf.yaml new file mode 100644 index 0000000000000000000000000000000000000000..dd8cca1ad8f3533ec8860d6062025507761c73e5 --- /dev/null +++ b/conf/service_conf.yaml @@ -0,0 +1,28 @@ +authentication: + client: + switch: false + http_app_key: + http_secret_key: + site: + switch: false +permission: + switch: false + component: false + dataset: false +ragflow: + # you must set real ip address, 127.0.0.1 and 0.0.0.0 is not supported + host: 127.0.0.1 + http_port: 9380 +database: + name: 'rag_flow' + user: 'root' + passwd: 'infini_rag_flow' + host: '123.60.95.134' + port: 5455 + max_connections: 100 + stale_timeout: 30 +oauth: + github: + client_id: 302129228f0d96055bee + secret_key: e518e55ccfcdfcae8996afc40f110e9c95f14fc4 + url: https://github.com/login/oauth/access_token \ No newline at end of file diff --git a/docker/.env b/docker/.env deleted file mode 100644 index fefb602e242ed542ae68c2792aa68e0207856867..0000000000000000000000000000000000000000 --- a/docker/.env +++ /dev/null @@ -1,21 +0,0 @@ -# Version of Elastic products -STACK_VERSION=8.11.3 - -# Set the cluster name -CLUSTER_NAME=docgpt - -# Port to expose Elasticsearch HTTP API to the host -ES_PORT=9200 - -# Port to expose Kibana to the host -KIBANA_PORT=6601 - -# Increase or decrease based on the available host memory (in bytes) -MEM_LIMIT=4073741824 - -POSTGRES_USER=root -POSTGRES_PASSWORD=infiniflow_docgpt -POSTGRES_DB=docgpt - -MINIO_USER=infiniflow -MINIO_PASSWORD=infiniflow_docgpt diff --git a/docker/docker-compose.yml b/docker/docker-compose.yml index 04aabdcc121620b01d79090eb2c1476fc29e2272..b01d215aa4de5af8a8cb213d9139bc0024372d2b 100644 --- a/docker/docker-compose.yml +++ b/docker/docker-compose.yml @@ -1,7 +1,7 @@ version: '2.2' services: es01: - container_name: docgpt-es-01 + container_name: ragflow-es-01 image: docker.elastic.co/elasticsearch/elasticsearch:${STACK_VERSION} volumes: - esdata01:/usr/share/elasticsearch/data @@ -20,14 +20,14 @@ services: soft: -1 hard: -1 networks: - - docgpt + - ragflow restart: always kibana: depends_on: - es01 image: docker.elastic.co/kibana/kibana:${STACK_VERSION} - container_name: docgpt-kibana + container_name: ragflow-kibana volumes: - kibanadata:/usr/share/kibana/data ports: @@ -37,26 +37,39 @@ services: - ELASTICSEARCH_HOSTS=http://es01:9200 mem_limit: ${MEM_LIMIT} networks: - - docgpt + - ragflow - postgres: - image: postgres - container_name: docgpt-postgres + mysql: + image: mysql:5.7.18 + container_name: ragflow-mysql environment: - - POSTGRES_USER=${POSTGRES_USER} - - POSTGRES_PASSWORD=${POSTGRES_PASSWORD} - - POSTGRES_DB=${POSTGRES_DB} + - MYSQL_ROOT_PASSWORD=${MYSQL_PASSWORD} + - TZ="Asia/Shanghai" + command: + --max_connections=1000 + --character-set-server=utf8mb4 + --collation-server=utf8mb4_general_ci + --default-authentication-plugin=mysql_native_password + --tls_version="TLSv1.2,TLSv1.3" + --init-file /data/application/init.sql ports: - - 5455:5432 + - ${MYSQL_PORT}:3306 volumes: - - pg_data:/var/lib/postgresql/data + - mysql_data:/var/lib/mysql + - ./init.sql:/data/application/init.sql networks: - - docgpt + - ragflow + healthcheck: + test: [ "CMD-SHELL", "curl --silent localhost:3306 >/dev/null || exit 1" ] + interval: 10s + timeout: 10s + retries: 3 restart: always + minio: image: quay.io/minio/minio:RELEASE.2023-12-20T01-00-02Z - container_name: docgpt-minio + container_name: ragflow-minio command: server --console-address ":9001" /data ports: - 9000:9000 @@ -67,7 +80,7 @@ services: volumes: - minio_data:/data networks: - - docgpt + - ragflow restart: always @@ -76,11 +89,11 @@ volumes: driver: local kibanadata: driver: local - pg_data: + mysql_data: driver: local minio_data: driver: local networks: - docgpt: + ragflow: driver: bridge diff --git a/docker/init.sql b/docker/init.sql new file mode 100644 index 0000000000000000000000000000000000000000..b368583dfaa0e0810232a812f816ddd80bad22ec --- /dev/null +++ b/docker/init.sql @@ -0,0 +1,2 @@ +CREATE DATABASE IF NOT EXISTS rag_flow; +USE rag_flow; \ No newline at end of file diff --git a/migration/Cargo.toml b/migration/Cargo.toml deleted file mode 100644 index df4becba81a5ceb9bbd06e76511dd306a862f713..0000000000000000000000000000000000000000 --- a/migration/Cargo.toml +++ /dev/null @@ -1,20 +0,0 @@ -[package] -name = "migration" -version = "0.1.0" -edition = "2021" -publish = false - -[lib] -name = "migration" -path = "src/lib.rs" - -[dependencies] -async-std = { version = "1", features = ["attributes", "tokio1"] } -chrono = "0.4.31" - -[dependencies.sea-orm-migration] -version = "0.12.0" -features = [ - "runtime-tokio-rustls", # `ASYNC_RUNTIME` feature - "sqlx-postgres", # `DATABASE_DRIVER` feature -] diff --git a/migration/README.md b/migration/README.md deleted file mode 100644 index 83c18d5d547c2bc1f41181240a636e90fb47656b..0000000000000000000000000000000000000000 --- a/migration/README.md +++ /dev/null @@ -1,41 +0,0 @@ - # Running Migrator CLI - -- Generate a new migration file - ```sh - cargo run -- generate MIGRATION_NAME - ``` -- Apply all pending migrations - ```sh - cargo run - ``` - ```sh - cargo run -- up - ``` -- Apply first 10 pending migrations - ```sh - cargo run -- up -n 10 - ``` -- Rollback last applied migrations - ```sh - cargo run -- down - ``` -- Rollback last 10 applied migrations - ```sh - cargo run -- down -n 10 - ``` -- Drop all tables from the database, then reapply all migrations - ```sh - cargo run -- fresh - ``` -- Rollback all applied migrations, then reapply all migrations - ```sh - cargo run -- refresh - ``` -- Rollback all applied migrations - ```sh - cargo run -- reset - ``` -- Check the status of all migrations - ```sh - cargo run -- status - ``` diff --git a/migration/src/lib.rs b/migration/src/lib.rs deleted file mode 100644 index 2c605afb9423c20f8bd056fb6a5562284e5da57a..0000000000000000000000000000000000000000 --- a/migration/src/lib.rs +++ /dev/null @@ -1,12 +0,0 @@ -pub use sea_orm_migration::prelude::*; - -mod m20220101_000001_create_table; - -pub struct Migrator; - -#[async_trait::async_trait] -impl MigratorTrait for Migrator { - fn migrations() -> Vec> { - vec![Box::new(m20220101_000001_create_table::Migration)] - } -} diff --git a/migration/src/m20220101_000001_create_table.rs b/migration/src/m20220101_000001_create_table.rs deleted file mode 100644 index 439dc4b10bac8941f36a42c9ed841ec60e8e4054..0000000000000000000000000000000000000000 --- a/migration/src/m20220101_000001_create_table.rs +++ /dev/null @@ -1,440 +0,0 @@ -use sea_orm_migration::prelude::*; -use chrono::{ FixedOffset, Utc }; - -#[allow(dead_code)] -fn now() -> chrono::DateTime { - Utc::now().with_timezone(&FixedOffset::east_opt(3600 * 8).unwrap()) -} -#[derive(DeriveMigrationName)] -pub struct Migration; - -#[async_trait::async_trait] -impl MigrationTrait for Migration { - async fn up(&self, manager: &SchemaManager) -> Result<(), DbErr> { - manager.create_table( - Table::create() - .table(UserInfo::Table) - .if_not_exists() - .col( - ColumnDef::new(UserInfo::Uid) - .big_integer() - .not_null() - .auto_increment() - .primary_key() - ) - .col(ColumnDef::new(UserInfo::Email).string().not_null()) - .col(ColumnDef::new(UserInfo::Nickname).string().not_null()) - .col(ColumnDef::new(UserInfo::AvatarBase64).string()) - .col(ColumnDef::new(UserInfo::ColorScheme).string().default("dark")) - .col(ColumnDef::new(UserInfo::ListStyle).string().default("list")) - .col(ColumnDef::new(UserInfo::Language).string().default("chinese")) - .col(ColumnDef::new(UserInfo::Password).string().not_null()) - .col( - ColumnDef::new(UserInfo::LastLoginAt) - .timestamp_with_time_zone() - .default(Expr::current_timestamp()) - ) - .col( - ColumnDef::new(UserInfo::CreatedAt) - .timestamp_with_time_zone() - .default(Expr::current_timestamp()) - .not_null() - ) - .col( - ColumnDef::new(UserInfo::UpdatedAt) - .timestamp_with_time_zone() - .default(Expr::current_timestamp()) - .not_null() - ) - .col(ColumnDef::new(UserInfo::IsDeleted).boolean().default(false)) - .to_owned() - ).await?; - - manager.create_table( - Table::create() - .table(TagInfo::Table) - .if_not_exists() - .col( - ColumnDef::new(TagInfo::Tid) - .big_integer() - .not_null() - .auto_increment() - .primary_key() - ) - .col(ColumnDef::new(TagInfo::Uid).big_integer().not_null()) - .col(ColumnDef::new(TagInfo::TagName).string().not_null()) - .col(ColumnDef::new(TagInfo::Regx).string()) - .col(ColumnDef::new(TagInfo::Color).tiny_unsigned().default(1)) - .col(ColumnDef::new(TagInfo::Icon).tiny_unsigned().default(1)) - .col(ColumnDef::new(TagInfo::FolderId).big_integer()) - .col( - ColumnDef::new(TagInfo::CreatedAt) - .timestamp_with_time_zone() - .default(Expr::current_timestamp()) - .not_null() - ) - .col( - ColumnDef::new(TagInfo::UpdatedAt) - .timestamp_with_time_zone() - .default(Expr::current_timestamp()) - .not_null() - ) - .col(ColumnDef::new(TagInfo::IsDeleted).boolean().default(false)) - .to_owned() - ).await?; - - manager.create_table( - Table::create() - .table(Tag2Doc::Table) - .if_not_exists() - .col( - ColumnDef::new(Tag2Doc::Id) - .big_integer() - .not_null() - .auto_increment() - .primary_key() - ) - .col(ColumnDef::new(Tag2Doc::TagId).big_integer()) - .col(ColumnDef::new(Tag2Doc::Did).big_integer()) - .to_owned() - ).await?; - - manager.create_table( - Table::create() - .table(Kb2Doc::Table) - .if_not_exists() - .col( - ColumnDef::new(Kb2Doc::Id) - .big_integer() - .not_null() - .auto_increment() - .primary_key() - ) - .col(ColumnDef::new(Kb2Doc::KbId).big_integer()) - .col(ColumnDef::new(Kb2Doc::Did).big_integer()) - .col(ColumnDef::new(Kb2Doc::KbProgress).float().default(0)) - .col(ColumnDef::new(Kb2Doc::KbProgressMsg).string().default("")) - .col( - ColumnDef::new(Kb2Doc::UpdatedAt) - .timestamp_with_time_zone() - .default(Expr::current_timestamp()) - .not_null() - ) - .col(ColumnDef::new(Kb2Doc::IsDeleted).boolean().default(false)) - .to_owned() - ).await?; - - manager.create_table( - Table::create() - .table(Dialog2Kb::Table) - .if_not_exists() - .col( - ColumnDef::new(Dialog2Kb::Id) - .big_integer() - .not_null() - .auto_increment() - .primary_key() - ) - .col(ColumnDef::new(Dialog2Kb::DialogId).big_integer()) - .col(ColumnDef::new(Dialog2Kb::KbId).big_integer()) - .to_owned() - ).await?; - - manager.create_table( - Table::create() - .table(Doc2Doc::Table) - .if_not_exists() - .col( - ColumnDef::new(Doc2Doc::Id) - .big_integer() - .not_null() - .auto_increment() - .primary_key() - ) - .col(ColumnDef::new(Doc2Doc::ParentId).big_integer()) - .col(ColumnDef::new(Doc2Doc::Did).big_integer()) - .to_owned() - ).await?; - - manager.create_table( - Table::create() - .table(KbInfo::Table) - .if_not_exists() - .col( - ColumnDef::new(KbInfo::KbId) - .big_integer() - .auto_increment() - .not_null() - .primary_key() - ) - .col(ColumnDef::new(KbInfo::Uid).big_integer().not_null()) - .col(ColumnDef::new(KbInfo::KbName).string().not_null()) - .col(ColumnDef::new(KbInfo::Icon).tiny_unsigned().default(1)) - .col( - ColumnDef::new(KbInfo::CreatedAt) - .timestamp_with_time_zone() - .default(Expr::current_timestamp()) - .not_null() - ) - .col( - ColumnDef::new(KbInfo::UpdatedAt) - .timestamp_with_time_zone() - .default(Expr::current_timestamp()) - .not_null() - ) - .col(ColumnDef::new(KbInfo::IsDeleted).boolean().default(false)) - .to_owned() - ).await?; - - manager.create_table( - Table::create() - .table(DocInfo::Table) - .if_not_exists() - .col( - ColumnDef::new(DocInfo::Did) - .big_integer() - .not_null() - .auto_increment() - .primary_key() - ) - .col(ColumnDef::new(DocInfo::Uid).big_integer().not_null()) - .col(ColumnDef::new(DocInfo::DocName).string().not_null()) - .col(ColumnDef::new(DocInfo::Location).string().not_null()) - .col(ColumnDef::new(DocInfo::Size).big_integer().not_null()) - .col(ColumnDef::new(DocInfo::Type).string().not_null()) - .col(ColumnDef::new(DocInfo::ThumbnailBase64).string().default("")) - .comment("doc type|folder") - .col( - ColumnDef::new(DocInfo::CreatedAt) - .timestamp_with_time_zone() - .default(Expr::current_timestamp()) - .not_null() - ) - .col( - ColumnDef::new(DocInfo::UpdatedAt) - .timestamp_with_time_zone() - .default(Expr::current_timestamp()) - .not_null() - ) - .col(ColumnDef::new(DocInfo::IsDeleted).boolean().default(false)) - .to_owned() - ).await?; - - manager.create_table( - Table::create() - .table(DialogInfo::Table) - .if_not_exists() - .col( - ColumnDef::new(DialogInfo::DialogId) - .big_integer() - .not_null() - .auto_increment() - .primary_key() - ) - .col(ColumnDef::new(DialogInfo::Uid).big_integer().not_null()) - .col(ColumnDef::new(DialogInfo::KbId).big_integer().not_null()) - .col(ColumnDef::new(DialogInfo::DialogName).string().not_null()) - .col(ColumnDef::new(DialogInfo::History).string().comment("json")) - .col( - ColumnDef::new(DialogInfo::CreatedAt) - .timestamp_with_time_zone() - .default(Expr::current_timestamp()) - .not_null() - ) - .col( - ColumnDef::new(DialogInfo::UpdatedAt) - .timestamp_with_time_zone() - .default(Expr::current_timestamp()) - .not_null() - ) - .col(ColumnDef::new(DialogInfo::IsDeleted).boolean().default(false)) - .to_owned() - ).await?; - - let root_insert = Query::insert() - .into_table(UserInfo::Table) - .columns([UserInfo::Email, UserInfo::Nickname, UserInfo::Password]) - .values_panic(["kai.hu@infiniflow.org".into(), "root".into(), "123456".into()]) - .to_owned(); - - let doc_insert = Query::insert() - .into_table(DocInfo::Table) - .columns([ - DocInfo::Uid, - DocInfo::DocName, - DocInfo::Size, - DocInfo::Type, - DocInfo::Location, - ]) - .values_panic([(1).into(), "/".into(), (0).into(), "folder".into(), "".into()]) - .to_owned(); - - let tag_insert = Query::insert() - .into_table(TagInfo::Table) - .columns([TagInfo::Uid, TagInfo::TagName, TagInfo::Regx, TagInfo::Color, TagInfo::Icon]) - .values_panic([ - (1).into(), - "Video".into(), - ".*\\.(mpg|mpeg|avi|rm|rmvb|mov|wmv|asf|dat|asx|wvx|mpe|mpa|mp4)".into(), - (1).into(), - (1).into(), - ]) - .values_panic([ - (1).into(), - "Picture".into(), - ".*\\.(jpg|jpeg|png|tif|gif|pcx|tga|exif|fpx|svg|psd|cdr|pcd|dxf|ufo|eps|ai|raw|WMF|webp|avif|apng|icon|ico)".into(), - (2).into(), - (2).into(), - ]) - .values_panic([ - (1).into(), - "Music".into(), - ".*\\.(wav|flac|ape|alac|wavpack|wv|mp3|aac|ogg|vorbis|opus|mp3)".into(), - (3).into(), - (3).into(), - ]) - .values_panic([ - (1).into(), - "Document".into(), - ".*\\.(pdf|doc|ppt|yml|xml|htm|json|csv|txt|ini|xsl|wps|rtf|hlp|pages|numbers|key)".into(), - (3).into(), - (3).into(), - ]) - .to_owned(); - - manager.exec_stmt(root_insert).await?; - manager.exec_stmt(doc_insert).await?; - manager.exec_stmt(tag_insert).await?; - Ok(()) - } - - async fn down(&self, manager: &SchemaManager) -> Result<(), DbErr> { - manager.drop_table(Table::drop().table(UserInfo::Table).to_owned()).await?; - - manager.drop_table(Table::drop().table(TagInfo::Table).to_owned()).await?; - - manager.drop_table(Table::drop().table(Tag2Doc::Table).to_owned()).await?; - - manager.drop_table(Table::drop().table(Kb2Doc::Table).to_owned()).await?; - - manager.drop_table(Table::drop().table(Dialog2Kb::Table).to_owned()).await?; - - manager.drop_table(Table::drop().table(Doc2Doc::Table).to_owned()).await?; - - manager.drop_table(Table::drop().table(KbInfo::Table).to_owned()).await?; - - manager.drop_table(Table::drop().table(DocInfo::Table).to_owned()).await?; - - manager.drop_table(Table::drop().table(DialogInfo::Table).to_owned()).await?; - - Ok(()) - } -} - -#[derive(DeriveIden)] -enum UserInfo { - Table, - Uid, - Email, - Nickname, - AvatarBase64, - ColorScheme, - ListStyle, - Language, - Password, - LastLoginAt, - CreatedAt, - UpdatedAt, - IsDeleted, -} - -#[derive(DeriveIden)] -enum TagInfo { - Table, - Tid, - Uid, - TagName, - Regx, - Color, - Icon, - FolderId, - CreatedAt, - UpdatedAt, - IsDeleted, -} - -#[derive(DeriveIden)] -enum Tag2Doc { - Table, - Id, - TagId, - Did, -} - -#[derive(DeriveIden)] -enum Kb2Doc { - Table, - Id, - KbId, - Did, - KbProgress, - KbProgressMsg, - UpdatedAt, - IsDeleted, -} - -#[derive(DeriveIden)] -enum Dialog2Kb { - Table, - Id, - DialogId, - KbId, -} - -#[derive(DeriveIden)] -enum Doc2Doc { - Table, - Id, - ParentId, - Did, -} - -#[derive(DeriveIden)] -enum KbInfo { - Table, - KbId, - Uid, - KbName, - Icon, - CreatedAt, - UpdatedAt, - IsDeleted, -} - -#[derive(DeriveIden)] -enum DocInfo { - Table, - Did, - Uid, - DocName, - Location, - Size, - Type, - ThumbnailBase64, - CreatedAt, - UpdatedAt, - IsDeleted, -} - -#[derive(DeriveIden)] -enum DialogInfo { - Table, - Uid, - KbId, - DialogId, - DialogName, - History, - CreatedAt, - UpdatedAt, - IsDeleted, -} diff --git a/migration/src/main.rs b/migration/src/main.rs deleted file mode 100644 index c6b6e48dbc06b53a1e439cd36362080052e26b43..0000000000000000000000000000000000000000 --- a/migration/src/main.rs +++ /dev/null @@ -1,6 +0,0 @@ -use sea_orm_migration::prelude::*; - -#[async_std::main] -async fn main() { - cli::run_cli(migration::Migrator).await; -} diff --git a/python/Dockerfile b/python/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..f8333fdde124b2bcc8963d469250511c3e81be2f --- /dev/null +++ b/python/Dockerfile @@ -0,0 +1,29 @@ +FROM ubuntu:22.04 as base + +RUN apt-get update + +ENV TZ="Asia/Taipei" +RUN apt-get install -yq \ + build-essential \ + curl \ + libncursesw5-dev \ + libssl-dev \ + libsqlite3-dev \ + libgdbm-dev \ + libc6-dev \ + libbz2-dev \ + software-properties-common \ + python3.11 python3.11-dev python3-pip + +RUN apt-get install -yq git +RUN pip3 config set global.index-url https://mirror.baidu.com/pypi/simple +RUN pip3 config set global.trusted-host mirror.baidu.com +RUN pip3 install --upgrade pip +RUN pip3 install torch==2.0.1 +RUN pip3 install torch-model-archiver==0.8.2 +RUN pip3 install torchvision==0.15.2 +COPY requirements.txt . + +WORKDIR /docgpt +ENV PYTHONPATH=/docgpt/ + diff --git a/python/README.md b/python/README.md deleted file mode 100644 index 4f351eb9a62b884dedbf3e45ab8fea44bc0c8330..0000000000000000000000000000000000000000 --- a/python/README.md +++ /dev/null @@ -1,22 +0,0 @@ - -```shell - -docker pull postgres - -LOCAL_POSTGRES_DATA=./postgres-data - -docker run - --name docass-postgres - -p 5455:5432 - -v $LOCAL_POSTGRES_DATA:/var/lib/postgresql/data - -e POSTGRES_USER=root - -e POSTGRES_PASSWORD=infiniflow_docass - -e POSTGRES_DB=docass - -d - postgres - -docker network create elastic -docker pull elasticsearch:8.11.3; -docker pull docker.elastic.co/kibana/kibana:8.11.3 - -``` diff --git a/python/nlp/__init__.py b/python/ToPDF.pdf similarity index 100% rename from python/nlp/__init__.py rename to python/ToPDF.pdf diff --git a/python/] b/python/] new file mode 100644 index 0000000000000000000000000000000000000000..4f413e02b27810bb159d182b0430ee087d920266 --- /dev/null +++ b/python/] @@ -0,0 +1,63 @@ +from abc import ABC +from openai import OpenAI +import os +import base64 +from io import BytesIO + +class Base(ABC): + def describe(self, image, max_tokens=300): + raise NotImplementedError("Please implement encode method!") + + +class GptV4(Base): + def __init__(self): + import openapi + openapi.api_key = os.environ["OPENAPI_KEY"] + self.client = OpenAI() + + def describe(self, image, max_tokens=300): + buffered = BytesIO() + try: + image.save(buffered, format="JPEG") + except Exception as e: + image.save(buffered, format="PNG") + b64 = base64.b64encode(buffered.getvalue()).decode("utf-8") + + res = self.client.chat.completions.create( + model="gpt-4-vision-preview", + messages=[ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等。", + }, + { + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{b64}" + }, + }, + ], + } + ], + max_tokens=max_tokens, + ) + return res.choices[0].message.content.strip() + + +class QWen(Base): + def chat(self, system, history, gen_conf): + from http import HTTPStatus + from dashscope import Generation + from dashscope.api_entities.dashscope_response import Role + # export DASHSCOPE_API_KEY=YOUR_DASHSCOPE_API_KEY + response = Generation.call( + Generation.Models.qwen_turbo, + messages=messages, + result_format='message' + ) + if response.status_code == HTTPStatus.OK: + return response.output.choices[0]['message']['content'] + return response.message diff --git a/python/conf/logging.json b/python/conf/logging.json deleted file mode 100755 index f35a4d766d9b6ddd646248e2c09cb295c0f42784..0000000000000000000000000000000000000000 --- a/python/conf/logging.json +++ /dev/null @@ -1,41 +0,0 @@ -{ - "version":1, - "disable_existing_loggers":false, - "formatters":{ - "simple":{ - "format":"%(asctime)s - %(name)s - %(levelname)s - %(filename)s - %(lineno)d - %(message)s" - } - }, - "handlers":{ - "console":{ - "class":"logging.StreamHandler", - "level":"DEBUG", - "formatter":"simple", - "stream":"ext://sys.stdout" - }, - "info_file_handler":{ - "class":"logging.handlers.TimedRotatingFileHandler", - "level":"INFO", - "formatter":"simple", - "filename":"log/info.log", - "when": "MIDNIGHT", - "interval":1, - "backupCount":30, - "encoding":"utf8" - }, - "error_file_handler":{ - "class":"logging.handlers.TimedRotatingFileHandler", - "level":"ERROR", - "formatter":"simple", - "filename":"log/errors.log", - "when": "MIDNIGHT", - "interval":1, - "backupCount":30, - "encoding":"utf8" - } - }, - "root":{ - "level":"DEBUG", - "handlers":["console","info_file_handler","error_file_handler"] - } -} diff --git a/python/conf/sys.cnf b/python/conf/sys.cnf deleted file mode 100755 index 50a47a43051b892f7f1f32a182541861ab6829c2..0000000000000000000000000000000000000000 --- a/python/conf/sys.cnf +++ /dev/null @@ -1,9 +0,0 @@ -[infiniflow] -es=http://es01:9200 -postgres_user=root -postgres_password=infiniflow_docgpt -postgres_host=postgres -postgres_port=5432 -minio_host=minio:9000 -minio_user=infiniflow -minio_password=infiniflow_docgpt diff --git a/python/llm/__init__.py b/python/llm/__init__.py deleted file mode 100644 index d7b4ef4401535d7cc357816e4e3ea6063fb80f93..0000000000000000000000000000000000000000 --- a/python/llm/__init__.py +++ /dev/null @@ -1,21 +0,0 @@ -import os -from .embedding_model import * -from .chat_model import * -from .cv_model import * - -EmbeddingModel = None -ChatModel = None -CvModel = None - - -if os.environ.get("OPENAI_API_KEY"): - EmbeddingModel = GptEmbed() - ChatModel = GptTurbo() - CvModel = GptV4() - -elif os.environ.get("DASHSCOPE_API_KEY"): - EmbeddingModel = QWenEmbd() - ChatModel = QWenChat() - CvModel = QWenCV() -else: - EmbeddingModel = HuEmbedding() diff --git a/python/llm/embedding_model.py b/python/llm/embedding_model.py deleted file mode 100644 index f39aa675c05a1187b9f9ccc12899411f2140455d..0000000000000000000000000000000000000000 --- a/python/llm/embedding_model.py +++ /dev/null @@ -1,61 +0,0 @@ -from abc import ABC -from openai import OpenAI -from FlagEmbedding import FlagModel -import torch -import os -import numpy as np - - -class Base(ABC): - def encode(self, texts: list, batch_size=32): - raise NotImplementedError("Please implement encode method!") - - -class HuEmbedding(Base): - def __init__(self): - """ - If you have trouble downloading HuggingFace models, -_^ this might help!! - - For Linux: - export HF_ENDPOINT=https://hf-mirror.com - - For Windows: - Good luck - ^_- - - """ - self.model = FlagModel("BAAI/bge-large-zh-v1.5", - query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:", - use_fp16=torch.cuda.is_available()) - - def encode(self, texts: list, batch_size=32): - res = [] - for i in range(0, len(texts), batch_size): - res.extend(self.model.encode(texts[i:i + batch_size]).tolist()) - return np.array(res) - - -class GptEmbed(Base): - def __init__(self): - self.client = OpenAI(api_key=os.envirement["OPENAI_API_KEY"]) - - def encode(self, texts: list, batch_size=32): - res = self.client.embeddings.create(input=texts, - model="text-embedding-ada-002") - return [d["embedding"] for d in res["data"]] - - -class QWenEmbd(Base): - def encode(self, texts: list, batch_size=32, text_type="document"): - # export DASHSCOPE_API_KEY=YOUR_DASHSCOPE_API_KEY - import dashscope - from http import HTTPStatus - res = [] - for txt in texts: - resp = dashscope.TextEmbedding.call( - model=dashscope.TextEmbedding.Models.text_embedding_v2, - input=txt[:2048], - text_type=text_type - ) - res.append(resp["output"]["embeddings"][0]["embedding"]) - return res diff --git a/python/output/ToPDF.pdf b/python/output/ToPDF.pdf new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/python/parser/excel_parser.py b/python/parser/excel_parser.py deleted file mode 100644 index e44611493ed639abb76a778b5a78199acfa09819..0000000000000000000000000000000000000000 --- a/python/parser/excel_parser.py +++ /dev/null @@ -1,25 +0,0 @@ -from openpyxl import load_workbook -import sys -from io import BytesIO - - -class HuExcelParser: - def __call__(self, fnm): - if isinstance(fnm, str): - wb = load_workbook(fnm) - else: - wb = load_workbook(BytesIO(fnm)) - res = [] - for sheetname in wb.sheetnames: - ws = wb[sheetname] - lines = [] - for r in ws.rows: - lines.append( - "\t".join([str(c.value) if c.value is not None else "" for c in r])) - res.append(f"《{sheetname}》\n" + "\n".join(lines)) - return res - - -if __name__ == "__main__": - psr = HuExcelParser() - psr(sys.argv[1]) diff --git a/python/requirements.txt b/python/requirements.txt deleted file mode 100644 index c0b1fb43186ec39c2baec281a80e16e6bbf75c6a..0000000000000000000000000000000000000000 --- a/python/requirements.txt +++ /dev/null @@ -1,194 +0,0 @@ -accelerate==0.24.1 -addict==2.4.0 -aiobotocore==2.7.0 -aiofiles==23.2.1 -aiohttp==3.8.6 -aioitertools==0.11.0 -aiosignal==1.3.1 -aliyun-python-sdk-core==2.14.0 -aliyun-python-sdk-kms==2.16.2 -altair==5.1.2 -anyio==3.7.1 -astor==0.8.1 -async-timeout==4.0.3 -attrdict==2.0.1 -attrs==23.1.0 -Babel==2.13.1 -bce-python-sdk==0.8.92 -beautifulsoup4==4.12.2 -bitsandbytes==0.41.1 -blinker==1.7.0 -botocore==1.31.64 -cachetools==5.3.2 -certifi==2023.7.22 -cffi==1.16.0 -charset-normalizer==3.3.2 -click==8.1.7 -cloudpickle==3.0.0 -contourpy==1.2.0 -crcmod==1.7 -cryptography==41.0.5 -cssselect==1.2.0 -cssutils==2.9.0 -cycler==0.12.1 -Cython==3.0.5 -datasets==2.13.0 -datrie==0.8.2 -decorator==5.1.1 -defusedxml==0.7.1 -dill==0.3.6 -einops==0.7.0 -elastic-transport==8.10.0 -elasticsearch==8.10.1 -elasticsearch-dsl==8.9.0 -et-xmlfile==1.1.0 -fastapi==0.104.1 -ffmpy==0.3.1 -filelock==3.13.1 -fire==0.5.0 -FlagEmbedding==1.1.5 -Flask==3.0.0 -flask-babel==4.0.0 -fonttools==4.44.0 -frozenlist==1.4.0 -fsspec==2023.10.0 -future==0.18.3 -gast==0.5.4 --e -git+https://github.com/ggerganov/llama.cpp.git@5f6e0c0dff1e7a89331e6b25eca9a9fd71324069#egg=gguf&subdirectory=gguf-py -gradio==3.50.2 -gradio_client==0.6.1 -greenlet==3.0.1 -h11==0.14.0 -hanziconv==0.3.2 -httpcore==1.0.1 -httpx==0.25.1 -huggingface-hub==0.17.3 -idna==3.4 -imageio==2.31.6 -imgaug==0.4.0 -importlib-metadata==6.8.0 -importlib-resources==6.1.0 -install==1.3.5 -itsdangerous==2.1.2 -Jinja2==3.1.2 -jmespath==0.10.0 -joblib==1.3.2 -jsonschema==4.19.2 -jsonschema-specifications==2023.7.1 -kiwisolver==1.4.5 -lazy_loader==0.3 -lmdb==1.4.1 -lxml==4.9.3 -MarkupSafe==2.1.3 -matplotlib==3.8.1 -modelscope==1.9.4 -mpmath==1.3.0 -multidict==6.0.4 -multiprocess==0.70.14 -networkx==3.2.1 -nltk==3.8.1 -numpy==1.24.4 -nvidia-cublas-cu12==12.1.3.1 -nvidia-cuda-cupti-cu12==12.1.105 -nvidia-cuda-nvrtc-cu12==12.1.105 -nvidia-cuda-runtime-cu12==12.1.105 -nvidia-cudnn-cu12==8.9.2.26 -nvidia-cufft-cu12==11.0.2.54 -nvidia-curand-cu12==10.3.2.106 -nvidia-cusolver-cu12==11.4.5.107 -nvidia-cusparse-cu12==12.1.0.106 -nvidia-nccl-cu12==2.18.1 -nvidia-nvjitlink-cu12==12.3.52 -nvidia-nvtx-cu12==12.1.105 -opencv-contrib-python==4.6.0.66 -opencv-python==4.6.0.66 -openpyxl==3.1.2 -opt-einsum==3.3.0 -orjson==3.9.10 -oss2==2.18.3 -packaging==23.2 -paddleocr==2.7.0.3 -paddlepaddle-gpu==2.5.2.post120 -pandas==2.1.2 -pdf2docx==0.5.5 -pdfminer.six==20221105 -pdfplumber==0.10.3 -Pillow==10.0.1 -platformdirs==3.11.0 -premailer==3.10.0 -protobuf==4.25.0 -psutil==5.9.6 -pyarrow==14.0.0 -pyclipper==1.3.0.post5 -pycocotools==2.0.7 -pycparser==2.21 -pycryptodome==3.19.0 -pydantic==1.10.13 -pydub==0.25.1 -PyMuPDF==1.20.2 -pyparsing==3.1.1 -pypdfium2==4.23.1 -python-dateutil==2.8.2 -python-docx==1.1.0 -python-multipart==0.0.6 -pytz==2023.3.post1 -PyYAML==6.0.1 -rapidfuzz==3.5.2 -rarfile==4.1 -referencing==0.30.2 -regex==2023.10.3 -requests==2.31.0 -rpds-py==0.12.0 -s3fs==2023.10.0 -safetensors==0.4.0 -scikit-image==0.22.0 -scikit-learn==1.3.2 -scipy==1.11.3 -semantic-version==2.10.0 -sentence-transformers==2.2.2 -sentencepiece==0.1.98 -shapely==2.0.2 -simplejson==3.19.2 -six==1.16.0 -sniffio==1.3.0 -sortedcontainers==2.4.0 -soupsieve==2.5 -SQLAlchemy==2.0.23 -starlette==0.27.0 -sympy==1.12 -tabulate==0.9.0 -tblib==3.0.0 -termcolor==2.3.0 -threadpoolctl==3.2.0 -tifffile==2023.9.26 -tiktoken==0.5.1 -timm==0.9.10 -tokenizers==0.13.3 -tomli==2.0.1 -toolz==0.12.0 -torch==2.1.0 -torchaudio==2.1.0 -torchvision==0.16.0 -tornado==6.3.3 -tqdm==4.66.1 -transformers==4.33.0 -transformers-stream-generator==0.0.4 -triton==2.1.0 -typing_extensions==4.8.0 -tzdata==2023.3 -urllib3==2.0.7 -uvicorn==0.24.0 -uvloop==0.19.0 -visualdl==2.5.3 -websockets==11.0.3 -Werkzeug==3.0.1 -wrapt==1.15.0 -xgboost==2.0.1 -xinference==0.6.0 -xorbits==0.7.0 -xoscar==0.1.3 -xxhash==3.4.1 -yapf==0.40.2 -yarl==1.9.2 -zipp==3.17.0 diff --git a/python/res/1-0.tm b/python/res/1-0.tm new file mode 100644 index 0000000000000000000000000000000000000000..ae009308ee3203a2b57f1060470e075e2b0751b1 --- /dev/null +++ b/python/res/1-0.tm @@ -0,0 +1,8 @@ +2023-12-20 11:44:08.791336+00:00 +2023-12-20 11:44:08.853249+00:00 +2023-12-20 11:44:08.909933+00:00 +2023-12-21 00:47:09.996757+00:00 +2023-12-20 11:44:08.965855+00:00 +2023-12-20 11:44:09.011682+00:00 +2023-12-21 00:47:10.063326+00:00 +2023-12-20 11:44:09.069486+00:00 diff --git a/python/res/thumbnail-1-0.tm b/python/res/thumbnail-1-0.tm new file mode 100644 index 0000000000000000000000000000000000000000..dbadd02326aa68820d7320ff20ed07e3cb3f955d --- /dev/null +++ b/python/res/thumbnail-1-0.tm @@ -0,0 +1,3 @@ +2023-12-27 08:21:49.309802+00:00 +2023-12-27 08:37:22.407772+00:00 +2023-12-27 08:59:18.845627+00:00 diff --git a/python/svr/add_thumbnail2file.py b/python/svr/add_thumbnail2file.py deleted file mode 100644 index e4558ca9a2f601b4cced9fd473815e6653a817e8..0000000000000000000000000000000000000000 --- a/python/svr/add_thumbnail2file.py +++ /dev/null @@ -1,118 +0,0 @@ -import sys, datetime, random, re, cv2 -from os.path import dirname, realpath -sys.path.append(dirname(realpath(__file__)) + "/../") -from util.db_conn import Postgres -from util.minio_conn import HuMinio -from util import findMaxDt -import base64 -from io import BytesIO -import pandas as pd -from PIL import Image -import pdfplumber - - -PG = Postgres("infiniflow", "docgpt") -MINIO = HuMinio("infiniflow") -def set_thumbnail(did, base64): - sql = f""" - update doc_info set thumbnail_base64='{base64}' - where - did={did} - """ - PG.update(sql) - - -def collect(comm, mod, tm): - sql = f""" - select - did, uid, doc_name, location, updated_at - from doc_info - where - updated_at >= '{tm}' - and MOD(did, {comm}) = {mod} - and is_deleted=false - and type <> 'folder' - and thumbnail_base64='' - order by updated_at asc - limit 10 - """ - docs = PG.select(sql) - if len(docs) == 0:return pd.DataFrame() - - mtm = str(docs["updated_at"].max())[:19] - print("TOTAL:", len(docs), "To: ", mtm) - return docs - - -def build(row): - if not re.search(r"\.(pdf|jpg|jpeg|png|gif|svg|apng|icon|ico|webp|mpg|mpeg|avi|rm|rmvb|mov|wmv|mp4)$", - row["doc_name"].lower().strip()): - set_thumbnail(row["did"], "_") - return - - def thumbnail(img, SIZE=128): - w,h = img.size - p = SIZE/max(w, h) - w, h = int(w*p), int(h*p) - img.thumbnail((w, h)) - buffered = BytesIO() - try: - img.save(buffered, format="JPEG") - except Exception as e: - try: - img.save(buffered, format="PNG") - except Exception as ee: - pass - return base64.b64encode(buffered.getvalue()).decode("utf-8") - - - iobytes = BytesIO(MINIO.get("%s-upload"%str(row["uid"]), row["location"])) - if re.search(r"\.pdf$", row["doc_name"].lower().strip()): - pdf = pdfplumber.open(iobytes) - img = pdf.pages[0].to_image().annotated - set_thumbnail(row["did"], thumbnail(img)) - - if re.search(r"\.(jpg|jpeg|png|gif|svg|apng|webp|icon|ico)$", row["doc_name"].lower().strip()): - img = Image.open(iobytes) - set_thumbnail(row["did"], thumbnail(img)) - - if re.search(r"\.(mpg|mpeg|avi|rm|rmvb|mov|wmv|mp4)$", row["doc_name"].lower().strip()): - url = MINIO.get_presigned_url("%s-upload"%str(row["uid"]), - row["location"], - expires=datetime.timedelta(seconds=60) - ) - cap = cv2.VideoCapture(url) - succ = cap.isOpened() - i = random.randint(1, 11) - while succ: - ret, frame = cap.read() - if not ret: break - if i > 0: - i -= 1 - continue - img = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) - print(img.size) - set_thumbnail(row["did"], thumbnail(img)) - cap.release() - cv2.destroyAllWindows() - - -def main(comm, mod): - global model - tm_fnm = f"res/thumbnail-{comm}-{mod}.tm" - tm = findMaxDt(tm_fnm) - rows = collect(comm, mod, tm) - if len(rows) == 0:return - - tmf = open(tm_fnm, "a+") - for _, r in rows.iterrows(): - build(r) - tmf.write(str(r["updated_at"]) + "\n") - tmf.close() - - -if __name__ == "__main__": - from mpi4py import MPI - comm = MPI.COMM_WORLD - main(comm.Get_size(), comm.Get_rank()) - diff --git a/python/svr/dialog_svr.py b/python/svr/dialog_svr.py deleted file mode 100755 index 9e93dccffc8893d8c352f1642e4dd47b3eecb0cc..0000000000000000000000000000000000000000 --- a/python/svr/dialog_svr.py +++ /dev/null @@ -1,165 +0,0 @@ -#-*- coding:utf-8 -*- -import sys, os, re,inspect,json,traceback,logging,argparse, copy -sys.path.append(os.path.realpath(os.path.dirname(inspect.getfile(inspect.currentframe())))+"/../") -from tornado.web import RequestHandler,Application -from tornado.ioloop import IOLoop -from tornado.httpserver import HTTPServer -from tornado.options import define,options -from util import es_conn, setup_logging -from sklearn.metrics.pairwise import cosine_similarity as CosineSimilarity -from nlp import huqie -from nlp import query as Query -from nlp import search -from llm import HuEmbedding, GptTurbo -import numpy as np -from io import BytesIO -from util import config -from timeit import default_timer as timer -from collections import OrderedDict -from llm import ChatModel, EmbeddingModel - -SE = None -CFIELD="content_ltks" -EMBEDDING = EmbeddingModel -LLM = ChatModel - -def get_QA_pairs(hists): - pa = [] - for h in hists: - for k in ["user", "assistant"]: - if h.get(k): - pa.append({ - "content": h[k], - "role": k, - }) - - for p in pa[:-1]: assert len(p) == 2, p - return pa - - - -def get_instruction(sres, top_i, max_len=8096, fld="content_ltks"): - max_len //= len(top_i) - # add instruction to prompt - instructions = [re.sub(r"[\r\n]+", " ", sres.field[sres.ids[i]][fld]) for i in top_i] - if len(instructions)>2: - # Said that LLM is sensitive to the first and the last one, so - # rearrange the order of references - instructions.append(copy.deepcopy(instructions[1])) - instructions.pop(1) - - def token_num(txt): - c = 0 - for tk in re.split(r"[,。/?‘’”“:;:;!!]", txt): - if re.match(r"[a-zA-Z-]+$", tk): - c += 1 - continue - c += len(tk) - return c - - _inst = "" - for ins in instructions: - if token_num(_inst) > 4096: - _inst += "\n知识库:" + instructions[-1][:max_len] - break - _inst += "\n知识库:" + ins[:max_len] - return _inst - - -def prompt_and_answer(history, inst): - hist = get_QA_pairs(history) - chks = [] - for s in re.split(r"[::;;。\n\r]+", inst): - if s: chks.append(s) - chks = len(set(chks))/(0.1+len(chks)) - print("Duplication portion:", chks) - - system = """ -你是一个智能助手,请总结知识库的内容来回答问题,请列举知识库中的数据详细回答%s。当所有知识库内容都与问题无关时,你的回答必须包括"知识库中未找到您要的答案!这是我所知道的,仅作参考。"这句话。回答需要考虑聊天历史。 -以下是知识库: -%s -以上是知识库。 -"""%((",最好总结成表格" if chks<0.6 and chks>0 else ""), inst) - - print("【PROMPT】:", system) - start = timer() - response = LLM.chat(system, hist, {"temperature": 0.2, "max_tokens": 512}) - print("GENERATE: ", timer()-start) - print("===>>", response) - return response - - -class Handler(RequestHandler): - def post(self): - global SE,MUST_TK_NUM - param = json.loads(self.request.body.decode('utf-8')) - try: - question = param.get("history",[{"user": "Hi!"}])[-1]["user"] - res = SE.search({ - "question": question, - "kb_ids": param.get("kb_ids", []), - "size": param.get("topn", 15)}, - search.index_name(param["uid"]) - ) - - sim = SE.rerank(res, question) - rk_idx = np.argsort(sim*-1) - topidx = [i for i in rk_idx if sim[i] >= aram.get("similarity", 0.5)][:param.get("topn",12)] - inst = get_instruction(res, topidx) - - ans, topidx = prompt_and_answer(param["history"], inst) - ans = SE.insert_citations(ans, topidx, res) - - refer = OrderedDict() - docnms = {} - for i in rk_idx: - did = res.field[res.ids[i]]["doc_id"] - if did not in docnms: docnms[did] = res.field[res.ids[i]]["docnm_kwd"] - if did not in refer: refer[did] = [] - refer[did].append({ - "chunk_id": res.ids[i], - "content": res.field[res.ids[i]]["content_ltks"], - "image": "" - }) - - print("::::::::::::::", ans) - self.write(json.dumps({ - "code":0, - "msg":"success", - "data":{ - "uid": param["uid"], - "dialog_id": param["dialog_id"], - "assistant": ans, - "refer": [{ - "did": did, - "doc_name": docnms[did], - "chunks": chunks - } for did, chunks in refer.items()] - } - })) - logging.info("SUCCESS[%d]"%(res.total)+json.dumps(param, ensure_ascii=False)) - - except Exception as e: - logging.error("Request 500: "+str(e)) - self.write(json.dumps({ - "code":500, - "msg":str(e), - "data":{} - })) - print(traceback.format_exc()) - - -if __name__ == '__main__': - parser = argparse.ArgumentParser() - parser.add_argument("--port", default=4455, type=int, help="Port used for service") - ARGS = parser.parse_args() - - SE = search.Dealer(es_conn.HuEs("infiniflow"), EMBEDDING) - - app = Application([(r'/v1/chat/completions', Handler)],debug=False) - http_server = HTTPServer(app) - http_server.bind(ARGS.port) - http_server.start(3) - - IOLoop.current().start() - diff --git a/python/svr/parse_user_docs.py b/python/svr/parse_user_docs.py deleted file mode 100644 index b129c7b010cef1e55014cd2a4c4c7533953dbf16..0000000000000000000000000000000000000000 --- a/python/svr/parse_user_docs.py +++ /dev/null @@ -1,258 +0,0 @@ -import json, os, sys, hashlib, copy, time, random, re -from os.path import dirname, realpath -sys.path.append(dirname(realpath(__file__)) + "/../") -from util.es_conn import HuEs -from util.db_conn import Postgres -from util.minio_conn import HuMinio -from util import rmSpace, findMaxDt -from FlagEmbedding import FlagModel -from nlp import huchunk, huqie, search -from io import BytesIO -import pandas as pd -from elasticsearch_dsl import Q -from PIL import Image -from parser import ( - PdfParser, - DocxParser, - ExcelParser -) -from nlp.huchunk import ( - PdfChunker, - DocxChunker, - ExcelChunker, - PptChunker, - TextChunker -) - -ES = HuEs("infiniflow") -BATCH_SIZE = 64 -PG = Postgres("infiniflow", "docgpt") -MINIO = HuMinio("infiniflow") - -PDF = PdfChunker(PdfParser()) -DOC = DocxChunker(DocxParser()) -EXC = ExcelChunker(ExcelParser()) -PPT = PptChunker() - -def chuck_doc(name, binary): - suff = os.path.split(name)[-1].lower().split(".")[-1] - if suff.find("pdf") >= 0: return PDF(binary) - if suff.find("doc") >= 0: return DOC(binary) - if re.match(r"(xlsx|xlsm|xltx|xltm)", suff): return EXC(binary) - if suff.find("ppt") >= 0: return PPT(binary) - if os.envirement.get("PARSE_IMAGE") \ - and re.search(r"\.(jpg|jpeg|png|tif|gif|pcx|tga|exif|fpx|svg|psd|cdr|pcd|dxf|ufo|eps|ai|raw|WMF|webp|avif|apng|icon|ico)$", - name.lower()): - from llm import CvModel - txt = CvModel.describe(binary) - field = TextChunker.Fields() - field.text_chunks = [(txt, binary)] - field.table_chunks = [] - - - return TextChunker()(binary) - - -def collect(comm, mod, tm): - sql = f""" - select - id as kb2doc_id, - kb_id, - did, - updated_at, - is_deleted - from kb2_doc - where - updated_at >= '{tm}' - and kb_progress = 0 - and MOD(did, {comm}) = {mod} - order by updated_at asc - limit 1000 - """ - kb2doc = PG.select(sql) - if len(kb2doc) == 0:return pd.DataFrame() - - sql = """ - select - did, - uid, - doc_name, - location, - size - from doc_info - where - did in (%s) - """%",".join([str(i) for i in kb2doc["did"].unique()]) - docs = PG.select(sql) - docs = docs.fillna("") - docs = docs.join(kb2doc.set_index("did"), on="did", how="left") - - mtm = str(docs["updated_at"].max())[:19] - print("TOTAL:", len(docs), "To: ", mtm) - return docs - - -def set_progress(kb2doc_id, prog, msg="Processing..."): - sql = f""" - update kb2_doc set kb_progress={prog}, kb_progress_msg='{msg}' - where - id={kb2doc_id} - """ - PG.update(sql) - - -def build(row): - if row["size"] > 256000000: - set_progress(row["kb2doc_id"], -1, "File size exceeds( <= 256Mb )") - return [] - res = ES.search(Q("term", doc_id=row["did"])) - if ES.getTotal(res) > 0: - ES.updateScriptByQuery(Q("term", doc_id=row["did"]), - scripts=""" - if(!ctx._source.kb_id.contains('%s')) - ctx._source.kb_id.add('%s'); - """%(str(row["kb_id"]), str(row["kb_id"])), - idxnm = search.index_name(row["uid"]) - ) - set_progress(row["kb2doc_id"], 1, "Done") - return [] - - random.seed(time.time()) - set_progress(row["kb2doc_id"], random.randint(0, 20)/100., "Finished preparing! Start to slice file!") - try: - obj = chuck_doc(row["doc_name"], MINIO.get("%s-upload"%str(row["uid"]), row["location"])) - except Exception as e: - if re.search("(No such file|not found)", str(e)): - set_progress(row["kb2doc_id"], -1, "Can not find file <%s>"%row["doc_name"]) - else: - set_progress(row["kb2doc_id"], -1, f"Internal system error: %s"%str(e).replace("'", "")) - return [] - - if not obj.text_chunks and not obj.table_chunks: - set_progress(row["kb2doc_id"], 1, "Nothing added! Mostly, file type unsupported yet.") - return [] - - set_progress(row["kb2doc_id"], random.randint(20, 60)/100., "Finished slicing files. Start to embedding the content.") - - doc = { - "doc_id": row["did"], - "kb_id": [str(row["kb_id"])], - "docnm_kwd": os.path.split(row["location"])[-1], - "title_tks": huqie.qie(os.path.split(row["location"])[-1]), - "updated_at": str(row["updated_at"]).replace("T", " ")[:19] - } - doc["title_sm_tks"] = huqie.qieqie(doc["title_tks"]) - output_buffer = BytesIO() - docs = [] - md5 = hashlib.md5() - for txt, img in obj.text_chunks: - d = copy.deepcopy(doc) - md5.update((txt + str(d["doc_id"])).encode("utf-8")) - d["_id"] = md5.hexdigest() - d["content_ltks"] = huqie.qie(txt) - d["content_sm_ltks"] = huqie.qieqie(d["content_ltks"]) - if not img: - docs.append(d) - continue - - if isinstance(img, Image): img.save(output_buffer, format='JPEG') - else: output_buffer = BytesIO(img) - - MINIO.put("{}-{}".format(row["uid"], row["kb_id"]), d["_id"], - output_buffer.getvalue()) - d["img_id"] = "{}-{}".format(row["uid"], row["kb_id"]) - docs.append(d) - - for arr, img in obj.table_chunks: - for i, txt in enumerate(arr): - d = copy.deepcopy(doc) - d["content_ltks"] = huqie.qie(txt) - md5.update((txt + str(d["doc_id"])).encode("utf-8")) - d["_id"] = md5.hexdigest() - if not img: - docs.append(d) - continue - img.save(output_buffer, format='JPEG') - MINIO.put("{}-{}".format(row["uid"], row["kb_id"]), d["_id"], - output_buffer.getvalue()) - d["img_id"] = "{}-{}".format(row["uid"], row["kb_id"]) - docs.append(d) - set_progress(row["kb2doc_id"], random.randint(60, 70)/100., "Continue embedding the content.") - - return docs - - -def init_kb(row): - idxnm = search.index_name(row["uid"]) - if ES.indexExist(idxnm): return - return ES.createIdx(idxnm, json.load(open("conf/mapping.json", "r"))) - - -model = None -def embedding(docs): - global model - tts = model.encode([rmSpace(d["title_tks"]) for d in docs]) - cnts = model.encode([rmSpace(d["content_ltks"]) for d in docs]) - vects = 0.1 * tts + 0.9 * cnts - assert len(vects) == len(docs) - for i,d in enumerate(docs):d["q_vec"] = vects[i].tolist() - - -def rm_doc_from_kb(df): - if len(df) == 0:return - for _,r in df.iterrows(): - ES.updateScriptByQuery(Q("term", doc_id=r["did"]), - scripts=""" - if(ctx._source.kb_id.contains('%s')) - ctx._source.kb_id.remove( - ctx._source.kb_id.indexOf('%s') - ); - """%(str(r["kb_id"]),str(r["kb_id"])), - idxnm = search.index_name(r["uid"]) - ) - if len(df) == 0:return - sql = """ - delete from kb2_doc where id in (%s) - """%",".join([str(i) for i in df["kb2doc_id"]]) - PG.update(sql) - - -def main(comm, mod): - global model - from llm import HuEmbedding - model = HuEmbedding() - tm_fnm = f"res/{comm}-{mod}.tm" - tm = findMaxDt(tm_fnm) - rows = collect(comm, mod, tm) - if len(rows) == 0:return - - rm_doc_from_kb(rows.loc[rows.is_deleted == True]) - rows = rows.loc[rows.is_deleted == False].reset_index(drop=True) - if len(rows) == 0:return - tmf = open(tm_fnm, "a+") - for _, r in rows.iterrows(): - cks = build(r) - if not cks: - tmf.write(str(r["updated_at"]) + "\n") - continue - ## TODO: exception handler - ## set_progress(r["did"], -1, "ERROR: ") - embedding(cks) - - set_progress(r["kb2doc_id"], random.randint(70, 95)/100., - "Finished embedding! Start to build index!") - init_kb(r) - es_r = ES.bulk(cks, search.index_name(r["uid"])) - if es_r: - set_progress(r["kb2doc_id"], -1, "Index failure!") - print(es_r) - else: set_progress(r["kb2doc_id"], 1., "Done!") - tmf.write(str(r["updated_at"]) + "\n") - tmf.close() - - -if __name__ == "__main__": - from mpi4py import MPI - comm = MPI.COMM_WORLD - main(comm.Get_size(), comm.Get_rank()) - diff --git a/python/tmp.log b/python/tmp.log new file mode 100644 index 0000000000000000000000000000000000000000..2924c5f29821f02345b5347be28ce52720ab3f76 --- /dev/null +++ b/python/tmp.log @@ -0,0 +1,15 @@ + Fetching 6 files: 0%| | 0/6 [00:00/?;'\[\]\\`!@#$%^&*\(\)\{\}\|_+=《》,。?、;‘’:“”【】~!¥%……()——-]+|[a-z\.-]+|[0-9,\.-]+)" try: diff --git a/python/nlp/query.py b/rag/nlp/query.py old mode 100755 new mode 100644 similarity index 98% rename from python/nlp/query.py rename to rag/nlp/query.py index 6b8f0ab8064fd8e0fde5acf7aa3be6e113c6caf8..de1edd7541af2b3d26571d19558b0ed0d64e5bf3 --- a/python/nlp/query.py +++ b/rag/nlp/query.py @@ -1,12 +1,12 @@ +# -*- coding: utf-8 -*- + import json import re -import sys -import os import logging import copy import math from elasticsearch_dsl import Q, Search -from nlp import huqie, term_weight, synonym +from rag.nlp import huqie, term_weight, synonym class EsQueryer: diff --git a/python/nlp/search.py b/rag/nlp/search.py similarity index 97% rename from python/nlp/search.py rename to rag/nlp/search.py index 2388e2a50f4adacd6b7e038d7f95c0e63f212200..05ce6276fbd9b7877c70e50627c6d35946efd320 100644 --- a/python/nlp/search.py +++ b/rag/nlp/search.py @@ -1,13 +1,11 @@ +# -*- coding: utf-8 -*- import re from elasticsearch_dsl import Q, Search, A from typing import List, Optional, Tuple, Dict, Union from dataclasses import dataclass -from util import setup_logging, rmSpace -from nlp import huqie, query -from datetime import datetime -from sklearn.metrics.pairwise import cosine_similarity as CosineSimilarity +from rag.utils import rmSpace +from rag.nlp import huqie, query import numpy as np -from copy import deepcopy def index_name(uid): return f"docgpt_{uid}" diff --git a/python/nlp/synonym.py b/rag/nlp/synonym.py old mode 100755 new mode 100644 similarity index 76% rename from python/nlp/synonym.py rename to rag/nlp/synonym.py index cbe88ce1ef359cbcfe574c23629565f8efde9d74..895fab3d37ca704e8ca8b5ba7e4c3193e7934927 --- a/python/nlp/synonym.py +++ b/rag/nlp/synonym.py @@ -1,8 +1,11 @@ import json +import os import time import logging import re +from web_server.utils.file_utils import get_project_base_directory + class Dealer: def __init__(self, redis=None): @@ -10,18 +13,12 @@ class Dealer: self.lookup_num = 100000000 self.load_tm = time.time() - 1000000 self.dictionary = None + path = os.path.join(get_project_base_directory(), "rag/res", "synonym.json") try: - self.dictionary = json.load(open("./synonym.json", 'r')) - except Exception as e: - pass - try: - self.dictionary = json.load(open("./res/synonym.json", 'r')) + self.dictionary = json.load(open(path, 'r')) except Exception as e: - try: - self.dictionary = json.load(open("../res/synonym.json", 'r')) - except Exception as e: - logging.warn("Miss synonym.json") - self.dictionary = {} + logging.warn("Miss synonym.json") + self.dictionary = {} if not redis: logging.warning( diff --git a/python/nlp/term_weight.py b/rag/nlp/term_weight.py old mode 100755 new mode 100644 similarity index 95% rename from python/nlp/term_weight.py rename to rag/nlp/term_weight.py index 4cd3d74c8ae6964d4bdcd6ede05b1b3495d7408c..14e8bfc3cfa54e70aa0e48ebc4cb086d1f781d6b --- a/python/nlp/term_weight.py +++ b/rag/nlp/term_weight.py @@ -1,9 +1,11 @@ +# -*- coding: utf-8 -*- import math import json import re import os import numpy as np -from nlp import huqie +from rag.nlp import huqie +from web_server.utils.file_utils import get_project_base_directory class Dealer: @@ -60,16 +62,14 @@ class Dealer: return set(res.keys()) return res - fnm = os.path.join(os.path.dirname(__file__), '../res/') - if not os.path.exists(fnm): - fnm = os.path.join(os.path.dirname(__file__), '../../res/') + fnm = os.path.join(get_project_base_directory(), "res") self.ne, self.df = {}, {} try: - self.ne = json.load(open(fnm + "ner.json", "r")) + self.ne = json.load(open(os.path.join(fnm, "ner.json"), "r")) except Exception as e: print("[WARNING] Load ner.json FAIL!") try: - self.df = load_dict(fnm + "term.freq") + self.df = load_dict(os.path.join(fnm, "term.freq")) except Exception as e: print("[WARNING] Load term.freq FAIL!") diff --git a/python/parser/__init__.py b/rag/parser/__init__.py similarity index 100% rename from python/parser/__init__.py rename to rag/parser/__init__.py diff --git a/python/parser/docx_parser.py b/rag/parser/docx_parser.py similarity index 98% rename from python/parser/docx_parser.py rename to rag/parser/docx_parser.py index 5968b0eaee8c17102529a057881dd1561a6f9a33..ae63a6839e133a5f74dd713dd3d456d147b30c8a 100644 --- a/python/parser/docx_parser.py +++ b/rag/parser/docx_parser.py @@ -1,8 +1,9 @@ +# -*- coding: utf-8 -*- from docx import Document import re import pandas as pd from collections import Counter -from nlp import huqie +from rag.nlp import huqie from io import BytesIO diff --git a/rag/parser/excel_parser.py b/rag/parser/excel_parser.py new file mode 100644 index 0000000000000000000000000000000000000000..10f3b283a106e7110b0e538ae19a4ea9d5040097 --- /dev/null +++ b/rag/parser/excel_parser.py @@ -0,0 +1,33 @@ +# -*- coding: utf-8 -*- +from openpyxl import load_workbook +import sys +from io import BytesIO + + +class HuExcelParser: + def __call__(self, fnm): + if isinstance(fnm, str): + wb = load_workbook(fnm) + else: + wb = load_workbook(BytesIO(fnm)) + res = [] + for sheetname in wb.sheetnames: + ws = wb[sheetname] + rows = list(ws.rows) + ti = list(rows[0]) + for r in list(rows[1:]): + l = [] + for i,c in enumerate(r): + if not c.value:continue + t = str(ti[i].value) if i < len(ti) else "" + t += (":" if t else "") + str(c.value) + l.append(t) + l = "; ".join(l) + if sheetname.lower().find("sheet") <0: l += " ——"+sheetname + res.append(l) + return res + + +if __name__ == "__main__": + psr = HuExcelParser() + psr(sys.argv[1]) diff --git a/python/parser/pdf_parser.py b/rag/parser/pdf_parser.py similarity index 99% rename from python/parser/pdf_parser.py rename to rag/parser/pdf_parser.py index 71aa50dfd094c8472d4bf0080d0114c6dabe51a3..ea2bb6c584ea07593438d09195b57294e55db14d 100644 --- a/python/parser/pdf_parser.py +++ b/rag/parser/pdf_parser.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- import xgboost as xgb from io import BytesIO import torch @@ -6,11 +7,11 @@ import pdfplumber import logging from PIL import Image import numpy as np -from nlp import huqie +from rag.nlp import huqie from collections import Counter from copy import deepcopy -from cv.table_recognize import TableTransformer -from cv.ppdetection import PPDet +from rag.cv.table_recognize import TableTransformer +from rag.cv.ppdetection import PPDet from huggingface_hub import hf_hub_download logging.getLogger("pdfminer").setLevel(logging.WARNING) diff --git a/python/res/huqie.txt b/rag/res/huqie.txt similarity index 100% rename from python/res/huqie.txt rename to rag/res/huqie.txt diff --git a/python/res/ner.json b/rag/res/ner.json old mode 100755 new mode 100644 similarity index 100% rename from python/res/ner.json rename to rag/res/ner.json diff --git a/python/res/synonym.json b/rag/res/synonym.json similarity index 100% rename from python/res/synonym.json rename to rag/res/synonym.json diff --git a/rag/settings.py b/rag/settings.py new file mode 100644 index 0000000000000000000000000000000000000000..613abe75c3789dc599169c3fb938a7abf3f1c17c --- /dev/null +++ b/rag/settings.py @@ -0,0 +1,37 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import os +from web_server.utils import get_base_config,decrypt_database_config +from web_server.utils.file_utils import get_project_base_directory +from web_server.utils.log_utils import LoggerFactory, getLogger + + +# Server +RAG_CONF_PATH = os.path.join(get_project_base_directory(), "conf") +SUBPROCESS_STD_LOG_NAME = "std.log" + +ES = get_base_config("es", {}) +MINIO = decrypt_database_config(name="minio") +DOC_MAXIMUM_SIZE = 64 * 1024 * 1024 + +# Logger +LoggerFactory.set_directory(os.path.join(get_project_base_directory(), "logs", "rag")) +# {CRITICAL: 50, FATAL:50, ERROR:40, WARNING:30, WARN:30, INFO:20, DEBUG:10, NOTSET:0} +LoggerFactory.LEVEL = 10 + +es_logger = getLogger("es") +minio_logger = getLogger("minio") +cron_logger = getLogger("cron_logger") diff --git a/rag/svr/parse_user_docs.py b/rag/svr/parse_user_docs.py new file mode 100644 index 0000000000000000000000000000000000000000..29d2c28764e93d654330478dd2cd4d7f52ce1808 --- /dev/null +++ b/rag/svr/parse_user_docs.py @@ -0,0 +1,279 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import json +import os +import hashlib +import copy +import time +import random +import re +from timeit import default_timer as timer + +from rag.llm import EmbeddingModel, CvModel +from rag.settings import cron_logger, DOC_MAXIMUM_SIZE +from rag.utils import ELASTICSEARCH, num_tokens_from_string +from rag.utils import MINIO +from rag.utils import rmSpace, findMaxDt +from rag.nlp import huchunk, huqie, search +from io import BytesIO +import pandas as pd +from elasticsearch_dsl import Q +from PIL import Image +from rag.parser import ( + PdfParser, + DocxParser, + ExcelParser +) +from rag.nlp.huchunk import ( + PdfChunker, + DocxChunker, + ExcelChunker, + PptChunker, + TextChunker +) +from web_server.db import LLMType +from web_server.db.services.document_service import DocumentService +from web_server.db.services.llm_service import TenantLLMService +from web_server.utils import get_format_time +from web_server.utils.file_utils import get_project_base_directory + +BATCH_SIZE = 64 + +PDF = PdfChunker(PdfParser()) +DOC = DocxChunker(DocxParser()) +EXC = ExcelChunker(ExcelParser()) +PPT = PptChunker() + + +def chuck_doc(name, binary, cvmdl=None): + suff = os.path.split(name)[-1].lower().split(".")[-1] + if suff.find("pdf") >= 0: + return PDF(binary) + if suff.find("doc") >= 0: + return DOC(binary) + if re.match(r"(xlsx|xlsm|xltx|xltm)", suff): + return EXC(binary) + if suff.find("ppt") >= 0: + return PPT(binary) + if cvmdl and re.search(r"\.(jpg|jpeg|png|tif|gif|pcx|tga|exif|fpx|svg|psd|cdr|pcd|dxf|ufo|eps|ai|raw|WMF|webp|avif|apng|icon|ico)$", + name.lower()): + txt = cvmdl.describe(binary) + field = TextChunker.Fields() + field.text_chunks = [(txt, binary)] + field.table_chunks = [] + + return TextChunker()(binary) + + +def collect(comm, mod, tm): + docs = DocumentService.get_newly_uploaded(tm, mod, comm) + if len(docs) == 0: + return pd.DataFrame() + docs = pd.DataFrame(docs) + mtm = str(docs["update_time"].max())[:19] + cron_logger.info("TOTAL:{}, To:{}".format(len(docs), mtm)) + return docs + + +def set_progress(docid, prog, msg="Processing...", begin=False): + d = {"progress": prog, "progress_msg": msg} + if begin: + d["process_begin_at"] = get_format_time() + try: + DocumentService.update_by_id( + docid, {"progress": prog, "progress_msg": msg}) + except Exception as e: + cron_logger.error("set_progress:({}), {}".format(docid, str(e))) + + +def build(row): + if row["size"] > DOC_MAXIMUM_SIZE: + set_progress(row["id"], -1, "File size exceeds( <= %dMb )" % + (int(DOC_MAXIMUM_SIZE / 1024 / 1024))) + return [] + res = ELASTICSEARCH.search(Q("term", doc_id=row["id"])) + if ELASTICSEARCH.getTotal(res) > 0: + ELASTICSEARCH.updateScriptByQuery(Q("term", doc_id=row["id"]), + scripts=""" + if(!ctx._source.kb_id.contains('%s')) + ctx._source.kb_id.add('%s'); + """ % (str(row["kb_id"]), str(row["kb_id"])), + idxnm=search.index_name(row["tenant_id"]) + ) + set_progress(row["id"], 1, "Done") + return [] + + random.seed(time.time()) + set_progress(row["id"], random.randint(0, 20) / + 100., "Finished preparing! Start to slice file!", True) + try: + obj = chuck_doc(row["name"], MINIO.get(row["kb_id"], row["location"])) + except Exception as e: + if re.search("(No such file|not found)", str(e)): + set_progress( + row["id"], -1, "Can not find file <%s>" % + row["doc_name"]) + else: + set_progress( + row["id"], -1, f"Internal server error: %s" % + str(e).replace( + "'", "")) + return [] + + if not obj.text_chunks and not obj.table_chunks: + set_progress( + row["id"], + 1, + "Nothing added! Mostly, file type unsupported yet.") + return [] + + set_progress(row["id"], random.randint(20, 60) / 100., + "Finished slicing files. Start to embedding the content.") + + doc = { + "doc_id": row["did"], + "kb_id": [str(row["kb_id"])], + "docnm_kwd": os.path.split(row["location"])[-1], + "title_tks": huqie.qie(row["name"]), + "updated_at": str(row["update_time"]).replace("T", " ")[:19] + } + doc["title_sm_tks"] = huqie.qieqie(doc["title_tks"]) + output_buffer = BytesIO() + docs = [] + md5 = hashlib.md5() + for txt, img in obj.text_chunks: + d = copy.deepcopy(doc) + md5.update((txt + str(d["doc_id"])).encode("utf-8")) + d["_id"] = md5.hexdigest() + d["content_ltks"] = huqie.qie(txt) + d["content_sm_ltks"] = huqie.qieqie(d["content_ltks"]) + if not img: + docs.append(d) + continue + + if isinstance(img, Image): + img.save(output_buffer, format='JPEG') + else: + output_buffer = BytesIO(img) + + MINIO.put(row["kb_id"], d["_id"], output_buffer.getvalue()) + d["img_id"] = "{}-{}".format(row["kb_id"], d["_id"]) + docs.append(d) + + for arr, img in obj.table_chunks: + for i, txt in enumerate(arr): + d = copy.deepcopy(doc) + d["content_ltks"] = huqie.qie(txt) + md5.update((txt + str(d["doc_id"])).encode("utf-8")) + d["_id"] = md5.hexdigest() + if not img: + docs.append(d) + continue + img.save(output_buffer, format='JPEG') + MINIO.put(row["kb_id"], d["_id"], output_buffer.getvalue()) + d["img_id"] = "{}-{}".format(row["kb_id"], d["_id"]) + docs.append(d) + set_progress(row["id"], random.randint(60, 70) / + 100., "Continue embedding the content.") + + return docs + + +def init_kb(row): + idxnm = search.index_name(row["tenant_id"]) + if ELASTICSEARCH.indexExist(idxnm): + return + return ELASTICSEARCH.createIdx(idxnm, json.load( + open(os.path.join(get_project_base_directory(), "conf", "mapping.json"), "r"))) + + +def embedding(docs, mdl): + tts, cnts = [rmSpace(d["title_tks"]) for d in docs], [rmSpace(d["content_ltks"]) for d in docs] + tk_count = 0 + tts, c = mdl.encode(tts) + tk_count += c + cnts, c = mdl.encode(cnts) + tk_count += c + vects = 0.1 * tts + 0.9 * cnts + assert len(vects) == len(docs) + for i, d in enumerate(docs): + d["q_vec"] = vects[i].tolist() + return tk_count + + +def model_instance(tenant_id, llm_type): + model_config = TenantLLMService.query(tenant_id=tenant_id, model_type=LLMType.EMBEDDING) + if not model_config:return + model_config = model_config[0] + if llm_type == LLMType.EMBEDDING: + if model_config.llm_factory not in EmbeddingModel: return + return EmbeddingModel[model_config.llm_factory](model_config.api_key, model_config.llm_name) + if llm_type == LLMType.IMAGE2TEXT: + if model_config.llm_factory not in CvModel: return + return CvModel[model_config.llm_factory](model_config.api_key, model_config.llm_name) + + +def main(comm, mod): + global model + from rag.llm import HuEmbedding + model = HuEmbedding() + tm_fnm = os.path.join(get_project_base_directory(), "rag/res", f"{comm}-{mod}.tm") + tm = findMaxDt(tm_fnm) + rows = collect(comm, mod, tm) + if len(rows) == 0: + return + + tmf = open(tm_fnm, "a+") + for _, r in rows.iterrows(): + embd_mdl = model_instance(r["tenant_id"], LLMType.EMBEDDING) + if not embd_mdl: + set_progress(r["id"], -1, "Can't find embedding model!") + cron_logger.error("Tenant({}) can't find embedding model!".format(r["tenant_id"])) + continue + cv_mdl = model_instance(r["tenant_id"], LLMType.IMAGE2TEXT) + st_tm = timer() + cks = build(r, cv_mdl) + if not cks: + tmf.write(str(r["updated_at"]) + "\n") + continue + # TODO: exception handler + ## set_progress(r["did"], -1, "ERROR: ") + try: + tk_count = embedding(cks, embd_mdl) + except Exception as e: + set_progress(r["id"], -1, "Embedding error:{}".format(str(e))) + cron_logger.error(str(e)) + continue + + + set_progress(r["id"], random.randint(70, 95) / 100., + "Finished embedding! Start to build index!") + init_kb(r) + es_r = ELASTICSEARCH.bulk(cks, search.index_name(r["tenant_id"])) + if es_r: + set_progress(r["id"], -1, "Index failure!") + cron_logger.error(str(es_r)) + else: + set_progress(r["id"], 1., "Done!") + DocumentService.update_by_id(r["id"], {"token_num": tk_count, "chunk_num": len(cks), "process_duation": timer()-st_tm}) + tmf.write(str(r["update_time"]) + "\n") + tmf.close() + + +if __name__ == "__main__": + from mpi4py import MPI + comm = MPI.COMM_WORLD + main(comm.Get_size(), comm.Get_rank()) diff --git a/python/util/__init__.py b/rag/utils/__init__.py similarity index 50% rename from python/util/__init__.py rename to rag/utils/__init__.py index bc20189d1f714a9d72c9c4d4282cf30274c6ed2b..d3f1632334f390724cead41352cc14491d24f845 100644 --- a/python/util/__init__.py +++ b/rag/utils/__init__.py @@ -1,6 +1,23 @@ +import os import re +import tiktoken +def singleton(cls, *args, **kw): + instances = {} + + def _singleton(): + key = str(cls) + str(os.getpid()) + if key not in instances: + instances[key] = cls(*args, **kw) + return instances[key] + + return _singleton + + +from .minio_conn import MINIO +from .es_conn import ELASTICSEARCH + def rmSpace(txt): txt = re.sub(r"([^a-z0-9.,]) +([^ ])", r"\1\2", txt) return re.sub(r"([^ ]) +([^a-z0-9.,])", r"\1\2", txt) @@ -22,3 +39,9 @@ def findMaxDt(fnm): except Exception as e: print("WARNING: can't find " + fnm) return m + +def num_tokens_from_string(string: str) -> int: + """Returns the number of tokens in a text string.""" + encoding = tiktoken.get_encoding('cl100k_base') + num_tokens = len(encoding.encode(string)) + return num_tokens \ No newline at end of file diff --git a/python/util/es_conn.py b/rag/utils/es_conn.py old mode 100755 new mode 100644 similarity index 83% rename from python/util/es_conn.py rename to rag/utils/es_conn.py index 228c15460c4fb4748d8fd1677f2b97be60c3df2c..f8337c01db0599f65c61556299a20985efcf2f2c --- a/python/util/es_conn.py +++ b/rag/utils/es_conn.py @@ -1,51 +1,39 @@ import re -import logging import json import time import copy import elasticsearch from elasticsearch import Elasticsearch -from elasticsearch_dsl import UpdateByQuery, Search, Index, Q -from util import config +from elasticsearch_dsl import UpdateByQuery, Search, Index +from rag.settings import es_logger +from rag import settings +from rag.utils import singleton -logging.info("Elasticsearch version: ", elasticsearch.__version__) - - -def instance(env): - CF = config.init(env) - ES_DRESS = CF.get("es").split(",") - - ES = Elasticsearch( - ES_DRESS, - timeout=600 - ) - - logging.info("ES: ", ES_DRESS, ES.info()) - - return ES +es_logger.info("Elasticsearch version: "+ str(elasticsearch.__version__)) +@singleton class HuEs: - def __init__(self, env): - self.env = env + def __init__(self): self.info = {} - self.config = config.init(env) self.conn() - self.idxnm = self.config.get("idx_nm", "") + self.idxnm = settings.ES.get("index_name", "") if not self.es.ping(): raise Exception("Can't connect to ES cluster") def conn(self): for _ in range(10): try: - c = instance(self.env) - if c: - self.es = c - self.info = c.info() - logging.info("Connect to es.") + self.es = Elasticsearch( + settings.ES["hosts"].split(","), + timeout=600 + ) + if self.es: + self.info = self.es.info() + es_logger.info("Connect to es.") break except Exception as e: - logging.error("Fail to connect to es: " + str(e)) + es_logger.error("Fail to connect to es: " + str(e)) time.sleep(1) def version(self): @@ -80,12 +68,12 @@ class HuEs: refresh=False, doc_type="_doc", retry_on_conflict=100) - logging.info("Successfully upsert: %s" % id) + es_logger.info("Successfully upsert: %s" % id) T = True break except Exception as e: - logging.warning("Fail to index: " + - json.dumps(d, ensure_ascii=False) + str(e)) + es_logger.warning("Fail to index: " + + json.dumps(d, ensure_ascii=False) + str(e)) if re.search(r"(Timeout|time out)", str(e), re.IGNORECASE): time.sleep(3) continue @@ -94,7 +82,7 @@ class HuEs: if not T: res.append(d) - logging.error( + es_logger.error( "Fail to index: " + re.sub( "[\r\n]", @@ -147,7 +135,7 @@ class HuEs: return res except Exception as e: - logging.warn("Fail to bulk: " + str(e)) + es_logger.warn("Fail to bulk: " + str(e)) if re.search(r"(Timeout|time out)", str(e), re.IGNORECASE): time.sleep(3) continue @@ -162,7 +150,7 @@ class HuEs: ids[id] = copy.deepcopy(d["raw"]) acts.append({"update": {"_id": id, "_index": self.idxnm}}) acts.append(d["script"]) - logging.info("bulk upsert: %s" % id) + es_logger.info("bulk upsert: %s" % id) res = [] for _ in range(10): @@ -189,7 +177,7 @@ class HuEs: return res except Exception as e: - logging.warning("Fail to bulk: " + str(e)) + es_logger.warning("Fail to bulk: " + str(e)) if re.search(r"(Timeout|time out)", str(e), re.IGNORECASE): time.sleep(3) continue @@ -212,10 +200,10 @@ class HuEs: id=d["id"], refresh=True, doc_type="_doc") - logging.info("Remove %s" % d["id"]) + es_logger.info("Remove %s" % d["id"]) return True except Exception as e: - logging.warn("Fail to delete: " + str(d) + str(e)) + es_logger.warn("Fail to delete: " + str(d) + str(e)) if re.search(r"(Timeout|time out)", str(e), re.IGNORECASE): time.sleep(3) continue @@ -223,7 +211,7 @@ class HuEs: return True self.conn() - logging.error("Fail to delete: " + str(d)) + es_logger.error("Fail to delete: " + str(d)) return False @@ -242,7 +230,7 @@ class HuEs: raise Exception("Es Timeout.") return res except Exception as e: - logging.error( + es_logger.error( "ES search exception: " + str(e) + "【Q】:" + @@ -250,7 +238,7 @@ class HuEs: if str(e).find("Timeout") > 0: continue raise e - logging.error("ES search timeout for 3 times!") + es_logger.error("ES search timeout for 3 times!") raise Exception("ES search timeout.") def updateByQuery(self, q, d): @@ -267,8 +255,8 @@ class HuEs: r = ubq.execute() return True except Exception as e: - logging.error("ES updateByQuery exception: " + - str(e) + "【Q】:" + str(q.to_dict())) + es_logger.error("ES updateByQuery exception: " + + str(e) + "【Q】:" + str(q.to_dict())) if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0: continue self.conn() @@ -288,8 +276,8 @@ class HuEs: r = ubq.execute() return True except Exception as e: - logging.error("ES updateByQuery exception: " + - str(e) + "【Q】:" + str(q.to_dict())) + es_logger.error("ES updateByQuery exception: " + + str(e) + "【Q】:" + str(q.to_dict())) if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0: continue self.conn() @@ -304,8 +292,8 @@ class HuEs: body=Search().query(query).to_dict()) return True except Exception as e: - logging.error("ES updateByQuery deleteByQuery: " + - str(e) + "【Q】:" + str(query.to_dict())) + es_logger.error("ES updateByQuery deleteByQuery: " + + str(e) + "【Q】:" + str(query.to_dict())) if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0: continue @@ -329,8 +317,9 @@ class HuEs: routing=routing, refresh=False) # , doc_type="_doc") return True except Exception as e: - logging.error("ES update exception: " + str(e) + " id:" + str(id) + ", version:" + str(self.version()) + - json.dumps(script, ensure_ascii=False)) + es_logger.error( + "ES update exception: " + str(e) + " id:" + str(id) + ", version:" + str(self.version()) + + json.dumps(script, ensure_ascii=False)) if str(e).find("Timeout") > 0: continue @@ -342,7 +331,7 @@ class HuEs: try: return s.exists() except Exception as e: - logging.error("ES updateByQuery indexExist: " + str(e)) + es_logger.error("ES updateByQuery indexExist: " + str(e)) if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0: continue @@ -354,7 +343,7 @@ class HuEs: return self.es.exists(index=(idxnm if idxnm else self.idxnm), id=docid) except Exception as e: - logging.error("ES Doc Exist: " + str(e)) + es_logger.error("ES Doc Exist: " + str(e)) if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0: continue return False @@ -368,13 +357,13 @@ class HuEs: settings=mapping["settings"], mappings=mapping["mappings"]) except Exception as e: - logging.error("ES create index error %s ----%s" % (idxnm, str(e))) + es_logger.error("ES create index error %s ----%s" % (idxnm, str(e))) def deleteIdx(self, idxnm): try: return self.es.indices.delete(idxnm, allow_no_indices=True) except Exception as e: - logging.error("ES delete index error %s ----%s" % (idxnm, str(e))) + es_logger.error("ES delete index error %s ----%s" % (idxnm, str(e))) def getTotal(self, res): if isinstance(res["hits"]["total"], type({})): @@ -393,7 +382,7 @@ class HuEs: return rr def scrollIter(self, pagesize=100, scroll_time='2m', q={ - "query": {"match_all": {}}, "sort": [{"updated_at": {"order": "desc"}}]}): + "query": {"match_all": {}}, "sort": [{"updated_at": {"order": "desc"}}]}): for _ in range(100): try: page = self.es.search( @@ -405,12 +394,12 @@ class HuEs: ) break except Exception as e: - logging.error("ES scrolling fail. " + str(e)) + es_logger.error("ES scrolling fail. " + str(e)) time.sleep(3) sid = page['_scroll_id'] scroll_size = page['hits']['total']["value"] - logging.info("[TOTAL]%d" % scroll_size) + es_logger.info("[TOTAL]%d" % scroll_size) # Start scrolling while scroll_size > 0: yield page["hits"]["hits"] @@ -419,10 +408,13 @@ class HuEs: page = self.es.scroll(scroll_id=sid, scroll=scroll_time) break except Exception as e: - logging.error("ES scrolling fail. " + str(e)) + es_logger.error("ES scrolling fail. " + str(e)) time.sleep(3) # Update the scroll ID sid = page['_scroll_id'] # Get the number of results that we returned in the last scroll scroll_size = len(page['hits']['hits']) + + +ELASTICSEARCH = HuEs() diff --git a/python/util/minio_conn.py b/rag/utils/minio_conn.py similarity index 62% rename from python/util/minio_conn.py rename to rag/utils/minio_conn.py index 141fa3de0e2a91df6d3f824d5438199e39d702dc..14a7067b90d44e0f10cd1a86a266952fa3009dcd 100644 --- a/python/util/minio_conn.py +++ b/rag/utils/minio_conn.py @@ -1,13 +1,15 @@ -import logging +import os import time -from util import config from minio import Minio from io import BytesIO +from rag import settings +from rag.settings import minio_logger +from rag.utils import singleton +@singleton class HuMinio(object): - def __init__(self, env): - self.config = config.init(env) + def __init__(self): self.conn = None self.__open__() @@ -19,15 +21,14 @@ class HuMinio(object): pass try: - self.conn = Minio(self.config.get("minio_host"), - access_key=self.config.get("minio_user"), - secret_key=self.config.get("minio_password"), + self.conn = Minio(settings.MINIO["host"], + access_key=settings.MINIO["user"], + secret_key=settings.MINIO["passwd"], secure=False ) except Exception as e: - logging.error( - "Fail to connect %s " % - self.config.get("minio_host") + str(e)) + minio_logger.error( + "Fail to connect %s " % settings.MINIO["host"] + str(e)) def __close__(self): del self.conn @@ -45,34 +46,51 @@ class HuMinio(object): ) return r except Exception as e: - logging.error(f"Fail put {bucket}/{fnm}: " + str(e)) + minio_logger.error(f"Fail put {bucket}/{fnm}: " + str(e)) self.__open__() time.sleep(1) + def rm(self, bucket, fnm): + try: + self.conn.remove_object(bucket, fnm) + except Exception as e: + minio_logger.error(f"Fail rm {bucket}/{fnm}: " + str(e)) + + def get(self, bucket, fnm): for _ in range(10): try: r = self.conn.get_object(bucket, fnm) return r.read() except Exception as e: - logging.error(f"fail get {bucket}/{fnm}: " + str(e)) + minio_logger.error(f"fail get {bucket}/{fnm}: " + str(e)) self.__open__() time.sleep(1) return + def obj_exist(self, bucket, fnm): + try: + if self.conn.stat_object(bucket, fnm):return True + return False + except Exception as e: + minio_logger.error(f"Fail put {bucket}/{fnm}: " + str(e)) + return False + + def get_presigned_url(self, bucket, fnm, expires): for _ in range(10): try: return self.conn.get_presigned_url("GET", bucket, fnm, expires) except Exception as e: - logging.error(f"fail get {bucket}/{fnm}: " + str(e)) + minio_logger.error(f"fail get {bucket}/{fnm}: " + str(e)) self.__open__() time.sleep(1) return +MINIO = HuMinio() if __name__ == "__main__": - conn = HuMinio("infiniflow") + conn = HuMinio() fnm = "/opt/home/kevinhu/docgpt/upload/13/11-408.jpg" from PIL import Image img = Image.open(fnm) diff --git a/src/api/dialog_info.rs b/src/api/dialog_info.rs deleted file mode 100644 index def4f53156eaf602cae25c6fde2d986db55a8b33..0000000000000000000000000000000000000000 --- a/src/api/dialog_info.rs +++ /dev/null @@ -1,178 +0,0 @@ -use std::collections::HashMap; -use actix_web::{ HttpResponse, post, web }; -use serde::Deserialize; -use serde_json::Value; -use serde_json::json; -use crate::api::JsonResponse; -use crate::AppState; -use crate::errors::AppError; -use crate::service::dialog_info::Query; -use crate::service::dialog_info::Mutation; - -#[derive(Debug, Deserialize)] -pub struct ListParams { - pub uid: i64, - pub dialog_id: Option, -} -#[post("/v1.0/dialogs")] -async fn list( - params: web::Json, - data: web::Data -) -> Result { - let mut result = HashMap::new(); - if let Some(dia_id) = params.dialog_id { - let dia = Query::find_dialog_info_by_id(&data.conn, dia_id).await?.unwrap(); - let kb = crate::service::kb_info::Query - ::find_kb_info_by_id(&data.conn, dia.kb_id).await? - .unwrap(); - print!("{:?}", dia.history); - let hist: Value = serde_json::from_str(&dia.history)?; - let detail = - json!({ - "dialog_id": dia_id, - "dialog_name": dia.dialog_name.to_owned(), - "created_at": dia.created_at.to_string().to_owned(), - "updated_at": dia.updated_at.to_string().to_owned(), - "history": hist, - "kb_info": kb - }); - - result.insert("dialogs", vec![detail]); - } else { - let mut dias = Vec::::new(); - for dia in Query::find_dialog_infos_by_uid(&data.conn, params.uid).await? { - let kb = crate::service::kb_info::Query - ::find_kb_info_by_id(&data.conn, dia.kb_id).await? - .unwrap(); - let hist: Value = serde_json::from_str(&dia.history)?; - dias.push( - json!({ - "dialog_id": dia.dialog_id, - "dialog_name": dia.dialog_name.to_owned(), - "created_at": dia.created_at.to_string().to_owned(), - "updated_at": dia.updated_at.to_string().to_owned(), - "history": hist, - "kb_info": kb - }) - ); - } - result.insert("dialogs", dias); - } - let json_response = JsonResponse { - code: 200, - err: "".to_owned(), - data: result, - }; - - Ok( - HttpResponse::Ok() - .content_type("application/json") - .body(serde_json::to_string(&json_response)?) - ) -} - -#[derive(Debug, Deserialize)] -pub struct RmParams { - pub uid: i64, - pub dialog_id: i64, -} -#[post("/v1.0/delete_dialog")] -async fn delete( - params: web::Json, - data: web::Data -) -> Result { - let _ = Mutation::delete_dialog_info(&data.conn, params.dialog_id).await?; - - let json_response = JsonResponse { - code: 200, - err: "".to_owned(), - data: (), - }; - - Ok( - HttpResponse::Ok() - .content_type("application/json") - .body(serde_json::to_string(&json_response)?) - ) -} - -#[derive(Debug, Deserialize)] -pub struct CreateParams { - pub uid: i64, - pub dialog_id: Option, - pub kb_id: i64, - pub name: String, -} -#[post("/v1.0/create_dialog")] -async fn create( - param: web::Json, - data: web::Data -) -> Result { - let mut result = HashMap::new(); - if let Some(dia_id) = param.dialog_id { - result.insert("dialog_id", dia_id); - let dia = Query::find_dialog_info_by_id(&data.conn, dia_id).await?; - let _ = Mutation::update_dialog_info_by_id( - &data.conn, - dia_id, - ¶m.name, - &dia.unwrap().history - ).await?; - } else { - let dia = Mutation::create_dialog_info( - &data.conn, - param.uid, - param.kb_id, - ¶m.name - ).await?; - result.insert("dialog_id", dia.dialog_id.unwrap()); - } - - let json_response = JsonResponse { - code: 200, - err: "".to_owned(), - data: result, - }; - - Ok( - HttpResponse::Ok() - .content_type("application/json") - .body(serde_json::to_string(&json_response)?) - ) -} - -#[derive(Debug, Deserialize)] -pub struct UpdateHistoryParams { - pub uid: i64, - pub dialog_id: i64, - pub history: Value, -} -#[post("/v1.0/update_history")] -async fn update_history( - param: web::Json, - data: web::Data -) -> Result { - let mut json_response = JsonResponse { - code: 200, - err: "".to_owned(), - data: (), - }; - - if let Some(dia) = Query::find_dialog_info_by_id(&data.conn, param.dialog_id).await? { - let _ = Mutation::update_dialog_info_by_id( - &data.conn, - param.dialog_id, - &dia.dialog_name, - ¶m.history.to_string() - ).await?; - } else { - json_response.code = 500; - json_response.err = "Can't find dialog data!".to_owned(); - } - - Ok( - HttpResponse::Ok() - .content_type("application/json") - .body(serde_json::to_string(&json_response)?) - ) -} diff --git a/src/api/doc_info.rs b/src/api/doc_info.rs deleted file mode 100644 index 6d3f5a3a9434146904fbbc195ca330222fbaf8a0..0000000000000000000000000000000000000000 --- a/src/api/doc_info.rs +++ /dev/null @@ -1,314 +0,0 @@ -use std::collections::{HashMap}; -use std::io::BufReader; -use actix_multipart_extract::{ File, Multipart, MultipartForm }; -use actix_web::web::Bytes; -use actix_web::{ HttpResponse, post, web }; -use chrono::{ Utc, FixedOffset }; -use minio::s3::args::{ BucketExistsArgs, MakeBucketArgs, PutObjectArgs }; -use sea_orm::DbConn; -use crate::api::JsonResponse; -use crate::AppState; -use crate::entity::doc_info::Model; -use crate::errors::AppError; -use crate::service::doc_info::{ Mutation, Query }; -use serde::Deserialize; -use regex::Regex; - -fn now() -> chrono::DateTime { - Utc::now().with_timezone(&FixedOffset::east_opt(3600 * 8).unwrap()) -} - -#[derive(Debug, Deserialize)] -pub struct ListParams { - pub uid: i64, - pub filter: FilterParams, - pub sortby: String, - pub page: Option, - pub per_page: Option, -} - -#[derive(Debug, Deserialize)] -pub struct FilterParams { - pub keywords: Option, - pub folder_id: Option, - pub tag_id: Option, - pub kb_id: Option, -} - -#[post("/v1.0/docs")] -async fn list( - params: web::Json, - data: web::Data -) -> Result { - let docs = Query::find_doc_infos_by_params(&data.conn, params.into_inner()).await?; - - let mut result = HashMap::new(); - result.insert("docs", docs); - - let json_response = JsonResponse { - code: 200, - err: "".to_owned(), - data: result, - }; - - Ok( - HttpResponse::Ok() - .content_type("application/json") - .body(serde_json::to_string(&json_response)?) - ) -} - -#[derive(Deserialize, MultipartForm, Debug)] -pub struct UploadForm { - #[multipart(max_size = 512MB)] - file_field: File, - uid: i64, - did: i64, -} - -fn file_type(filename: &String) -> String { - let fnm = filename.to_lowercase(); - if - let Some(_) = Regex::new(r"\.(mpg|mpeg|avi|rm|rmvb|mov|wmv|asf|dat|asx|wvx|mpe|mpa|mp4)$") - .unwrap() - .captures(&fnm) - { - return "Video".to_owned(); - } - if - let Some(_) = Regex::new( - r"\.(jpg|jpeg|png|tif|gif|pcx|tga|exif|fpx|svg|psd|cdr|pcd|dxf|ufo|eps|ai|raw|WMF|webp|avif|apng|icon|ico)$" - ) - .unwrap() - .captures(&fnm) - { - return "Picture".to_owned(); - } - if - let Some(_) = Regex::new(r"\.(wav|flac|ape|alac|wavpack|wv|mp3|aac|ogg|vorbis|opus|mp3)$") - .unwrap() - .captures(&fnm) - { - return "Music".to_owned(); - } - if - let Some(_) = Regex::new(r"\.(pdf|doc|ppt|yml|xml|htm|json|csv|txt|ini|xsl|wps|rtf|hlp|pages|numbers|key)$") - .unwrap() - .captures(&fnm) - { - return "Document".to_owned(); - } - "Other".to_owned() -} - - -#[post("/v1.0/upload")] -async fn upload( - payload: Multipart, - data: web::Data -) -> Result { - - - Ok(HttpResponse::Ok().body("File uploaded successfully")) -} - -pub(crate) async fn _upload_file(uid: i64, did: i64, file_name: &str, bytes: &[u8], data: &web::Data) -> Result<(), AppError> { - async fn add_number_to_filename( - file_name: &str, - conn: &DbConn, - uid: i64, - parent_id: i64 - ) -> String { - let mut i = 0; - let mut new_file_name = file_name.to_string(); - let arr: Vec<&str> = file_name.split(".").collect(); - let suffix = String::from(arr[arr.len() - 1]); - let preffix = arr[..arr.len() - 1].join("."); - let mut docs = Query::find_doc_infos_by_name( - conn, - uid, - &new_file_name, - Some(parent_id) - ).await.unwrap(); - while docs.len() > 0 { - i += 1; - new_file_name = format!("{}_{}.{}", preffix, i, suffix); - docs = Query::find_doc_infos_by_name( - conn, - uid, - &new_file_name, - Some(parent_id) - ).await.unwrap(); - } - new_file_name - } - let fnm = add_number_to_filename(file_name, &data.conn, uid, did).await; - - let bucket_name = format!("{}-upload", uid); - let s3_client: &minio::s3::client::Client = &data.s3_client; - let buckets_exists = s3_client - .bucket_exists(&BucketExistsArgs::new(&bucket_name).unwrap()).await - .unwrap(); - if !buckets_exists { - print!("Create bucket: {}", bucket_name.clone()); - s3_client.make_bucket(&MakeBucketArgs::new(&bucket_name).unwrap()).await.unwrap(); - } else { - print!("Existing bucket: {}", bucket_name.clone()); - } - - let location = format!("/{}/{}", did, fnm) - .as_bytes() - .to_vec() - .iter() - .map(|b| format!("{:02x}", b).to_string()) - .collect::>() - .join(""); - print!("===>{}", location.clone()); - s3_client.put_object( - &mut PutObjectArgs::new( - &bucket_name, - &location, - &mut BufReader::new(bytes), - Some(bytes.len()), - None - )? - ).await?; - - let doc = Mutation::create_doc_info(&data.conn, Model { - did: Default::default(), - uid: uid, - doc_name: fnm.clone(), - size: bytes.len() as i64, - location, - r#type: file_type(&fnm), - thumbnail_base64: Default::default(), - created_at: now(), - updated_at: now(), - is_deleted: Default::default(), - }).await?; - - let _ = Mutation::place_doc(&data.conn, did, doc.did.unwrap()).await?; - - Ok(()) -} - -#[derive(Deserialize, Debug)] -pub struct RmDocsParam { - uid: i64, - dids: Vec, -} -#[post("/v1.0/delete_docs")] -async fn delete( - params: web::Json, - data: web::Data -) -> Result { - let _ = Mutation::delete_doc_info(&data.conn, ¶ms.dids).await?; - - let json_response = JsonResponse { - code: 200, - err: "".to_owned(), - data: (), - }; - - Ok( - HttpResponse::Ok() - .content_type("application/json") - .body(serde_json::to_string(&json_response)?) - ) -} - -#[derive(Debug, Deserialize)] -pub struct MvParams { - pub uid: i64, - pub dids: Vec, - pub dest_did: i64, -} - -#[post("/v1.0/mv_docs")] -async fn mv( - params: web::Json, - data: web::Data -) -> Result { - Mutation::mv_doc_info(&data.conn, params.dest_did, ¶ms.dids).await?; - - let json_response = JsonResponse { - code: 200, - err: "".to_owned(), - data: (), - }; - - Ok( - HttpResponse::Ok() - .content_type("application/json") - .body(serde_json::to_string(&json_response)?) - ) -} - -#[derive(Debug, Deserialize)] -pub struct NewFoldParams { - pub uid: i64, - pub parent_id: i64, - pub name: String, -} - -#[post("/v1.0/new_folder")] -async fn new_folder( - params: web::Json, - data: web::Data -) -> Result { - let doc = Mutation::create_doc_info(&data.conn, Model { - did: Default::default(), - uid: params.uid, - doc_name: params.name.to_string(), - size: 0, - r#type: "folder".to_string(), - location: "".to_owned(), - thumbnail_base64: Default::default(), - created_at: now(), - updated_at: now(), - is_deleted: Default::default(), - }).await?; - let _ = Mutation::place_doc(&data.conn, params.parent_id, doc.did.unwrap()).await?; - - Ok(HttpResponse::Ok().body("Folder created successfully")) -} - -#[derive(Debug, Deserialize)] -pub struct RenameParams { - pub uid: i64, - pub did: i64, - pub name: String, -} - -#[post("/v1.0/rename")] -async fn rename( - params: web::Json, - data: web::Data -) -> Result { - let docs = Query::find_doc_infos_by_name(&data.conn, params.uid, ¶ms.name, None).await?; - if docs.len() > 0 { - let json_response = JsonResponse { - code: 500, - err: "Name duplicated!".to_owned(), - data: (), - }; - return Ok( - HttpResponse::Ok() - .content_type("application/json") - .body(serde_json::to_string(&json_response)?) - ); - } - let doc = Mutation::rename(&data.conn, params.did, ¶ms.name).await?; - - let json_response = JsonResponse { - code: 200, - err: "".to_owned(), - data: doc, - }; - - Ok( - HttpResponse::Ok() - .content_type("application/json") - .body(serde_json::to_string(&json_response)?) - ) -} diff --git a/src/api/kb_info.rs b/src/api/kb_info.rs deleted file mode 100644 index 544e944b03c738333f59bc699e5799d2e5e304bd..0000000000000000000000000000000000000000 --- a/src/api/kb_info.rs +++ /dev/null @@ -1,166 +0,0 @@ -use std::collections::HashMap; -use actix_web::{ get, HttpResponse, post, web }; -use serde::Serialize; -use crate::api::JsonResponse; -use crate::AppState; -use crate::entity::kb_info; -use crate::errors::AppError; -use crate::service::kb_info::Mutation; -use crate::service::kb_info::Query; -use serde::Deserialize; - -#[derive(Clone, Debug, Serialize, Deserialize)] -pub struct AddDocs2KbParams { - pub uid: i64, - pub dids: Vec, - pub kb_id: i64, -} -#[post("/v1.0/create_kb")] -async fn create( - model: web::Json, - data: web::Data -) -> Result { - let mut docs = Query::find_kb_infos_by_name( - &data.conn, - model.kb_name.to_owned() - ).await.unwrap(); - if docs.len() > 0 { - let json_response = JsonResponse { - code: 201, - err: "Duplicated name.".to_owned(), - data: (), - }; - Ok( - HttpResponse::Ok() - .content_type("application/json") - .body(serde_json::to_string(&json_response)?) - ) - } else { - let model = Mutation::create_kb_info(&data.conn, model.into_inner()).await?; - - let mut result = HashMap::new(); - result.insert("kb_id", model.kb_id.unwrap()); - - let json_response = JsonResponse { - code: 200, - err: "".to_owned(), - data: result, - }; - - Ok( - HttpResponse::Ok() - .content_type("application/json") - .body(serde_json::to_string(&json_response)?) - ) - } -} - -#[post("/v1.0/add_docs_to_kb")] -async fn add_docs_to_kb( - param: web::Json, - data: web::Data -) -> Result { - let _ = Mutation::add_docs(&data.conn, param.kb_id, param.dids.to_owned()).await?; - - let json_response = JsonResponse { - code: 200, - err: "".to_owned(), - data: (), - }; - - Ok( - HttpResponse::Ok() - .content_type("application/json") - .body(serde_json::to_string(&json_response)?) - ) -} - -#[post("/v1.0/anti_kb_docs")] -async fn anti_kb_docs( - param: web::Json, - data: web::Data -) -> Result { - let _ = Mutation::remove_docs(&data.conn, param.dids.to_owned(), Some(param.kb_id)).await?; - - let json_response = JsonResponse { - code: 200, - err: "".to_owned(), - data: (), - }; - - Ok( - HttpResponse::Ok() - .content_type("application/json") - .body(serde_json::to_string(&json_response)?) - ) -} -#[get("/v1.0/kbs")] -async fn list( - model: web::Json, - data: web::Data -) -> Result { - let kbs = Query::find_kb_infos_by_uid(&data.conn, model.uid).await?; - - let mut result = HashMap::new(); - result.insert("kbs", kbs); - - let json_response = JsonResponse { - code: 200, - err: "".to_owned(), - data: result, - }; - - Ok( - HttpResponse::Ok() - .content_type("application/json") - .body(serde_json::to_string(&json_response)?) - ) -} - -#[post("/v1.0/delete_kb")] -async fn delete( - model: web::Json, - data: web::Data -) -> Result { - let _ = Mutation::delete_kb_info(&data.conn, model.kb_id).await?; - - let json_response = JsonResponse { - code: 200, - err: "".to_owned(), - data: (), - }; - - Ok( - HttpResponse::Ok() - .content_type("application/json") - .body(serde_json::to_string(&json_response)?) - ) -} - -#[derive(Clone, Debug, Serialize, Deserialize)] -pub struct DocIdsParams { - pub uid: i64, - pub dids: Vec, -} - -#[post("/v1.0/all_relevents")] -async fn all_relevents( - params: web::Json, - data: web::Data -) -> Result { - let dids = crate::service::doc_info::Query::all_descendent_ids(&data.conn, ¶ms.dids).await?; - let mut result = HashMap::new(); - let kbs = Query::find_kb_by_docs(&data.conn, dids).await?; - result.insert("kbs", kbs); - let json_response = JsonResponse { - code: 200, - err: "".to_owned(), - data: result, - }; - - Ok( - HttpResponse::Ok() - .content_type("application/json") - .body(serde_json::to_string(&json_response)?) - ) -} diff --git a/src/api/mod.rs b/src/api/mod.rs deleted file mode 100644 index e3ae3e6080bf10d3ffe6533f9d931787860f3c65..0000000000000000000000000000000000000000 --- a/src/api/mod.rs +++ /dev/null @@ -1,14 +0,0 @@ -use serde::{ Deserialize, Serialize }; - -pub(crate) mod tag_info; -pub(crate) mod kb_info; -pub(crate) mod dialog_info; -pub(crate) mod doc_info; -pub(crate) mod user_info; - -#[derive(Debug, Deserialize, Serialize)] -struct JsonResponse { - code: u32, - err: String, - data: T, -} diff --git a/src/api/tag_info.rs b/src/api/tag_info.rs deleted file mode 100644 index 1caa4e5dc2a13c054ad3ee2e00b45bc28fa47368..0000000000000000000000000000000000000000 --- a/src/api/tag_info.rs +++ /dev/null @@ -1,81 +0,0 @@ -use std::collections::HashMap; -use actix_web::{ get, HttpResponse, post, web }; -use serde::Deserialize; -use crate::api::JsonResponse; -use crate::AppState; -use crate::entity::tag_info; -use crate::errors::AppError; -use crate::service::tag_info::{ Mutation, Query }; - -#[derive(Debug, Deserialize)] -pub struct TagListParams { - pub uid: i64, -} - -#[post("/v1.0/create_tag")] -async fn create( - model: web::Json, - data: web::Data -) -> Result { - let model = Mutation::create_tag(&data.conn, model.into_inner()).await?; - - let mut result = HashMap::new(); - result.insert("tid", model.tid.unwrap()); - - let json_response = JsonResponse { - code: 200, - err: "".to_owned(), - data: result, - }; - - Ok( - HttpResponse::Ok() - .content_type("application/json") - .body(serde_json::to_string(&json_response)?) - ) -} - -#[post("/v1.0/delete_tag")] -async fn delete( - model: web::Json, - data: web::Data -) -> Result { - let _ = Mutation::delete_tag(&data.conn, model.tid).await?; - - let json_response = JsonResponse { - code: 200, - err: "".to_owned(), - data: (), - }; - - Ok( - HttpResponse::Ok() - .content_type("application/json") - .body(serde_json::to_string(&json_response)?) - ) -} - -//#[get("/v1.0/tags", wrap = "HttpAuthentication::bearer(validator)")] - -#[post("/v1.0/tags")] -async fn list( - param: web::Json, - data: web::Data -) -> Result { - let tags = Query::find_tags_by_uid(param.uid, &data.conn).await?; - - let mut result = HashMap::new(); - result.insert("tags", tags); - - let json_response = JsonResponse { - code: 200, - err: "".to_owned(), - data: result, - }; - - Ok( - HttpResponse::Ok() - .content_type("application/json") - .body(serde_json::to_string(&json_response)?) - ) -} diff --git a/src/api/user_info.rs b/src/api/user_info.rs deleted file mode 100644 index 3d9f89e44b27214ab6972eef63febc080eb14e75..0000000000000000000000000000000000000000 --- a/src/api/user_info.rs +++ /dev/null @@ -1,149 +0,0 @@ -use std::collections::HashMap; -use std::io::SeekFrom; -use std::ptr::null; -use actix_identity::Identity; -use actix_web::{ HttpResponse, post, web }; -use chrono::{ FixedOffset, Utc }; -use sea_orm::ActiveValue::NotSet; -use serde::{ Deserialize, Serialize }; -use crate::api::JsonResponse; -use crate::AppState; -use crate::entity::{ doc_info, tag_info }; -use crate::entity::user_info::Model; -use crate::errors::{ AppError, UserError }; -use crate::service::user_info::Mutation; -use crate::service::user_info::Query; - -fn now() -> chrono::DateTime { - Utc::now().with_timezone(&FixedOffset::east_opt(3600 * 8).unwrap()) -} - -pub(crate) fn create_auth_token(user: &Model) -> u64 { - use std::{ collections::hash_map::DefaultHasher, hash::{ Hash, Hasher } }; - - let mut hasher = DefaultHasher::new(); - user.hash(&mut hasher); - hasher.finish() -} - -#[derive(Clone, Debug, Serialize, Deserialize)] -pub(crate) struct LoginParams { - pub(crate) email: String, - pub(crate) password: String, -} - -#[post("/v1.0/login")] -async fn login( - data: web::Data, - identity: Identity, - input: web::Json -) -> Result { - match Query::login(&data.conn, &input.email, &input.password).await? { - Some(user) => { - let _ = Mutation::update_login_status(user.uid, &data.conn).await?; - let token = create_auth_token(&user).to_string(); - - identity.remember(token.clone()); - - let json_response = JsonResponse { - code: 200, - err: "".to_owned(), - data: token.clone(), - }; - - Ok( - HttpResponse::Ok() - .content_type("application/json") - .append_header(("X-Auth-Token", token)) - .body(serde_json::to_string(&json_response)?) - ) - } - None => Err(UserError::LoginFailed.into()), - } -} - -#[post("/v1.0/register")] -async fn register( - model: web::Json, - data: web::Data -) -> Result { - let mut result = HashMap::new(); - let u = Query::find_user_infos(&data.conn, &model.email).await?; - if let Some(_) = u { - let json_response = JsonResponse { - code: 500, - err: "Email registered!".to_owned(), - data: (), - }; - return Ok( - HttpResponse::Ok() - .content_type("application/json") - .body(serde_json::to_string(&json_response)?) - ); - } - - let usr = Mutation::create_user(&data.conn, &model).await?; - result.insert("uid", usr.uid.clone().unwrap()); - crate::service::doc_info::Mutation::create_doc_info(&data.conn, doc_info::Model { - did: Default::default(), - uid: usr.uid.clone().unwrap(), - doc_name: "/".into(), - size: 0, - location: "".into(), - thumbnail_base64: "".into(), - r#type: "folder".to_string(), - created_at: now(), - updated_at: now(), - is_deleted: Default::default(), - }).await?; - let tnm = vec!["Video", "Picture", "Music", "Document"]; - let tregx = vec![ - ".*\\.(mpg|mpeg|avi|rm|rmvb|mov|wmv|asf|dat|asx|wvx|mpe|mpa)", - ".*\\.(png|tif|gif|pcx|tga|exif|fpx|svg|psd|cdr|pcd|dxf|ufo|eps|ai|raw|WMF|webp|avif|apng)", - ".*\\.(WAV|FLAC|APE|ALAC|WavPack|WV|MP3|AAC|Ogg|Vorbis|Opus)", - ".*\\.(pdf|doc|ppt|yml|xml|htm|json|csv|txt|ini|xsl|wps|rtf|hlp)" - ]; - for i in 0..4 { - crate::service::tag_info::Mutation::create_tag(&data.conn, tag_info::Model { - tid: Default::default(), - uid: usr.uid.clone().unwrap(), - tag_name: tnm[i].to_owned(), - regx: tregx[i].to_owned(), - color: (i + 1).to_owned() as i16, - icon: (i + 1).to_owned() as i16, - folder_id: 0, - created_at: Default::default(), - updated_at: Default::default(), - }).await?; - } - let json_response = JsonResponse { - code: 200, - err: "".to_owned(), - data: result, - }; - - Ok( - HttpResponse::Ok() - .content_type("application/json") - .body(serde_json::to_string(&json_response)?) - ) -} - -#[post("/v1.0/setting")] -async fn setting( - model: web::Json, - data: web::Data -) -> Result { - let _ = Mutation::update_user_by_id(&data.conn, &model).await?; - let json_response = JsonResponse { - code: 200, - err: "".to_owned(), - data: (), - }; - - Ok( - HttpResponse::Ok() - .content_type("application/json") - .body(serde_json::to_string(&json_response)?) - ) -} diff --git a/src/entity/dialog2_kb.rs b/src/entity/dialog2_kb.rs deleted file mode 100644 index 1f3ebaa7fa41d2483c7de2aaf68ec2e3436eeb5b..0000000000000000000000000000000000000000 --- a/src/entity/dialog2_kb.rs +++ /dev/null @@ -1,38 +0,0 @@ -use sea_orm::entity::prelude::*; -use serde::{ Deserialize, Serialize }; - -#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel, Deserialize, Serialize)] -#[sea_orm(table_name = "dialog2_kb")] -pub struct Model { - #[sea_orm(primary_key, auto_increment = true)] - pub id: i64, - #[sea_orm(index)] - pub dialog_id: i64, - #[sea_orm(index)] - pub kb_id: i64, -} - -#[derive(Debug, Clone, Copy, EnumIter)] -pub enum Relation { - DialogInfo, - KbInfo, -} - -impl RelationTrait for Relation { - fn def(&self) -> RelationDef { - match self { - Self::DialogInfo => - Entity::belongs_to(super::dialog_info::Entity) - .from(Column::DialogId) - .to(super::dialog_info::Column::DialogId) - .into(), - Self::KbInfo => - Entity::belongs_to(super::kb_info::Entity) - .from(Column::KbId) - .to(super::kb_info::Column::KbId) - .into(), - } - } -} - -impl ActiveModelBehavior for ActiveModel {} diff --git a/src/entity/dialog_info.rs b/src/entity/dialog_info.rs deleted file mode 100644 index a201695f1dd127451bfa06fd06a5418b16d98f32..0000000000000000000000000000000000000000 --- a/src/entity/dialog_info.rs +++ /dev/null @@ -1,38 +0,0 @@ -use chrono::{ DateTime, FixedOffset }; -use sea_orm::entity::prelude::*; -use serde::{ Deserialize, Serialize }; - -#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel, Deserialize, Serialize)] -#[sea_orm(table_name = "dialog_info")] -pub struct Model { - #[sea_orm(primary_key, auto_increment = false)] - pub dialog_id: i64, - #[sea_orm(index)] - pub uid: i64, - #[serde(skip_deserializing)] - pub kb_id: i64, - pub dialog_name: String, - pub history: String, - - #[serde(skip_deserializing)] - pub created_at: DateTime, - #[serde(skip_deserializing)] - pub updated_at: DateTime, - #[serde(skip_deserializing)] - pub is_deleted: bool, -} - -#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] -pub enum Relation {} - -impl Related for Entity { - fn to() -> RelationDef { - super::dialog2_kb::Relation::KbInfo.def() - } - - fn via() -> Option { - Some(super::dialog2_kb::Relation::DialogInfo.def().rev()) - } -} - -impl ActiveModelBehavior for ActiveModel {} diff --git a/src/entity/doc2_doc.rs b/src/entity/doc2_doc.rs deleted file mode 100644 index 00ab2ca55b20e13fd7456ea602f9d2a6a8d7dee8..0000000000000000000000000000000000000000 --- a/src/entity/doc2_doc.rs +++ /dev/null @@ -1,38 +0,0 @@ -use sea_orm::entity::prelude::*; -use serde::{ Deserialize, Serialize }; - -#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel, Deserialize, Serialize)] -#[sea_orm(table_name = "doc2_doc")] -pub struct Model { - #[sea_orm(primary_key, auto_increment = true)] - pub id: i64, - #[sea_orm(index)] - pub parent_id: i64, - #[sea_orm(index)] - pub did: i64, -} - -#[derive(Debug, Clone, Copy, EnumIter)] -pub enum Relation { - Parent, - Child, -} - -impl RelationTrait for Relation { - fn def(&self) -> RelationDef { - match self { - Self::Parent => - Entity::belongs_to(super::doc_info::Entity) - .from(Column::ParentId) - .to(super::doc_info::Column::Did) - .into(), - Self::Child => - Entity::belongs_to(super::doc_info::Entity) - .from(Column::Did) - .to(super::doc_info::Column::Did) - .into(), - } - } -} - -impl ActiveModelBehavior for ActiveModel {} diff --git a/src/entity/doc_info.rs b/src/entity/doc_info.rs deleted file mode 100644 index 26c46d509604eaab7c9e99305d0915e0d88f7792..0000000000000000000000000000000000000000 --- a/src/entity/doc_info.rs +++ /dev/null @@ -1,62 +0,0 @@ -use sea_orm::entity::prelude::*; -use serde::{ Deserialize, Serialize }; -use crate::entity::kb_info; -use chrono::{ DateTime, FixedOffset }; - -#[derive(Clone, Debug, PartialEq, DeriveEntityModel, Deserialize, Serialize)] -#[sea_orm(table_name = "doc_info")] -pub struct Model { - #[sea_orm(primary_key, auto_increment = false)] - pub did: i64, - #[sea_orm(index)] - pub uid: i64, - pub doc_name: String, - pub size: i64, - #[sea_orm(column_name = "type")] - pub r#type: String, - #[serde(skip_deserializing)] - pub location: String, - #[serde(skip_deserializing)] - pub thumbnail_base64: String, - #[serde(skip_deserializing)] - pub created_at: DateTime, - #[serde(skip_deserializing)] - pub updated_at: DateTime, - #[serde(skip_deserializing)] - pub is_deleted: bool, -} - -#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] -pub enum Relation {} - -impl Related for Entity { - fn to() -> RelationDef { - super::tag2_doc::Relation::Tag.def() - } - - fn via() -> Option { - Some(super::tag2_doc::Relation::DocInfo.def().rev()) - } -} - -impl Related for Entity { - fn to() -> RelationDef { - super::kb2_doc::Relation::KbInfo.def() - } - - fn via() -> Option { - Some(super::kb2_doc::Relation::DocInfo.def().rev()) - } -} - -impl Related for Entity { - fn to() -> RelationDef { - super::doc2_doc::Relation::Parent.def() - } - - fn via() -> Option { - Some(super::doc2_doc::Relation::Child.def().rev()) - } -} - -impl ActiveModelBehavior for ActiveModel {} diff --git a/src/entity/kb2_doc.rs b/src/entity/kb2_doc.rs deleted file mode 100644 index bd3d565f32b2b0ae3846ccb3bf409491edafb27d..0000000000000000000000000000000000000000 --- a/src/entity/kb2_doc.rs +++ /dev/null @@ -1,47 +0,0 @@ -use sea_orm::entity::prelude::*; -use serde::{ Deserialize, Serialize }; -use chrono::{ DateTime, FixedOffset }; - -#[derive(Clone, Debug, PartialEq, DeriveEntityModel, Deserialize, Serialize)] -#[sea_orm(table_name = "kb2_doc")] -pub struct Model { - #[sea_orm(primary_key, auto_increment = true)] - pub id: i64, - #[sea_orm(index)] - pub kb_id: i64, - #[sea_orm(index)] - pub did: i64, - #[serde(skip_deserializing)] - pub kb_progress: f32, - #[serde(skip_deserializing)] - pub kb_progress_msg: String, - #[serde(skip_deserializing)] - pub updated_at: DateTime, - #[serde(skip_deserializing)] - pub is_deleted: bool, -} - -#[derive(Debug, Clone, Copy, EnumIter)] -pub enum Relation { - DocInfo, - KbInfo, -} - -impl RelationTrait for Relation { - fn def(&self) -> RelationDef { - match self { - Self::DocInfo => - Entity::belongs_to(super::doc_info::Entity) - .from(Column::Did) - .to(super::doc_info::Column::Did) - .into(), - Self::KbInfo => - Entity::belongs_to(super::kb_info::Entity) - .from(Column::KbId) - .to(super::kb_info::Column::KbId) - .into(), - } - } -} - -impl ActiveModelBehavior for ActiveModel {} diff --git a/src/entity/kb_info.rs b/src/entity/kb_info.rs deleted file mode 100644 index e06a637f4c4b916bc2675a1326f657a9e2b264e5..0000000000000000000000000000000000000000 --- a/src/entity/kb_info.rs +++ /dev/null @@ -1,47 +0,0 @@ -use sea_orm::entity::prelude::*; -use serde::{ Deserialize, Serialize }; -use chrono::{ DateTime, FixedOffset }; - -#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel, Deserialize, Serialize)] -#[sea_orm(table_name = "kb_info")] -pub struct Model { - #[sea_orm(primary_key, auto_increment = false)] - #[serde(skip_deserializing)] - pub kb_id: i64, - #[sea_orm(index)] - pub uid: i64, - pub kb_name: String, - pub icon: i16, - - #[serde(skip_deserializing)] - pub created_at: DateTime, - #[serde(skip_deserializing)] - pub updated_at: DateTime, - #[serde(skip_deserializing)] - pub is_deleted: bool, -} - -#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] -pub enum Relation {} - -impl Related for Entity { - fn to() -> RelationDef { - super::kb2_doc::Relation::DocInfo.def() - } - - fn via() -> Option { - Some(super::kb2_doc::Relation::KbInfo.def().rev()) - } -} - -impl Related for Entity { - fn to() -> RelationDef { - super::dialog2_kb::Relation::DialogInfo.def() - } - - fn via() -> Option { - Some(super::dialog2_kb::Relation::KbInfo.def().rev()) - } -} - -impl ActiveModelBehavior for ActiveModel {} diff --git a/src/entity/mod.rs b/src/entity/mod.rs deleted file mode 100644 index dcaae2a183f57b17e0c93a1688a4c6c2ce25e261..0000000000000000000000000000000000000000 --- a/src/entity/mod.rs +++ /dev/null @@ -1,9 +0,0 @@ -pub(crate) mod user_info; -pub(crate) mod tag_info; -pub(crate) mod tag2_doc; -pub(crate) mod kb2_doc; -pub(crate) mod dialog2_kb; -pub(crate) mod doc2_doc; -pub(crate) mod kb_info; -pub(crate) mod doc_info; -pub(crate) mod dialog_info; diff --git a/src/entity/tag2_doc.rs b/src/entity/tag2_doc.rs deleted file mode 100644 index 468c5fde75955de86bd8cbd7b5640a63f3ac55b4..0000000000000000000000000000000000000000 --- a/src/entity/tag2_doc.rs +++ /dev/null @@ -1,38 +0,0 @@ -use sea_orm::entity::prelude::*; -use serde::{ Deserialize, Serialize }; - -#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel, Deserialize, Serialize)] -#[sea_orm(table_name = "tag2_doc")] -pub struct Model { - #[sea_orm(primary_key, auto_increment = true)] - pub id: i64, - #[sea_orm(index)] - pub tag_id: i64, - #[sea_orm(index)] - pub did: i64, -} - -#[derive(Debug, Clone, Copy, EnumIter)] -pub enum Relation { - Tag, - DocInfo, -} - -impl RelationTrait for Relation { - fn def(&self) -> sea_orm::RelationDef { - match self { - Self::Tag => - Entity::belongs_to(super::tag_info::Entity) - .from(Column::TagId) - .to(super::tag_info::Column::Tid) - .into(), - Self::DocInfo => - Entity::belongs_to(super::doc_info::Entity) - .from(Column::Did) - .to(super::doc_info::Column::Did) - .into(), - } - } -} - -impl ActiveModelBehavior for ActiveModel {} diff --git a/src/entity/tag_info.rs b/src/entity/tag_info.rs deleted file mode 100644 index b49fa305a952cd4bcb7e6f3cfb2c01763c97de5f..0000000000000000000000000000000000000000 --- a/src/entity/tag_info.rs +++ /dev/null @@ -1,40 +0,0 @@ -use sea_orm::entity::prelude::*; -use serde::{ Deserialize, Serialize }; -use chrono::{ DateTime, FixedOffset }; - -#[derive(Clone, Debug, PartialEq, DeriveEntityModel, Deserialize, Serialize)] -#[sea_orm(table_name = "tag_info")] -pub struct Model { - #[sea_orm(primary_key)] - #[serde(skip_deserializing)] - pub tid: i64, - #[sea_orm(index)] - pub uid: i64, - pub tag_name: String, - #[serde(skip_deserializing)] - pub regx: String, - pub color: i16, - pub icon: i16, - #[serde(skip_deserializing)] - pub folder_id: i64, - - #[serde(skip_deserializing)] - pub created_at: DateTime, - #[serde(skip_deserializing)] - pub updated_at: DateTime, -} - -#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] -pub enum Relation {} - -impl Related for Entity { - fn to() -> RelationDef { - super::tag2_doc::Relation::DocInfo.def() - } - - fn via() -> Option { - Some(super::tag2_doc::Relation::Tag.def().rev()) - } -} - -impl ActiveModelBehavior for ActiveModel {} diff --git a/src/entity/user_info.rs b/src/entity/user_info.rs deleted file mode 100644 index 68fc10e38b3af45b7d2bfd01a6844307ca47535d..0000000000000000000000000000000000000000 --- a/src/entity/user_info.rs +++ /dev/null @@ -1,30 +0,0 @@ -use sea_orm::entity::prelude::*; -use serde::{ Deserialize, Serialize }; -use chrono::{ DateTime, FixedOffset }; - -#[derive(Clone, Debug, PartialEq, Eq, Hash, DeriveEntityModel, Deserialize, Serialize)] -#[sea_orm(table_name = "user_info")] -pub struct Model { - #[sea_orm(primary_key)] - #[serde(skip_deserializing)] - pub uid: i64, - pub email: String, - pub nickname: String, - pub avatar_base64: String, - pub color_scheme: String, - pub list_style: String, - pub language: String, - pub password: String, - - #[serde(skip_deserializing)] - pub last_login_at: DateTime, - #[serde(skip_deserializing)] - pub created_at: DateTime, - #[serde(skip_deserializing)] - pub updated_at: DateTime, -} - -#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] -pub enum Relation {} - -impl ActiveModelBehavior for ActiveModel {} diff --git a/src/errors.rs b/src/errors.rs deleted file mode 100644 index f19e58c17a5272a9f70dd6706c7572104206f33e..0000000000000000000000000000000000000000 --- a/src/errors.rs +++ /dev/null @@ -1,83 +0,0 @@ -use actix_web::{HttpResponse, ResponseError}; -use thiserror::Error; - -#[derive(Debug, Error)] -pub(crate) enum AppError { - #[error("`{0}`")] - User(#[from] UserError), - - #[error("`{0}`")] - Json(#[from] serde_json::Error), - - #[error("`{0}`")] - Actix(#[from] actix_web::Error), - - #[error("`{0}`")] - Db(#[from] sea_orm::DbErr), - - #[error("`{0}`")] - MinioS3(#[from] minio::s3::error::Error), - - #[error("`{0}`")] - Std(#[from] std::io::Error), -} - -#[derive(Debug, Error)] -pub(crate) enum UserError { - #[error("`username` field of `User` cannot be empty!")] - EmptyUsername, - - #[error("`username` field of `User` cannot contain whitespaces!")] - UsernameInvalidCharacter, - - #[error("`password` field of `User` cannot be empty!")] - EmptyPassword, - - #[error("`password` field of `User` cannot contain whitespaces!")] - PasswordInvalidCharacter, - - #[error("Could not find any `User` for id: `{0}`!")] - NotFound(i64), - - #[error("Failed to login user!")] - LoginFailed, - - #[error("User is not logged in!")] - NotLoggedIn, - - #[error("Invalid authorization token!")] - InvalidToken, - - #[error("Could not find any `User`!")] - Empty, -} - -impl ResponseError for AppError { - fn status_code(&self) -> actix_web::http::StatusCode { - match self { - AppError::User(user_error) => match user_error { - UserError::EmptyUsername => actix_web::http::StatusCode::UNPROCESSABLE_ENTITY, - UserError::UsernameInvalidCharacter => { - actix_web::http::StatusCode::UNPROCESSABLE_ENTITY - } - UserError::EmptyPassword => actix_web::http::StatusCode::UNPROCESSABLE_ENTITY, - UserError::PasswordInvalidCharacter => { - actix_web::http::StatusCode::UNPROCESSABLE_ENTITY - } - UserError::NotFound(_) => actix_web::http::StatusCode::NOT_FOUND, - UserError::NotLoggedIn => actix_web::http::StatusCode::UNAUTHORIZED, - UserError::Empty => actix_web::http::StatusCode::NOT_FOUND, - UserError::LoginFailed => actix_web::http::StatusCode::NOT_FOUND, - UserError::InvalidToken => actix_web::http::StatusCode::UNAUTHORIZED, - }, - AppError::Actix(fail) => fail.as_response_error().status_code(), - _ => actix_web::http::StatusCode::INTERNAL_SERVER_ERROR, - } - } - - fn error_response(&self) -> HttpResponse { - let status_code = self.status_code(); - let response = HttpResponse::build(status_code).body(self.to_string()); - response - } -} \ No newline at end of file diff --git a/src/main.rs b/src/main.rs deleted file mode 100644 index 0b2a2840aa8d6873f61df68b1bd0cfe2c24a2b25..0000000000000000000000000000000000000000 --- a/src/main.rs +++ /dev/null @@ -1,145 +0,0 @@ -mod api; -mod entity; -mod service; -mod errors; -mod web_socket; - -use std::env; -use actix_files::Files; -use actix_identity::{ CookieIdentityPolicy, IdentityService, RequestIdentity }; -use actix_session::CookieSession; -use actix_web::{ web, App, HttpServer, middleware, Error }; -use actix_web::cookie::time::Duration; -use actix_web::dev::ServiceRequest; -use actix_web::error::ErrorUnauthorized; -use actix_web_httpauth::extractors::bearer::BearerAuth; -use listenfd::ListenFd; -use minio::s3::client::Client; -use minio::s3::creds::StaticProvider; -use minio::s3::http::BaseUrl; -use sea_orm::{ Database, DatabaseConnection }; -use migration::{ Migrator, MigratorTrait }; -use crate::errors::{ AppError, UserError }; -use crate::web_socket::doc_info::upload_file_ws; - -#[derive(Debug, Clone)] -struct AppState { - conn: DatabaseConnection, - s3_client: Client, -} - -pub(crate) async fn validator( - req: ServiceRequest, - credentials: BearerAuth -) -> Result { - if let Some(token) = req.get_identity() { - println!("{}, {}", credentials.token(), token); - (credentials.token() == token) - .then(|| req) - .ok_or(ErrorUnauthorized(UserError::InvalidToken)) - } else { - Err(ErrorUnauthorized(UserError::NotLoggedIn)) - } -} - -#[actix_web::main] -async fn main() -> Result<(), AppError> { - std::env::set_var("RUST_LOG", "debug"); - tracing_subscriber::fmt::init(); - - // get env vars - dotenvy::dotenv().ok(); - let db_url = env::var("DATABASE_URL").expect("DATABASE_URL is not set in .env file"); - let host = env::var("HOST").expect("HOST is not set in .env file"); - let port = env::var("PORT").expect("PORT is not set in .env file"); - let server_url = format!("{host}:{port}"); - - let mut s3_base_url = env::var("MINIO_HOST").expect("MINIO_HOST is not set in .env file"); - let s3_access_key = env::var("MINIO_USR").expect("MINIO_USR is not set in .env file"); - let s3_secret_key = env::var("MINIO_PWD").expect("MINIO_PWD is not set in .env file"); - if s3_base_url.find("http") != Some(0) { - s3_base_url = format!("http://{}", s3_base_url); - } - - // establish connection to database and apply migrations - // -> create post table if not exists - let conn = Database::connect(&db_url).await.unwrap(); - Migrator::up(&conn, None).await.unwrap(); - - let static_provider = StaticProvider::new(s3_access_key.as_str(), s3_secret_key.as_str(), None); - - let s3_client = Client::new( - s3_base_url.parse::()?, - Some(Box::new(static_provider)), - None, - Some(true) - )?; - - let state = AppState { conn, s3_client }; - - // create server and try to serve over socket if possible - let mut listenfd = ListenFd::from_env(); - let mut server = HttpServer::new(move || { - App::new() - .service(Files::new("/static", "./static")) - .app_data(web::Data::new(state.clone())) - .wrap( - IdentityService::new( - CookieIdentityPolicy::new(&[0; 32]) - .name("auth-cookie") - .login_deadline(Duration::seconds(120)) - .secure(false) - ) - ) - .wrap( - CookieSession::signed(&[0; 32]) - .name("session-cookie") - .secure(false) - // WARNING(alex): This uses the `time` crate, not `std::time`! - .expires_in_time(Duration::seconds(60)) - ) - .wrap(middleware::Logger::default()) - .configure(init) - }); - - server = match listenfd.take_tcp_listener(0)? { - Some(listener) => server.listen(listener)?, - None => server.bind(&server_url)?, - }; - - println!("Starting server at {server_url}"); - server.run().await?; - - Ok(()) -} - -fn init(cfg: &mut web::ServiceConfig) { - cfg.service(api::tag_info::create); - cfg.service(api::tag_info::delete); - cfg.service(api::tag_info::list); - - cfg.service(api::kb_info::create); - cfg.service(api::kb_info::delete); - cfg.service(api::kb_info::list); - cfg.service(api::kb_info::add_docs_to_kb); - cfg.service(api::kb_info::anti_kb_docs); - cfg.service(api::kb_info::all_relevents); - - cfg.service(api::doc_info::list); - cfg.service(api::doc_info::delete); - cfg.service(api::doc_info::mv); - cfg.service(api::doc_info::upload); - cfg.service(api::doc_info::new_folder); - cfg.service(api::doc_info::rename); - - cfg.service(api::dialog_info::list); - cfg.service(api::dialog_info::delete); - cfg.service(api::dialog_info::create); - cfg.service(api::dialog_info::update_history); - - cfg.service(api::user_info::login); - cfg.service(api::user_info::register); - cfg.service(api::user_info::setting); - - cfg.service(web::resource("/ws-upload-doc").route(web::get().to(upload_file_ws))); -} diff --git a/src/service/dialog_info.rs b/src/service/dialog_info.rs deleted file mode 100644 index a5a393021ff1d79cc029a25cf9479ecee3b89cfe..0000000000000000000000000000000000000000 --- a/src/service/dialog_info.rs +++ /dev/null @@ -1,107 +0,0 @@ -use chrono::{ Local, FixedOffset, Utc }; -use migration::Expr; -use sea_orm::{ - ActiveModelTrait, - DbConn, - DbErr, - DeleteResult, - EntityTrait, - PaginatorTrait, - QueryOrder, - UpdateResult, -}; -use sea_orm::ActiveValue::Set; -use sea_orm::QueryFilter; -use sea_orm::ColumnTrait; -use crate::entity::dialog_info; -use crate::entity::dialog_info::Entity; - -fn now() -> chrono::DateTime { - Utc::now().with_timezone(&FixedOffset::east_opt(3600 * 8).unwrap()) -} -pub struct Query; - -impl Query { - pub async fn find_dialog_info_by_id( - db: &DbConn, - id: i64 - ) -> Result, DbErr> { - Entity::find_by_id(id).one(db).await - } - - pub async fn find_dialog_infos(db: &DbConn) -> Result, DbErr> { - Entity::find().all(db).await - } - - pub async fn find_dialog_infos_by_uid( - db: &DbConn, - uid: i64 - ) -> Result, DbErr> { - Entity::find() - .filter(dialog_info::Column::Uid.eq(uid)) - .filter(dialog_info::Column::IsDeleted.eq(false)) - .all(db).await - } - - pub async fn find_dialog_infos_in_page( - db: &DbConn, - page: u64, - posts_per_page: u64 - ) -> Result<(Vec, u64), DbErr> { - // Setup paginator - let paginator = Entity::find() - .order_by_asc(dialog_info::Column::DialogId) - .paginate(db, posts_per_page); - let num_pages = paginator.num_pages().await?; - - // Fetch paginated posts - paginator.fetch_page(page - 1).await.map(|p| (p, num_pages)) - } -} - -pub struct Mutation; - -impl Mutation { - pub async fn create_dialog_info( - db: &DbConn, - uid: i64, - kb_id: i64, - name: &String - ) -> Result { - (dialog_info::ActiveModel { - dialog_id: Default::default(), - uid: Set(uid), - kb_id: Set(kb_id), - dialog_name: Set(name.to_owned()), - history: Set("".to_owned()), - created_at: Set(now()), - updated_at: Set(now()), - is_deleted: Default::default(), - }).save(db).await - } - - pub async fn update_dialog_info_by_id( - db: &DbConn, - dialog_id: i64, - dialog_name: &String, - history: &String - ) -> Result { - Entity::update_many() - .col_expr(dialog_info::Column::DialogName, Expr::value(dialog_name)) - .col_expr(dialog_info::Column::History, Expr::value(history)) - .col_expr(dialog_info::Column::UpdatedAt, Expr::value(now())) - .filter(dialog_info::Column::DialogId.eq(dialog_id)) - .exec(db).await - } - - pub async fn delete_dialog_info(db: &DbConn, dialog_id: i64) -> Result { - Entity::update_many() - .col_expr(dialog_info::Column::IsDeleted, Expr::value(true)) - .filter(dialog_info::Column::DialogId.eq(dialog_id)) - .exec(db).await - } - - pub async fn delete_all_dialog_infos(db: &DbConn) -> Result { - Entity::delete_many().exec(db).await - } -} diff --git a/src/service/doc_info.rs b/src/service/doc_info.rs deleted file mode 100644 index 18fc87f5c640dff2b263fbee5b55205ae8fa8717..0000000000000000000000000000000000000000 --- a/src/service/doc_info.rs +++ /dev/null @@ -1,335 +0,0 @@ -use chrono::{ Utc, FixedOffset }; -use sea_orm::{ - ActiveModelTrait, - ColumnTrait, - DbConn, - DbErr, - DeleteResult, - EntityTrait, - PaginatorTrait, - QueryOrder, - Unset, - Unchanged, - ConditionalStatement, - QuerySelect, - JoinType, - RelationTrait, - DbBackend, - Statement, - UpdateResult, -}; -use sea_orm::ActiveValue::Set; -use sea_orm::QueryFilter; -use crate::api::doc_info::ListParams; -use crate::entity::{ doc2_doc, doc_info }; -use crate::entity::doc_info::Entity; -use crate::service; - -fn now() -> chrono::DateTime { - Utc::now().with_timezone(&FixedOffset::east_opt(3600 * 8).unwrap()) -} - -pub struct Query; - -impl Query { - pub async fn find_doc_info_by_id( - db: &DbConn, - id: i64 - ) -> Result, DbErr> { - Entity::find_by_id(id).one(db).await - } - - pub async fn find_doc_infos(db: &DbConn) -> Result, DbErr> { - Entity::find().all(db).await - } - - pub async fn find_doc_infos_by_uid( - db: &DbConn, - uid: i64 - ) -> Result, DbErr> { - Entity::find().filter(doc_info::Column::Uid.eq(uid)).all(db).await - } - - pub async fn find_doc_infos_by_name( - db: &DbConn, - uid: i64, - name: &String, - parent_id: Option - ) -> Result, DbErr> { - let mut dids = Vec::::new(); - if let Some(pid) = parent_id { - for d2d in doc2_doc::Entity - ::find() - .filter(doc2_doc::Column::ParentId.eq(pid)) - .all(db).await? { - dids.push(d2d.did); - } - } else { - let doc = Entity::find() - .filter(doc_info::Column::DocName.eq(name.clone())) - .filter(doc_info::Column::Uid.eq(uid)) - .all(db).await?; - if doc.len() == 0 { - return Ok(vec![]); - } - assert!(doc.len() > 0); - let d2d = doc2_doc::Entity - ::find() - .filter(doc2_doc::Column::Did.eq(doc[0].did)) - .all(db).await?; - assert!(d2d.len() <= 1, "Did: {}->{}", doc[0].did, d2d.len()); - if d2d.len() > 0 { - for d2d_ in doc2_doc::Entity - ::find() - .filter(doc2_doc::Column::ParentId.eq(d2d[0].parent_id)) - .all(db).await? { - dids.push(d2d_.did); - } - } - } - - Entity::find() - .filter(doc_info::Column::DocName.eq(name.clone())) - .filter(doc_info::Column::Uid.eq(uid)) - .filter(doc_info::Column::Did.is_in(dids)) - .filter(doc_info::Column::IsDeleted.eq(false)) - .all(db).await - } - - pub async fn all_descendent_ids(db: &DbConn, doc_ids: &Vec) -> Result, DbErr> { - let mut dids = doc_ids.clone(); - let mut i: usize = 0; - loop { - if dids.len() == i { - break; - } - - for d in doc2_doc::Entity - ::find() - .filter(doc2_doc::Column::ParentId.eq(dids[i])) - .all(db).await? { - dids.push(d.did); - } - i += 1; - } - Ok(dids) - } - - pub async fn find_doc_infos_by_params( - db: &DbConn, - params: ListParams - ) -> Result, DbErr> { - // Setup paginator - let mut sql: String = - " - select - a.did, - a.uid, - a.doc_name, - a.location, - a.size, - a.type, - a.created_at, - a.updated_at, - a.is_deleted - from - doc_info as a - ".to_owned(); - - let mut cond: String = format!(" a.uid={} and a.is_deleted=False ", params.uid); - - if let Some(kb_id) = params.filter.kb_id { - sql.push_str( - &format!(" inner join kb2_doc on kb2_doc.did = a.did and kb2_doc.kb_id={}", kb_id) - ); - } - if let Some(folder_id) = params.filter.folder_id { - sql.push_str( - &format!(" inner join doc2_doc on a.did = doc2_doc.did and doc2_doc.parent_id={}", folder_id) - ); - } - // Fetch paginated posts - if let Some(tag_id) = params.filter.tag_id { - let tag = service::tag_info::Query - ::find_tag_info_by_id(tag_id, &db).await - .unwrap() - .unwrap(); - if tag.folder_id > 0 { - sql.push_str( - &format!( - " inner join doc2_doc on a.did = doc2_doc.did and doc2_doc.parent_id={}", - tag.folder_id - ) - ); - } - if tag.regx.len() > 0 { - cond.push_str(&format!(" and (type='{}' or doc_name ~ '{}') ", tag.tag_name, tag.regx)); - } - } - - if let Some(keywords) = params.filter.keywords { - cond.push_str(&format!(" and doc_name like '%{}%'", keywords)); - } - if cond.len() > 0 { - sql.push_str(&" where "); - sql.push_str(&cond); - } - let mut orderby = params.sortby.clone(); - if orderby.len() == 0 { - orderby = "updated_at desc".to_owned(); - } - sql.push_str(&format!(" order by {}", orderby)); - let mut page_size: u32 = 30; - if let Some(pg_sz) = params.per_page { - page_size = pg_sz; - } - let mut page: u32 = 0; - if let Some(pg) = params.page { - page = pg; - } - sql.push_str(&format!(" limit {} offset {} ;", page_size, page * page_size)); - - print!("{}", sql); - Entity::find() - .from_raw_sql(Statement::from_sql_and_values(DbBackend::Postgres, sql, vec![])) - .all(db).await - } - - pub async fn find_doc_infos_in_page( - db: &DbConn, - page: u64, - posts_per_page: u64 - ) -> Result<(Vec, u64), DbErr> { - // Setup paginator - let paginator = Entity::find() - .order_by_asc(doc_info::Column::Did) - .paginate(db, posts_per_page); - let num_pages = paginator.num_pages().await?; - - // Fetch paginated posts - paginator.fetch_page(page - 1).await.map(|p| (p, num_pages)) - } -} - -pub struct Mutation; - -impl Mutation { - pub async fn mv_doc_info(db: &DbConn, dest_did: i64, dids: &[i64]) -> Result<(), DbErr> { - for did in dids { - let d = doc2_doc::Entity - ::find() - .filter(doc2_doc::Column::Did.eq(did.to_owned())) - .all(db).await?; - - let _ = (doc2_doc::ActiveModel { - id: Set(d[0].id), - did: Set(did.to_owned()), - parent_id: Set(dest_did), - }).update(db).await?; - } - - Ok(()) - } - - pub async fn place_doc( - db: &DbConn, - dest_did: i64, - did: i64 - ) -> Result { - (doc2_doc::ActiveModel { - id: Default::default(), - parent_id: Set(dest_did), - did: Set(did), - }).save(db).await - } - - pub async fn create_doc_info( - db: &DbConn, - form_data: doc_info::Model - ) -> Result { - (doc_info::ActiveModel { - did: Default::default(), - uid: Set(form_data.uid.to_owned()), - doc_name: Set(form_data.doc_name.to_owned()), - size: Set(form_data.size.to_owned()), - r#type: Set(form_data.r#type.to_owned()), - location: Set(form_data.location.to_owned()), - thumbnail_base64: Default::default(), - created_at: Set(form_data.created_at.to_owned()), - updated_at: Set(form_data.updated_at.to_owned()), - is_deleted: Default::default(), - }).save(db).await - } - - pub async fn update_doc_info_by_id( - db: &DbConn, - id: i64, - form_data: doc_info::Model - ) -> Result { - let doc_info: doc_info::ActiveModel = Entity::find_by_id(id) - .one(db).await? - .ok_or(DbErr::Custom("Cannot find.".to_owned())) - .map(Into::into)?; - - (doc_info::ActiveModel { - did: doc_info.did, - uid: Set(form_data.uid.to_owned()), - doc_name: Set(form_data.doc_name.to_owned()), - size: Set(form_data.size.to_owned()), - r#type: Set(form_data.r#type.to_owned()), - location: Set(form_data.location.to_owned()), - thumbnail_base64: doc_info.thumbnail_base64, - created_at: doc_info.created_at, - updated_at: Set(now()), - is_deleted: Default::default(), - }).update(db).await - } - - pub async fn delete_doc_info(db: &DbConn, doc_ids: &Vec) -> Result { - let mut dids = doc_ids.clone(); - let mut i: usize = 0; - loop { - if dids.len() == i { - break; - } - let mut doc: doc_info::ActiveModel = Entity::find_by_id(dids[i]) - .one(db).await? - .ok_or(DbErr::Custom(format!("Can't find doc:{}", dids[i]))) - .map(Into::into)?; - doc.updated_at = Set(now()); - doc.is_deleted = Set(true); - let _ = doc.update(db).await?; - - for d in doc2_doc::Entity - ::find() - .filter(doc2_doc::Column::ParentId.eq(dids[i])) - .all(db).await? { - dids.push(d.did); - } - let _ = doc2_doc::Entity - ::delete_many() - .filter(doc2_doc::Column::ParentId.eq(dids[i])) - .exec(db).await?; - let _ = doc2_doc::Entity - ::delete_many() - .filter(doc2_doc::Column::Did.eq(dids[i])) - .exec(db).await?; - i += 1; - } - crate::service::kb_info::Mutation::remove_docs(&db, dids, None).await - } - - pub async fn rename(db: &DbConn, doc_id: i64, name: &String) -> Result { - let mut doc: doc_info::ActiveModel = Entity::find_by_id(doc_id) - .one(db).await? - .ok_or(DbErr::Custom(format!("Can't find doc:{}", doc_id))) - .map(Into::into)?; - doc.updated_at = Set(now()); - doc.doc_name = Set(name.clone()); - doc.update(db).await - } - - pub async fn delete_all_doc_infos(db: &DbConn) -> Result { - Entity::delete_many().exec(db).await - } -} diff --git a/src/service/kb_info.rs b/src/service/kb_info.rs deleted file mode 100644 index fe25ecacf9571e450652182659cf53c44d499a4b..0000000000000000000000000000000000000000 --- a/src/service/kb_info.rs +++ /dev/null @@ -1,168 +0,0 @@ -use chrono::{ Local, FixedOffset, Utc }; -use migration::Expr; -use sea_orm::{ - ActiveModelTrait, - ColumnTrait, - DbConn, - DbErr, - DeleteResult, - EntityTrait, - PaginatorTrait, - QueryFilter, - QueryOrder, - UpdateResult, -}; -use sea_orm::ActiveValue::Set; -use crate::entity::kb_info; -use crate::entity::kb2_doc; -use crate::entity::kb_info::Entity; - -fn now() -> chrono::DateTime { - Utc::now().with_timezone(&FixedOffset::east_opt(3600 * 8).unwrap()) -} -pub struct Query; - -impl Query { - pub async fn find_kb_info_by_id(db: &DbConn, id: i64) -> Result, DbErr> { - Entity::find_by_id(id).one(db).await - } - - pub async fn find_kb_infos(db: &DbConn) -> Result, DbErr> { - Entity::find().all(db).await - } - - pub async fn find_kb_infos_by_uid(db: &DbConn, uid: i64) -> Result, DbErr> { - Entity::find().filter(kb_info::Column::Uid.eq(uid)).all(db).await - } - - pub async fn find_kb_infos_by_name( - db: &DbConn, - name: String - ) -> Result, DbErr> { - Entity::find().filter(kb_info::Column::KbName.eq(name)).all(db).await - } - - pub async fn find_kb_by_docs( - db: &DbConn, - doc_ids: Vec - ) -> Result, DbErr> { - let mut kbids = Vec::::new(); - for k in kb2_doc::Entity - ::find() - .filter(kb2_doc::Column::Did.is_in(doc_ids)) - .all(db).await? { - kbids.push(k.kb_id); - } - Entity::find().filter(kb_info::Column::KbId.is_in(kbids)).all(db).await - } - - pub async fn find_kb_infos_in_page( - db: &DbConn, - page: u64, - posts_per_page: u64 - ) -> Result<(Vec, u64), DbErr> { - // Setup paginator - let paginator = Entity::find() - .order_by_asc(kb_info::Column::KbId) - .paginate(db, posts_per_page); - let num_pages = paginator.num_pages().await?; - - // Fetch paginated posts - paginator.fetch_page(page - 1).await.map(|p| (p, num_pages)) - } -} - -pub struct Mutation; - -impl Mutation { - pub async fn create_kb_info( - db: &DbConn, - form_data: kb_info::Model - ) -> Result { - (kb_info::ActiveModel { - kb_id: Default::default(), - uid: Set(form_data.uid.to_owned()), - kb_name: Set(form_data.kb_name.to_owned()), - icon: Set(form_data.icon.to_owned()), - created_at: Set(now()), - updated_at: Set(now()), - is_deleted: Default::default(), - }).save(db).await - } - - pub async fn add_docs(db: &DbConn, kb_id: i64, doc_ids: Vec) -> Result<(), DbErr> { - for did in doc_ids { - let res = kb2_doc::Entity - ::find() - .filter(kb2_doc::Column::KbId.eq(kb_id)) - .filter(kb2_doc::Column::Did.eq(did)) - .all(db).await?; - if res.len() > 0 { - continue; - } - let _ = (kb2_doc::ActiveModel { - id: Default::default(), - kb_id: Set(kb_id), - did: Set(did), - kb_progress: Set(0.0), - kb_progress_msg: Set("".to_owned()), - updated_at: Set(now()), - is_deleted: Default::default(), - }).save(db).await?; - } - - Ok(()) - } - - pub async fn remove_docs( - db: &DbConn, - doc_ids: Vec, - kb_id: Option - ) -> Result { - let update = kb2_doc::Entity - ::update_many() - .col_expr(kb2_doc::Column::IsDeleted, Expr::value(true)) - .col_expr(kb2_doc::Column::KbProgress, Expr::value(0)) - .col_expr(kb2_doc::Column::KbProgressMsg, Expr::value("")) - .filter(kb2_doc::Column::Did.is_in(doc_ids)); - if let Some(kbid) = kb_id { - update.filter(kb2_doc::Column::KbId.eq(kbid)).exec(db).await - } else { - update.exec(db).await - } - } - - pub async fn update_kb_info_by_id( - db: &DbConn, - id: i64, - form_data: kb_info::Model - ) -> Result { - let kb_info: kb_info::ActiveModel = Entity::find_by_id(id) - .one(db).await? - .ok_or(DbErr::Custom("Cannot find.".to_owned())) - .map(Into::into)?; - - (kb_info::ActiveModel { - kb_id: kb_info.kb_id, - uid: kb_info.uid, - kb_name: Set(form_data.kb_name.to_owned()), - icon: Set(form_data.icon.to_owned()), - created_at: kb_info.created_at, - updated_at: Set(now()), - is_deleted: Default::default(), - }).update(db).await - } - - pub async fn delete_kb_info(db: &DbConn, kb_id: i64) -> Result { - let kb: kb_info::ActiveModel = Entity::find_by_id(kb_id) - .one(db).await? - .ok_or(DbErr::Custom("Cannot find.".to_owned())) - .map(Into::into)?; - - kb.delete(db).await - } - - pub async fn delete_all_kb_infos(db: &DbConn) -> Result { - Entity::delete_many().exec(db).await - } -} diff --git a/src/service/mod.rs b/src/service/mod.rs deleted file mode 100644 index 41371e396f64eb618cb8f89a4f97393316d107d9..0000000000000000000000000000000000000000 --- a/src/service/mod.rs +++ /dev/null @@ -1,5 +0,0 @@ -pub(crate) mod dialog_info; -pub(crate) mod tag_info; -pub(crate) mod kb_info; -pub(crate) mod doc_info; -pub(crate) mod user_info; diff --git a/src/service/tag_info.rs b/src/service/tag_info.rs deleted file mode 100644 index 2fec87a3859ff303ebe43342de1f6fe30e306abf..0000000000000000000000000000000000000000 --- a/src/service/tag_info.rs +++ /dev/null @@ -1,108 +0,0 @@ -use chrono::{ FixedOffset, Utc }; -use sea_orm::{ - ActiveModelTrait, - DbConn, - DbErr, - DeleteResult, - EntityTrait, - PaginatorTrait, - QueryOrder, - ColumnTrait, - QueryFilter, -}; -use sea_orm::ActiveValue::{ Set, NotSet }; -use crate::entity::tag_info; -use crate::entity::tag_info::Entity; - -fn now() -> chrono::DateTime { - Utc::now().with_timezone(&FixedOffset::east_opt(3600 * 8).unwrap()) -} -pub struct Query; - -impl Query { - pub async fn find_tag_info_by_id( - id: i64, - db: &DbConn - ) -> Result, DbErr> { - Entity::find_by_id(id).one(db).await - } - - pub async fn find_tags_by_uid(uid: i64, db: &DbConn) -> Result, DbErr> { - Entity::find().filter(tag_info::Column::Uid.eq(uid)).all(db).await - } - - pub async fn find_tag_infos_in_page( - db: &DbConn, - page: u64, - posts_per_page: u64 - ) -> Result<(Vec, u64), DbErr> { - // Setup paginator - let paginator = Entity::find() - .order_by_asc(tag_info::Column::Tid) - .paginate(db, posts_per_page); - let num_pages = paginator.num_pages().await?; - - // Fetch paginated posts - paginator.fetch_page(page - 1).await.map(|p| (p, num_pages)) - } -} - -pub struct Mutation; - -impl Mutation { - pub async fn create_tag( - db: &DbConn, - form_data: tag_info::Model - ) -> Result { - (tag_info::ActiveModel { - tid: Default::default(), - uid: Set(form_data.uid.to_owned()), - tag_name: Set(form_data.tag_name.to_owned()), - regx: Set(form_data.regx.to_owned()), - color: Set(form_data.color.to_owned()), - icon: Set(form_data.icon.to_owned()), - folder_id: match form_data.folder_id { - 0 => NotSet, - _ => Set(form_data.folder_id.to_owned()), - }, - created_at: Set(now()), - updated_at: Set(now()), - }).save(db).await - } - - pub async fn update_tag_by_id( - db: &DbConn, - id: i64, - form_data: tag_info::Model - ) -> Result { - let tag: tag_info::ActiveModel = Entity::find_by_id(id) - .one(db).await? - .ok_or(DbErr::Custom("Cannot find tag.".to_owned())) - .map(Into::into)?; - - (tag_info::ActiveModel { - tid: tag.tid, - uid: tag.uid, - tag_name: Set(form_data.tag_name.to_owned()), - regx: Set(form_data.regx.to_owned()), - color: Set(form_data.color.to_owned()), - icon: Set(form_data.icon.to_owned()), - folder_id: Set(form_data.folder_id.to_owned()), - created_at: Default::default(), - updated_at: Set(now()), - }).update(db).await - } - - pub async fn delete_tag(db: &DbConn, tid: i64) -> Result { - let tag: tag_info::ActiveModel = Entity::find_by_id(tid) - .one(db).await? - .ok_or(DbErr::Custom("Cannot find tag.".to_owned())) - .map(Into::into)?; - - tag.delete(db).await - } - - pub async fn delete_all_tags(db: &DbConn) -> Result { - Entity::delete_many().exec(db).await - } -} diff --git a/src/service/user_info.rs b/src/service/user_info.rs deleted file mode 100644 index 5e15611c65f9c332b13345ae064825f158de3714..0000000000000000000000000000000000000000 --- a/src/service/user_info.rs +++ /dev/null @@ -1,131 +0,0 @@ -use chrono::{ FixedOffset, Utc }; -use migration::Expr; -use sea_orm::{ - ActiveModelTrait, - ColumnTrait, - DbConn, - DbErr, - DeleteResult, - EntityTrait, - PaginatorTrait, - QueryFilter, - QueryOrder, - UpdateResult, -}; -use sea_orm::ActiveValue::Set; -use crate::entity::user_info; -use crate::entity::user_info::Entity; - -fn now() -> chrono::DateTime { - Utc::now().with_timezone(&FixedOffset::east_opt(3600 * 8).unwrap()) -} -pub struct Query; - -impl Query { - pub async fn find_user_info_by_id( - db: &DbConn, - id: i64 - ) -> Result, DbErr> { - Entity::find_by_id(id).one(db).await - } - - pub async fn login( - db: &DbConn, - email: &str, - password: &str - ) -> Result, DbErr> { - Entity::find() - .filter(user_info::Column::Email.eq(email)) - .filter(user_info::Column::Password.eq(password)) - .one(db).await - } - - pub async fn find_user_infos( - db: &DbConn, - email: &String - ) -> Result, DbErr> { - Entity::find().filter(user_info::Column::Email.eq(email)).one(db).await - } - - pub async fn find_user_infos_in_page( - db: &DbConn, - page: u64, - posts_per_page: u64 - ) -> Result<(Vec, u64), DbErr> { - // Setup paginator - let paginator = Entity::find() - .order_by_asc(user_info::Column::Uid) - .paginate(db, posts_per_page); - let num_pages = paginator.num_pages().await?; - - // Fetch paginated posts - paginator.fetch_page(page - 1).await.map(|p| (p, num_pages)) - } -} - -pub struct Mutation; - -impl Mutation { - pub async fn create_user( - db: &DbConn, - form_data: &user_info::Model - ) -> Result { - (user_info::ActiveModel { - uid: Default::default(), - email: Set(form_data.email.to_owned()), - nickname: Set(form_data.nickname.to_owned()), - avatar_base64: Set(form_data.avatar_base64.to_owned()), - color_scheme: Set(form_data.color_scheme.to_owned()), - list_style: Set(form_data.list_style.to_owned()), - language: Set(form_data.language.to_owned()), - password: Set(form_data.password.to_owned()), - last_login_at: Set(now()), - created_at: Set(now()), - updated_at: Set(now()), - }).save(db).await - } - - pub async fn update_user_by_id( - db: &DbConn, - form_data: &user_info::Model - ) -> Result { - let usr: user_info::ActiveModel = Entity::find_by_id(form_data.uid) - .one(db).await? - .ok_or(DbErr::Custom("Cannot find user.".to_owned())) - .map(Into::into)?; - - (user_info::ActiveModel { - uid: Set(form_data.uid), - email: Set(form_data.email.to_owned()), - nickname: Set(form_data.nickname.to_owned()), - avatar_base64: Set(form_data.avatar_base64.to_owned()), - color_scheme: Set(form_data.color_scheme.to_owned()), - list_style: Set(form_data.list_style.to_owned()), - language: Set(form_data.language.to_owned()), - password: Set(form_data.password.to_owned()), - updated_at: Set(now()), - last_login_at: usr.last_login_at, - created_at: usr.created_at, - }).update(db).await - } - - pub async fn update_login_status(uid: i64, db: &DbConn) -> Result { - Entity::update_many() - .col_expr(user_info::Column::LastLoginAt, Expr::value(now())) - .filter(user_info::Column::Uid.eq(uid)) - .exec(db).await - } - - pub async fn delete_user(db: &DbConn, tid: i64) -> Result { - let tag: user_info::ActiveModel = Entity::find_by_id(tid) - .one(db).await? - .ok_or(DbErr::Custom("Cannot find tag.".to_owned())) - .map(Into::into)?; - - tag.delete(db).await - } - - pub async fn delete_all(db: &DbConn) -> Result { - Entity::delete_many().exec(db).await - } -} diff --git a/src/web_socket/doc_info.rs b/src/web_socket/doc_info.rs deleted file mode 100644 index 70b8022caac9e39cd78a2970bf93b7d0fe1b597b..0000000000000000000000000000000000000000 --- a/src/web_socket/doc_info.rs +++ /dev/null @@ -1,97 +0,0 @@ -use std::io::{Cursor, Write}; -use std::time::{Duration, Instant}; -use actix_rt::time::interval; -use actix_web::{HttpRequest, HttpResponse, rt, web}; -use actix_web::web::Buf; -use actix_ws::Message; -use futures_util::{future, StreamExt}; -use futures_util::future::Either; -use uuid::Uuid; -use crate::api::doc_info::_upload_file; -use crate::AppState; -use crate::errors::AppError; - -const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5); - -/// How long before lack of client response causes a timeout. -const CLIENT_TIMEOUT: Duration = Duration::from_secs(10); - -pub async fn upload_file_ws(req: HttpRequest, stream: web::Payload, data: web::Data) -> Result { - let (res, session, msg_stream) = actix_ws::handle(&req, stream)?; - - // spawn websocket handler (and don't await it) so that the response is returned immediately - rt::spawn(upload_file_handler(data, session, msg_stream)); - - Ok(res) -} - -async fn upload_file_handler( - data: web::Data, - mut session: actix_ws::Session, - mut msg_stream: actix_ws::MessageStream, -) { - let mut bytes = Cursor::new(vec![]); - let mut last_heartbeat = Instant::now(); - let mut interval = interval(HEARTBEAT_INTERVAL); - - let reason = loop { - let tick = interval.tick(); - tokio::pin!(tick); - - match future::select(msg_stream.next(), tick).await { - // received message from WebSocket client - Either::Left((Some(Ok(msg)), _)) => { - match msg { - Message::Text(text) => { - session.text(text).await.unwrap(); - } - - Message::Binary(bin) => { - let mut pos = 0; // notice the name of the file that will be written - while pos < bin.len() { - let bytes_written = bytes.write(&bin[pos..]).unwrap(); - pos += bytes_written - }; - session.binary(bin).await.unwrap(); - } - - Message::Close(reason) => { - break reason; - } - - Message::Ping(bytes) => { - last_heartbeat = Instant::now(); - let _ = session.pong(&bytes).await; - } - - Message::Pong(_) => { - last_heartbeat = Instant::now(); - } - - Message::Continuation(_) | Message::Nop => {} - }; - } - Either::Left((Some(Err(_)), _)) => { - break None; - } - Either::Left((None, _)) => break None, - Either::Right((_inst, _)) => { - if Instant::now().duration_since(last_heartbeat) > CLIENT_TIMEOUT { - break None; - } - - let _ = session.ping(b"").await; - } - } - }; - let _ = session.close(reason).await; - - if !bytes.has_remaining() { - return; - } - - let uid = bytes.get_i64(); - let did = bytes.get_i64(); - - _upload_file(uid, did, &Uuid::new_v4().to_string(), &bytes.into_inner(), &data).await.unwrap(); -} \ No newline at end of file diff --git a/src/web_socket/mod.rs b/src/web_socket/mod.rs deleted file mode 100644 index 2753903b57eb9e3ae88ae2d9693bc36889bb3e18..0000000000000000000000000000000000000000 --- a/src/web_socket/mod.rs +++ /dev/null @@ -1 +0,0 @@ -pub mod doc_info; \ No newline at end of file diff --git a/web_server/__init__.py b/web_server/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/web_server/apps/__init__.py b/web_server/apps/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6a7cc9d30c13e041dad9f1291b9de8372940214e --- /dev/null +++ b/web_server/apps/__init__.py @@ -0,0 +1,147 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import logging +import sys +from importlib.util import module_from_spec, spec_from_file_location +from pathlib import Path +from flask import Blueprint, Flask, request +from werkzeug.wrappers.request import Request +from flask_cors import CORS + +from web_server.db import StatusEnum +from web_server.db.services import UserService +from web_server.utils import CustomJSONEncoder + +from flask_session import Session +from flask_login import LoginManager +from web_server.settings import RetCode, SECRET_KEY, stat_logger +from web_server.hook import HookManager +from web_server.hook.common.parameters import AuthenticationParameters, ClientAuthenticationParameters +from web_server.settings import API_VERSION, CLIENT_AUTHENTICATION, SITE_AUTHENTICATION, access_logger +from web_server.utils.api_utils import get_json_result, server_error_response +from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer + +__all__ = ['app'] + + +logger = logging.getLogger('flask.app') +for h in access_logger.handlers: + logger.addHandler(h) + +Request.json = property(lambda self: self.get_json(force=True, silent=True)) + +app = Flask(__name__) +CORS(app, supports_credentials=True,max_age = 2592000) +app.url_map.strict_slashes = False +app.json_encoder = CustomJSONEncoder +app.errorhandler(Exception)(server_error_response) + + +## convince for dev and debug +#app.config["LOGIN_DISABLED"] = True +app.config["SESSION_PERMANENT"] = False +app.config["SESSION_TYPE"] = "filesystem" +app.config['MAX_CONTENT_LENGTH'] = 64 * 1024 * 1024 + +Session(app) +login_manager = LoginManager() +login_manager.init_app(app) + + + +def search_pages_path(pages_dir): + return [path for path in pages_dir.glob('*_app.py') if not path.name.startswith('.')] + + +def register_page(page_path): + page_name = page_path.stem.rstrip('_app') + module_name = '.'.join(page_path.parts[page_path.parts.index('web_server'):-1] + (page_name, )) + + spec = spec_from_file_location(module_name, page_path) + page = module_from_spec(spec) + page.app = app + page.manager = Blueprint(page_name, module_name) + sys.modules[module_name] = page + spec.loader.exec_module(page) + + page_name = getattr(page, 'page_name', page_name) + url_prefix = f'/{API_VERSION}/{page_name}' + + app.register_blueprint(page.manager, url_prefix=url_prefix) + return url_prefix + + +pages_dir = [ + Path(__file__).parent, + Path(__file__).parent.parent / 'web_server' / 'apps', +] + +client_urls_prefix = [ + register_page(path) + for dir in pages_dir + for path in search_pages_path(dir) +] + + +def client_authentication_before_request(): + result = HookManager.client_authentication(ClientAuthenticationParameters( + request.full_path, request.headers, + request.form, request.data, request.json, + )) + + if result.code != RetCode.SUCCESS: + return get_json_result(result.code, result.message) + + +def site_authentication_before_request(): + for url_prefix in client_urls_prefix: + if request.path.startswith(url_prefix): + return + + result = HookManager.site_authentication(AuthenticationParameters( + request.headers.get('site_signature'), + request.json, + )) + + if result.code != RetCode.SUCCESS: + return get_json_result(result.code, result.message) + + +@app.before_request +def authentication_before_request(): + if CLIENT_AUTHENTICATION: + return client_authentication_before_request() + + if SITE_AUTHENTICATION: + return site_authentication_before_request() + +@login_manager.request_loader +def load_user(web_request): + jwt = Serializer(secret_key=SECRET_KEY) + authorization = web_request.headers.get("Authorization") + if authorization: + try: + access_token = str(jwt.loads(authorization)) + user = UserService.query(access_token=access_token, status=StatusEnum.VALID.value) + if user: + return user[0] + else: + return None + except Exception as e: + stat_logger.exception(e) + return None + else: + return None \ No newline at end of file diff --git a/web_server/apps/document_app.py b/web_server/apps/document_app.py new file mode 100644 index 0000000000000000000000000000000000000000..d14d69ab1d67c820c841dd8388b30950bb89622f --- /dev/null +++ b/web_server/apps/document_app.py @@ -0,0 +1,235 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import pathlib + +from elasticsearch_dsl import Q +from flask import request +from flask_login import login_required, current_user + +from rag.nlp import search +from rag.utils import ELASTICSEARCH +from web_server.db.services import duplicate_name +from web_server.db.services.kb_service import KnowledgebaseService +from web_server.db.services.user_service import TenantService +from web_server.utils.api_utils import server_error_response, get_data_error_result, validate_request +from web_server.utils import get_uuid, get_format_time +from web_server.db import StatusEnum, FileType +from web_server.db.services.document_service import DocumentService +from web_server.settings import RetCode +from web_server.utils.api_utils import get_json_result +from rag.utils.minio_conn import MINIO +from web_server.utils.file_utils import filename_type + + +@manager.route('/upload', methods=['POST']) +@login_required +@validate_request("kb_id") +def upload(): + kb_id = request.form.get("kb_id") + if not kb_id: + return get_json_result( + data=False, retmsg='Lack of "KB ID"', retcode=RetCode.ARGUMENT_ERROR) + if 'file' not in request.files: + return get_json_result( + data=False, retmsg='No file part!', retcode=RetCode.ARGUMENT_ERROR) + file = request.files['file'] + if file.filename == '': + return get_json_result( + data=False, retmsg='No file selected!', retcode=RetCode.ARGUMENT_ERROR) + + try: + e, kb = KnowledgebaseService.get_by_id(kb_id) + if not e: + return get_data_error_result( + retmsg="Can't find this knowledgebase!") + + filename = duplicate_name( + DocumentService.query, + name=file.filename, + kb_id=kb.id) + location = filename + while MINIO.obj_exist(kb_id, location): + location += "_" + blob = request.files['file'].read() + MINIO.put(kb_id, filename, blob) + doc = DocumentService.insert({ + "id": get_uuid(), + "kb_id": kb.id, + "parser_id": kb.parser_id, + "created_by": current_user.id, + "type": filename_type(filename), + "name": filename, + "location": location, + "size": len(blob) + }) + return get_json_result(data=doc.to_json()) + except Exception as e: + return server_error_response(e) + + +@manager.route('/create', methods=['POST']) +@login_required +@validate_request("name", "kb_id") +def create(): + req = request.json + kb_id = req["kb_id"] + if not kb_id: + return get_json_result( + data=False, retmsg='Lack of "KB ID"', retcode=RetCode.ARGUMENT_ERROR) + + try: + e, kb = KnowledgebaseService.get_by_id(kb_id) + if not e: + return get_data_error_result( + retmsg="Can't find this knowledgebase!") + + if DocumentService.query(name=req["name"], kb_id=kb_id): + return get_data_error_result( + retmsg="Duplicated document name in the same knowledgebase.") + + doc = DocumentService.insert({ + "id": get_uuid(), + "kb_id": kb.id, + "parser_id": kb.parser_id, + "created_by": current_user.id, + "type": FileType.VIRTUAL, + "name": req["name"], + "location": "", + "size": 0 + }) + return get_json_result(data=doc.to_json()) + except Exception as e: + return server_error_response(e) + + +@manager.route('/list', methods=['GET']) +@login_required +def list(): + kb_id = request.args.get("kb_id") + if not kb_id: + return get_json_result( + data=False, retmsg='Lack of "KB ID"', retcode=RetCode.ARGUMENT_ERROR) + keywords = request.args.get("keywords", "") + + page_number = request.args.get("page", 1) + items_per_page = request.args.get("page_size", 15) + orderby = request.args.get("orderby", "create_time") + desc = request.args.get("desc", True) + try: + docs = DocumentService.get_by_kb_id( + kb_id, page_number, items_per_page, orderby, desc, keywords) + return get_json_result(data=docs) + except Exception as e: + return server_error_response(e) + + +@manager.route('/change_status', methods=['POST']) +@login_required +@validate_request("doc_id", "status") +def change_status(): + req = request.json + if str(req["status"]) not in ["0", "1"]: + get_json_result( + data=False, + retmsg='"Status" must be either 0 or 1!', + retcode=RetCode.ARGUMENT_ERROR) + + try: + e, doc = DocumentService.get_by_id(req["doc_id"]) + if not e: + return get_data_error_result(retmsg="Document not found!") + e, kb = KnowledgebaseService.get_by_id(doc.kb_id) + if not e: + return get_data_error_result( + retmsg="Can't find this knowledgebase!") + + if not DocumentService.update_by_id( + req["doc_id"], {"status": str(req["status"])}): + return get_data_error_result( + retmsg="Database error (Document update)!") + + if str(req["status"]) == "0": + ELASTICSEARCH.updateScriptByQuery(Q("term", doc_id=req["doc_id"]), + scripts=""" + if(ctx._source.kb_id.contains('%s')) + ctx._source.kb_id.remove( + ctx._source.kb_id.indexOf('%s') + ); + """ % (doc.kb_id, doc.kb_id), + idxnm=search.index_name( + kb.tenant_id) + ) + else: + ELASTICSEARCH.updateScriptByQuery(Q("term", doc_id=req["doc_id"]), + scripts=""" + if(!ctx._source.kb_id.contains('%s')) + ctx._source.kb_id.add('%s'); + """ % (doc.kb_id, doc.kb_id), + idxnm=search.index_name( + kb.tenant_id) + ) + return get_json_result(data=True) + except Exception as e: + return server_error_response(e) + + +@manager.route('/rm', methods=['POST']) +@login_required +@validate_request("doc_id") +def rm(): + req = request.json + try: + e, doc = DocumentService.get_by_id(req["doc_id"]) + if not e: + return get_data_error_result(retmsg="Document not found!") + if not DocumentService.delete_by_id(req["doc_id"]): + return get_data_error_result( + retmsg="Database error (Document removal)!") + e, kb = KnowledgebaseService.get_by_id(doc.kb_id) + MINIO.rm(kb.id, doc.location) + return get_json_result(data=True) + except Exception as e: + return server_error_response(e) + + +@manager.route('/rename', methods=['POST']) +@login_required +@validate_request("doc_id", "name", "old_name") +def rename(): + req = request.json + if pathlib.Path(req["name"].lower()).suffix != pathlib.Path( + req["old_name"].lower()).suffix: + get_json_result( + data=False, + retmsg="The extension of file can't be changed", + retcode=RetCode.ARGUMENT_ERROR) + + try: + e, doc = DocumentService.get_by_id(req["doc_id"]) + if not e: + return get_data_error_result(retmsg="Document not found!") + if DocumentService.query(name=req["name"], kb_id=doc.kb_id): + return get_data_error_result( + retmsg="Duplicated document name in the same knowledgebase.") + + if not DocumentService.update_by_id( + req["doc_id"], {"name": req["name"]}): + return get_data_error_result( + retmsg="Database error (Document rename)!") + + return get_json_result(data=True) + except Exception as e: + return server_error_response(e) diff --git a/web_server/apps/kb_app.py b/web_server/apps/kb_app.py new file mode 100644 index 0000000000000000000000000000000000000000..c035cb6375e2e0a03d949f33743e768ca1011523 --- /dev/null +++ b/web_server/apps/kb_app.py @@ -0,0 +1,102 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from flask import request +from flask_login import login_required, current_user + +from web_server.db.services import duplicate_name +from web_server.db.services.user_service import TenantService, UserTenantService +from web_server.utils.api_utils import server_error_response, get_data_error_result, validate_request +from web_server.utils import get_uuid, get_format_time +from web_server.db import StatusEnum, UserTenantRole +from web_server.db.services.kb_service import KnowledgebaseService +from web_server.db.db_models import Knowledgebase +from web_server.settings import stat_logger, RetCode +from web_server.utils.api_utils import get_json_result + + +@manager.route('/create', methods=['post']) +@login_required +@validate_request("name", "description", "permission", "embd_id", "parser_id") +def create(): + req = request.json + req["name"] = req["name"].strip() + req["name"] = duplicate_name(KnowledgebaseService.query, name=req["name"], tenant_id=current_user.id, status=StatusEnum.VALID.value) + try: + req["id"] = get_uuid() + req["tenant_id"] = current_user.id + req["created_by"] = current_user.id + if not KnowledgebaseService.save(**req): return get_data_error_result() + return get_json_result(data={"kb_id": req["id"]}) + except Exception as e: + return server_error_response(e) + + +@manager.route('/update', methods=['post']) +@login_required +@validate_request("kb_id", "name", "description", "permission", "embd_id", "parser_id") +def update(): + req = request.json + req["name"] = req["name"].strip() + try: + if not KnowledgebaseService.query(created_by=current_user.id, id=req["kb_id"]): + return get_json_result(data=False, retmsg=f'Only owner of knowledgebase authorized for this operation.', retcode=RetCode.OPERATING_ERROR) + + e, kb = KnowledgebaseService.get_by_id(req["kb_id"]) + if not e: return get_data_error_result(retmsg="Can't find this knowledgebase!") + + if req["name"].lower() != kb.name.lower() \ + and len(KnowledgebaseService.query(name=req["name"], tenant_id=current_user.id, status=StatusEnum.VALID.value))>1: + return get_data_error_result(retmsg="Duplicated knowledgebase name.") + + del req["kb_id"] + if not KnowledgebaseService.update_by_id(kb.id, req): return get_data_error_result() + + e, kb = KnowledgebaseService.get_by_id(kb.id) + if not e: return get_data_error_result(retmsg="Database error (Knowledgebase rename)!") + + return get_json_result(data=kb.to_json()) + except Exception as e: + return server_error_response(e) + + +@manager.route('/list', methods=['GET']) +@login_required +def list(): + page_number = request.args.get("page", 1) + items_per_page = request.args.get("page_size", 15) + orderby = request.args.get("orderby", "create_time") + desc = request.args.get("desc", True) + try: + tenants = TenantService.get_joined_tenants_by_user_id(current_user.id) + kbs = KnowledgebaseService.get_by_tenant_ids([m["tenant_id"] for m in tenants], current_user.id, page_number, items_per_page, orderby, desc) + return get_json_result(data=kbs) + except Exception as e: + return server_error_response(e) + + +@manager.route('/rm', methods=['post']) +@login_required +@validate_request("kb_id") +def rm(): + req = request.json + try: + if not KnowledgebaseService.query(created_by=current_user.id, id=req["kb_id"]): + return get_json_result(data=False, retmsg=f'Only owner of knowledgebase authorized for this operation.', retcode=RetCode.OPERATING_ERROR) + + if not KnowledgebaseService.update_by_id(req["kb_id"], {"status": StatusEnum.IN_VALID.value}): return get_data_error_result(retmsg="Database error (Knowledgebase removal)!") + return get_json_result(data=True) + except Exception as e: + return server_error_response(e) \ No newline at end of file diff --git a/web_server/apps/user_app.py b/web_server/apps/user_app.py new file mode 100644 index 0000000000000000000000000000000000000000..aa2ba43cec4058d2ff257421947a652777be067f --- /dev/null +++ b/web_server/apps/user_app.py @@ -0,0 +1,226 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from flask import request, session, redirect, url_for +from werkzeug.security import generate_password_hash, check_password_hash +from flask_login import login_required, current_user, login_user, logout_user +from web_server.utils.api_utils import server_error_response, validate_request +from web_server.utils import get_uuid, get_format_time, decrypt, download_img +from web_server.db import UserTenantRole +from web_server.settings import RetCode, GITHUB_OAUTH, CHAT_MDL, EMBEDDING_MDL, ASR_MDL, IMAGE2TEXT_MDL, PARSERS +from web_server.db.services.user_service import UserService, TenantService, UserTenantService +from web_server.settings import stat_logger +from web_server.utils.api_utils import get_json_result, cors_reponse + + +@manager.route('/login', methods=['POST', 'GET']) +def login(): + userinfo = None + login_channel = "password" + if session.get("access_token"): + login_channel = session["access_token_from"] + if session["access_token_from"] == "github": + userinfo = user_info_from_github(session["access_token"]) + elif not request.json: + return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR, + retmsg='Unautherized!') + + email = request.json.get('email') if not userinfo else userinfo["email"] + users = UserService.query(email=email) + if not users: + if request.json is not None: + return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR, retmsg=f'This Email is not registered!') + avatar = "" + try: + avatar = download_img(userinfo["avatar_url"]) + except Exception as e: + stat_logger.exception(e) + try: + users = user_register({ + "access_token": session["access_token"], + "email": userinfo["email"], + "avatar": avatar, + "nickname": userinfo["login"], + "login_channel": login_channel, + "last_login_time": get_format_time(), + "is_superuser": False, + }) + if not users: raise Exception('Register user failure.') + if len(users) > 1: raise Exception('Same E-mail exist!') + user = users[0] + login_user(user) + return cors_reponse(data=user.to_json(), auth=user.get_id(), retmsg="Welcome back!") + except Exception as e: + stat_logger.exception(e) + return server_error_response(e) + elif not request.json: + login_user(users[0]) + return cors_reponse(data=users[0].to_json(), auth=users[0].get_id(), retmsg="Welcome back!") + + password = request.json.get('password') + try: + password = decrypt(password) + except: + return get_json_result(data=False, retcode=RetCode.SERVER_ERROR, retmsg='Fail to crypt password') + + user = UserService.query_user(email, password) + if user: + response_data = user.to_json() + user.access_token = get_uuid() + login_user(user) + user.save() + msg = "Welcome back!" + return cors_reponse(data=response_data, auth=user.get_id(), retmsg=msg) + else: + return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR, retmsg='Email and Password do not match!') + + +@manager.route('/github_callback', methods=['GET']) +def github_callback(): + try: + import requests + res = requests.post(GITHUB_OAUTH.get("url"), data={ + "client_id": GITHUB_OAUTH.get("client_id"), + "client_secret": GITHUB_OAUTH.get("secret_key"), + "code": request.args.get('code') + },headers={"Accept": "application/json"}) + res = res.json() + if "error" in res: + return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR, + retmsg=res["error_description"]) + + if "user:email" not in res["scope"].split(","): + return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR, retmsg='user:email not in scope') + + session["access_token"] = res["access_token"] + session["access_token_from"] = "github" + return redirect(url_for("user.login"), code=307) + + except Exception as e: + stat_logger.exception(e) + return server_error_response(e) + + +def user_info_from_github(access_token): + import requests + headers = {"Accept": "application/json", 'Authorization': f"token {access_token}"} + res = requests.get(f"https://api.github.com/user?access_token={access_token}", headers=headers) + user_info = res.json() + email_info = requests.get(f"https://api.github.com/user/emails?access_token={access_token}", headers=headers).json() + user_info["email"] = next((email for email in email_info if email['primary'] == True), None)["email"] + return user_info + + +@manager.route("/logout", methods=['GET']) +@login_required +def log_out(): + current_user.access_token = "" + current_user.save() + logout_user() + return get_json_result(data=True) + + +@manager.route("/setting", methods=["POST"]) +@login_required +def setting_user(): + update_dict = {} + request_data = request.json + if request_data.get("password"): + new_password = request_data.get("new_password") + if not check_password_hash(current_user.password, decrypt(request_data["password"])): + return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR, retmsg='Password error!') + + if new_password: update_dict["password"] = generate_password_hash(decrypt(new_password)) + + for k in request_data.keys(): + if k in ["password", "new_password"]:continue + update_dict[k] = request_data[k] + + try: + UserService.update_by_id(current_user.id, update_dict) + return get_json_result(data=True) + except Exception as e: + stat_logger.exception(e) + return get_json_result(data=False, retmsg='Update failure!', retcode=RetCode.EXCEPTION_ERROR) + + +@manager.route("/info", methods=["GET"]) +@login_required +def user_info(): + return get_json_result(data=current_user.to_dict()) + + +def user_register(user): + user_id = get_uuid() + user["id"] = user_id + tenant = { + "id": user_id, + "name": user["nickname"] + "‘s Kingdom", + "llm_id": CHAT_MDL, + "embd_id": EMBEDDING_MDL, + "asr_id": ASR_MDL, + "parser_ids": PARSERS, + "img2txt_id": IMAGE2TEXT_MDL + } + usr_tenant = { + "tenant_id": user_id, + "user_id": user_id, + "invited_by": user_id, + "role": UserTenantRole.OWNER + } + + if not UserService.save(**user):return + TenantService.save(**tenant) + UserTenantService.save(**usr_tenant) + return UserService.query(email=user["email"]) + + +@manager.route("/register", methods=["POST"]) +@validate_request("nickname", "email", "password") +def user_add(): + req = request.json + if UserService.query(email=req["email"]): + return get_json_result(data=False, retmsg=f'Email: {req["email"]} has already registered!', retcode=RetCode.OPERATING_ERROR) + + user_dict = { + "access_token": get_uuid(), + "email": req["email"], + "nickname": req["nickname"], + "password": decrypt(req["password"]), + "login_channel": "password", + "last_login_time": get_format_time(), + "is_superuser": False, + } + try: + users = user_register(user_dict) + if not users: raise Exception('Register user failure.') + if len(users) > 1: raise Exception('Same E-mail exist!') + user = users[0] + login_user(user) + return cors_reponse(data=user.to_json(), auth=user.get_id(), retmsg="Welcome aboard!") + except Exception as e: + stat_logger.exception(e) + return get_json_result(data=False, retmsg='User registration failure!', retcode=RetCode.EXCEPTION_ERROR) + + + +@manager.route("/tenant_info", methods=["GET"]) +@login_required +def tenant_info(): + try: + tenants = TenantService.get_by_user_id(current_user.id) + return get_json_result(data=tenants) + except Exception as e: + return server_error_response(e) diff --git a/web_server/db/__init__.py b/web_server/db/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9984299c1c710dd9d316d6a0c571ce852ad3985c --- /dev/null +++ b/web_server/db/__init__.py @@ -0,0 +1,54 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from enum import Enum +from enum import IntEnum +from strenum import StrEnum + + +class StatusEnum(Enum): + VALID = "1" + IN_VALID = "0" + + +class UserTenantRole(StrEnum): + OWNER = 'owner' + ADMIN = 'admin' + NORMAL = 'normal' + + +class TenantPermission(StrEnum): + ME = 'me' + TEAM = 'team' + + +class SerializedType(IntEnum): + PICKLE = 1 + JSON = 2 + + +class FileType(StrEnum): + PDF = 'pdf' + DOC = 'doc' + VISUAL = 'visual' + AURAL = 'aural' + VIRTUAL = 'virtual' + + +class LLMType(StrEnum): + CHAT = 'chat' + EMBEDDING = 'embedding' + SPEECH2TEXT = 'speech2text' + IMAGE2TEXT = 'image2text' \ No newline at end of file diff --git a/web_server/db/db_models.py b/web_server/db/db_models.py new file mode 100644 index 0000000000000000000000000000000000000000..b6761680369220103a3a4c1ad33bfaa2a057c8d2 --- /dev/null +++ b/web_server/db/db_models.py @@ -0,0 +1,616 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import inspect +import os +import sys +import typing +import operator +from functools import wraps +from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer +from flask_login import UserMixin + +from peewee import ( + BigAutoField, BigIntegerField, BooleanField, CharField, + CompositeKey, Insert, IntegerField, TextField, FloatField, DateTimeField, + Field, Model, Metadata +) +from playhouse.pool import PooledMySQLDatabase + +from web_server.db import SerializedType +from web_server.settings import DATABASE, stat_logger, SECRET_KEY +from web_server.utils.log_utils import getLogger +from web_server import utils + +LOGGER = getLogger() + + +def singleton(cls, *args, **kw): + instances = {} + + def _singleton(): + key = str(cls) + str(os.getpid()) + if key not in instances: + instances[key] = cls(*args, **kw) + return instances[key] + + return _singleton + + +CONTINUOUS_FIELD_TYPE = {IntegerField, FloatField, DateTimeField} +AUTO_DATE_TIMESTAMP_FIELD_PREFIX = {"create", "start", "end", "update", "read_access", "write_access"} + + +class LongTextField(TextField): + field_type = 'LONGTEXT' + + +class JSONField(LongTextField): + default_value = {} + + def __init__(self, object_hook=None, object_pairs_hook=None, **kwargs): + self._object_hook = object_hook + self._object_pairs_hook = object_pairs_hook + super().__init__(**kwargs) + + def db_value(self, value): + if value is None: + value = self.default_value + return utils.json_dumps(value) + + def python_value(self, value): + if not value: + return self.default_value + return utils.json_loads(value, object_hook=self._object_hook, object_pairs_hook=self._object_pairs_hook) + + +class ListField(JSONField): + default_value = [] + + +class SerializedField(LongTextField): + def __init__(self, serialized_type=SerializedType.PICKLE, object_hook=None, object_pairs_hook=None, **kwargs): + self._serialized_type = serialized_type + self._object_hook = object_hook + self._object_pairs_hook = object_pairs_hook + super().__init__(**kwargs) + + def db_value(self, value): + if self._serialized_type == SerializedType.PICKLE: + return utils.serialize_b64(value, to_str=True) + elif self._serialized_type == SerializedType.JSON: + if value is None: + return None + return utils.json_dumps(value, with_type=True) + else: + raise ValueError(f"the serialized type {self._serialized_type} is not supported") + + def python_value(self, value): + if self._serialized_type == SerializedType.PICKLE: + return utils.deserialize_b64(value) + elif self._serialized_type == SerializedType.JSON: + if value is None: + return {} + return utils.json_loads(value, object_hook=self._object_hook, object_pairs_hook=self._object_pairs_hook) + else: + raise ValueError(f"the serialized type {self._serialized_type} is not supported") + + +def is_continuous_field(cls: typing.Type) -> bool: + if cls in CONTINUOUS_FIELD_TYPE: + return True + for p in cls.__bases__: + if p in CONTINUOUS_FIELD_TYPE: + return True + elif p != Field and p != object: + if is_continuous_field(p): + return True + else: + return False + + +def auto_date_timestamp_field(): + return {f"{f}_time" for f in AUTO_DATE_TIMESTAMP_FIELD_PREFIX} + + +def auto_date_timestamp_db_field(): + return {f"f_{f}_time" for f in AUTO_DATE_TIMESTAMP_FIELD_PREFIX} + + +def remove_field_name_prefix(field_name): + return field_name[2:] if field_name.startswith('f_') else field_name + + +class BaseModel(Model): + create_time = BigIntegerField(null=True) + create_date = DateTimeField(null=True) + update_time = BigIntegerField(null=True) + update_date = DateTimeField(null=True) + + def to_json(self): + # This function is obsolete + return self.to_dict() + + def to_dict(self): + return self.__dict__['__data__'] + + def to_human_model_dict(self, only_primary_with: list = None): + model_dict = self.__dict__['__data__'] + + if not only_primary_with: + return {remove_field_name_prefix(k): v for k, v in model_dict.items()} + + human_model_dict = {} + for k in self._meta.primary_key.field_names: + human_model_dict[remove_field_name_prefix(k)] = model_dict[k] + for k in only_primary_with: + human_model_dict[k] = model_dict[f'f_{k}'] + return human_model_dict + + @property + def meta(self) -> Metadata: + return self._meta + + @classmethod + def get_primary_keys_name(cls): + return cls._meta.primary_key.field_names if isinstance(cls._meta.primary_key, CompositeKey) else [ + cls._meta.primary_key.name] + + @classmethod + def getter_by(cls, attr): + return operator.attrgetter(attr)(cls) + + @classmethod + def query(cls, reverse=None, order_by=None, **kwargs): + filters = [] + for f_n, f_v in kwargs.items(): + attr_name = '%s' % f_n + if not hasattr(cls, attr_name) or f_v is None: + continue + if type(f_v) in {list, set}: + f_v = list(f_v) + if is_continuous_field(type(getattr(cls, attr_name))): + if len(f_v) == 2: + for i, v in enumerate(f_v): + if isinstance(v, str) and f_n in auto_date_timestamp_field(): + # time type: %Y-%m-%d %H:%M:%S + f_v[i] = utils.date_string_to_timestamp(v) + lt_value = f_v[0] + gt_value = f_v[1] + if lt_value is not None and gt_value is not None: + filters.append(cls.getter_by(attr_name).between(lt_value, gt_value)) + elif lt_value is not None: + filters.append(operator.attrgetter(attr_name)(cls) >= lt_value) + elif gt_value is not None: + filters.append(operator.attrgetter(attr_name)(cls) <= gt_value) + else: + filters.append(operator.attrgetter(attr_name)(cls) << f_v) + else: + filters.append(operator.attrgetter(attr_name)(cls) == f_v) + if filters: + query_records = cls.select().where(*filters) + if reverse is not None: + if not order_by or not hasattr(cls, f"{order_by}"): + order_by = "create_time" + if reverse is True: + query_records = query_records.order_by(cls.getter_by(f"{order_by}").desc()) + elif reverse is False: + query_records = query_records.order_by(cls.getter_by(f"{order_by}").asc()) + return [query_record for query_record in query_records] + else: + return [] + + @classmethod + def insert(cls, __data=None, **insert): + if isinstance(__data, dict) and __data: + __data[cls._meta.combined["create_time"]] = utils.current_timestamp() + if insert: + insert["create_time"] = utils.current_timestamp() + + return super().insert(__data, **insert) + + # update and insert will call this method + @classmethod + def _normalize_data(cls, data, kwargs): + normalized = super()._normalize_data(data, kwargs) + if not normalized: + return {} + + normalized[cls._meta.combined["update_time"]] = utils.current_timestamp() + + for f_n in AUTO_DATE_TIMESTAMP_FIELD_PREFIX: + if {f"{f_n}_time", f"{f_n}_date"}.issubset(cls._meta.combined.keys()) and \ + cls._meta.combined[f"{f_n}_time"] in normalized and \ + normalized[cls._meta.combined[f"{f_n}_time"]] is not None: + normalized[cls._meta.combined[f"{f_n}_date"]] = utils.timestamp_to_date( + normalized[cls._meta.combined[f"{f_n}_time"]]) + + return normalized + + +class JsonSerializedField(SerializedField): + def __init__(self, object_hook=utils.from_dict_hook, object_pairs_hook=None, **kwargs): + super(JsonSerializedField, self).__init__(serialized_type=SerializedType.JSON, object_hook=object_hook, + object_pairs_hook=object_pairs_hook, **kwargs) + + +@singleton +class BaseDataBase: + def __init__(self): + database_config = DATABASE.copy() + db_name = database_config.pop("name") + self.database_connection = PooledMySQLDatabase(db_name, **database_config) + stat_logger.info('init mysql database on cluster mode successfully') + + +class DatabaseLock: + def __init__(self, lock_name, timeout=10, db=None): + self.lock_name = lock_name + self.timeout = int(timeout) + self.db = db if db else DB + + def lock(self): + # SQL parameters only support %s format placeholders + cursor = self.db.execute_sql("SELECT GET_LOCK(%s, %s)", (self.lock_name, self.timeout)) + ret = cursor.fetchone() + if ret[0] == 0: + raise Exception(f'acquire mysql lock {self.lock_name} timeout') + elif ret[0] == 1: + return True + else: + raise Exception(f'failed to acquire lock {self.lock_name}') + + def unlock(self): + cursor = self.db.execute_sql("SELECT RELEASE_LOCK(%s)", (self.lock_name,)) + ret = cursor.fetchone() + if ret[0] == 0: + raise Exception(f'mysql lock {self.lock_name} was not established by this thread') + elif ret[0] == 1: + return True + else: + raise Exception(f'mysql lock {self.lock_name} does not exist') + + def __enter__(self): + if isinstance(self.db, PooledMySQLDatabase): + self.lock() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if isinstance(self.db, PooledMySQLDatabase): + self.unlock() + + def __call__(self, func): + @wraps(func) + def magic(*args, **kwargs): + with self: + return func(*args, **kwargs) + + return magic + + +DB = BaseDataBase().database_connection +DB.lock = DatabaseLock + + +def close_connection(): + try: + if DB: + DB.close() + except Exception as e: + LOGGER.exception(e) + + +class DataBaseModel(BaseModel): + class Meta: + database = DB + + +@DB.connection_context() +def init_database_tables(): + members = inspect.getmembers(sys.modules[__name__], inspect.isclass) + table_objs = [] + create_failed_list = [] + for name, obj in members: + if obj != DataBaseModel and issubclass(obj, DataBaseModel): + table_objs.append(obj) + LOGGER.info(f"start create table {obj.__name__}") + try: + obj.create_table() + LOGGER.info(f"create table success: {obj.__name__}") + except Exception as e: + LOGGER.exception(e) + create_failed_list.append(obj.__name__) + if create_failed_list: + LOGGER.info(f"create tables failed: {create_failed_list}") + raise Exception(f"create tables failed: {create_failed_list}") + + +def fill_db_model_object(model_object, human_model_dict): + for k, v in human_model_dict.items(): + attr_name = '%s' % k + if hasattr(model_object.__class__, attr_name): + setattr(model_object, attr_name, v) + return model_object + + +class User(DataBaseModel, UserMixin): + id = CharField(max_length=32, primary_key=True) + access_token = CharField(max_length=255, null=True) + nickname = CharField(max_length=100, null=False, help_text="nicky name") + password = CharField(max_length=255, null=True, help_text="password") + email = CharField(max_length=255, null=False, help_text="email", index=True) + avatar = TextField(null=True, help_text="avatar base64 string") + language = CharField(max_length=32, null=True, help_text="English|Chinese", default="Chinese") + color_schema = CharField(max_length=32, null=True, help_text="Bright|Dark", default="Dark") + last_login_time = DateTimeField(null=True) + is_authenticated = CharField(max_length=1, null=False, default="1") + is_active = CharField(max_length=1, null=False, default="1") + is_anonymous = CharField(max_length=1, null=False, default="0") + login_channel = CharField(null=True, help_text="from which user login") + status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1") + is_superuser = BooleanField(null=True, help_text="is root", default=False) + + def __str__(self): + return self.email + + def get_id(self): + jwt = Serializer(secret_key=SECRET_KEY) + return jwt.dumps(str(self.access_token)) + + class Meta: + db_table = "user" + + +class Tenant(DataBaseModel): + id = CharField(max_length=32, primary_key=True) + name = CharField(max_length=100, null=True, help_text="Tenant name") + public_key = CharField(max_length=255, null=True) + llm_id = CharField(max_length=128, null=False, help_text="default llm ID") + embd_id = CharField(max_length=128, null=False, help_text="default embedding model ID") + asr_id = CharField(max_length=128, null=False, help_text="default ASR model ID") + img2txt_id = CharField(max_length=128, null=False, help_text="default image to text model ID") + parser_ids = CharField(max_length=128, null=False, help_text="default image to text model ID") + status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1") + + class Meta: + db_table = "tenant" + + +class UserTenant(DataBaseModel): + id = CharField(max_length=32, primary_key=True) + user_id = CharField(max_length=32, null=False) + tenant_id = CharField(max_length=32, null=False) + role = CharField(max_length=32, null=False, help_text="UserTenantRole") + invited_by = CharField(max_length=32, null=False) + status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1") + + class Meta: + db_table = "user_tenant" + + +class InvitationCode(DataBaseModel): + id = CharField(max_length=32, primary_key=True) + code = CharField(max_length=32, null=False) + visit_time = DateTimeField(null=True) + user_id = CharField(max_length=32, null=True) + tenant_id = CharField(max_length=32, null=True) + status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1") + + class Meta: + db_table = "invitation_code" + + +class LLMFactories(DataBaseModel): + name = CharField(max_length=128, null=False, help_text="LLM factory name", primary_key=True) + logo = TextField(null=True, help_text="llm logo base64") + tags = CharField(max_length=255, null=False, help_text="LLM, Text Embedding, Image2Text, ASR") + status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1") + + def __str__(self): + return self.name + + class Meta: + db_table = "llm_factories" + + +class LLM(DataBaseModel): + # defautlt LLMs for every users + llm_name = CharField(max_length=128, null=False, help_text="LLM name", primary_key=True) + fid = CharField(max_length=128, null=False, help_text="LLM factory id") + tags = CharField(max_length=255, null=False, help_text="LLM, Text Embedding, Image2Text, Chat, 32k...") + status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1") + + def __str__(self): + return self.llm_name + + class Meta: + db_table = "llm" + + +class TenantLLM(DataBaseModel): + tenant_id = CharField(max_length=32, null=False) + llm_factory = CharField(max_length=128, null=False, help_text="LLM factory name") + model_type = CharField(max_length=128, null=False, help_text="LLM, Text Embedding, Image2Text, ASR") + llm_name = CharField(max_length=128, null=False, help_text="LLM name") + api_key = CharField(max_length=255, null=True, help_text="API KEY") + api_base = CharField(max_length=255, null=True, help_text="API Base") + + def __str__(self): + return self.llm_name + + class Meta: + db_table = "tenant_llm" + primary_key = CompositeKey('tenant_id', 'llm_factory') + + +class Knowledgebase(DataBaseModel): + id = CharField(max_length=32, primary_key=True) + avatar = TextField(null=True, help_text="avatar base64 string") + tenant_id = CharField(max_length=32, null=False) + name = CharField(max_length=128, null=False, help_text="KB name", index=True) + description = TextField(null=True, help_text="KB description") + permission = CharField(max_length=16, null=False, help_text="me|team") + created_by = CharField(max_length=32, null=False) + doc_num = IntegerField(default=0) + embd_id = CharField(max_length=32, null=False, help_text="default embedding model ID") + parser_id = CharField(max_length=32, null=False, help_text="default parser ID") + status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1") + + def __str__(self): + return self.name + + class Meta: + db_table = "knowledgebase" + + +class Document(DataBaseModel): + id = CharField(max_length=32, primary_key=True) + thumbnail = TextField(null=True, help_text="thumbnail base64 string") + kb_id = CharField(max_length=256, null=False, index=True) + parser_id = CharField(max_length=32, null=False, help_text="default parser ID") + source_type = CharField(max_length=128, null=False, default="local", help_text="where dose this document from") + type = CharField(max_length=32, null=False, help_text="file extension") + created_by = CharField(max_length=32, null=False, help_text="who created it") + name = CharField(max_length=255, null=True, help_text="file name", index=True) + location = CharField(max_length=255, null=True, help_text="where dose it store") + size = IntegerField(default=0) + token_num = IntegerField(default=0) + chunk_num = IntegerField(default=0) + progress = FloatField(default=0) + progress_msg = CharField(max_length=255, null=True, help_text="process message", default="") + process_begin_at = DateTimeField(null=True) + process_duation = FloatField(default=0) + status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1") + + class Meta: + db_table = "document" + + +class Dialog(DataBaseModel): + id = CharField(max_length=32, primary_key=True) + tenant_id = CharField(max_length=32, null=False) + name = CharField(max_length=255, null=True, help_text="dialog application name") + description = TextField(null=True, help_text="Dialog description") + icon = CharField(max_length=16, null=False, help_text="dialog icon") + language = CharField(max_length=32, null=True, default="Chinese", help_text="English|Chinese") + llm_id = CharField(max_length=32, null=False, help_text="default llm ID") + llm_setting_type = CharField(max_length=8, null=False, help_text="Creative|Precise|Evenly|Custom", + default="Creative") + llm_setting = JSONField(null=False, default={"temperature": 0.1, "top_p": 0.3, "frequency_penalty": 0.7, + "presence_penalty": 0.4, "max_tokens": 215}) + prompt_type = CharField(max_length=16, null=False, default="simple", help_text="simple|advanced") + prompt_config = JSONField(null=False, default={"system": "", "prologue": "您好,我是您的助手小樱,长得可爱又善良,can I help you?", + "parameters": [], "empty_response": "Sorry! 知识库中未找到相关内容!"}) + status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1") + + class Meta: + db_table = "dialog" + + +class DialogKb(DataBaseModel): + dialog_id = CharField(max_length=32, null=False, index=True) + kb_id = CharField(max_length=32, null=False) + + class Meta: + db_table = "dialog_kb" + primary_key = CompositeKey('dialog_id', 'kb_id') + + +class Conversation(DataBaseModel): + id = CharField(max_length=32, primary_key=True) + dialog_id = CharField(max_length=32, null=False, index=True) + name = CharField(max_length=255, null=True, help_text="converastion name") + message = JSONField(null=True) + + class Meta: + db_table = "conversation" + + +""" +class Job(DataBaseModel): + # multi-party common configuration + f_user_id = CharField(max_length=25, null=True) + f_job_id = CharField(max_length=25, index=True) + f_name = CharField(max_length=500, null=True, default='') + f_description = TextField(null=True, default='') + f_tag = CharField(max_length=50, null=True, default='') + f_dsl = JSONField() + f_runtime_conf = JSONField() + f_runtime_conf_on_party = JSONField() + f_train_runtime_conf = JSONField(null=True) + f_roles = JSONField() + f_initiator_role = CharField(max_length=50) + f_initiator_party_id = CharField(max_length=50) + f_status = CharField(max_length=50) + f_status_code = IntegerField(null=True) + f_user = JSONField() + # this party configuration + f_role = CharField(max_length=50, index=True) + f_party_id = CharField(max_length=10, index=True) + f_is_initiator = BooleanField(null=True, default=False) + f_progress = IntegerField(null=True, default=0) + f_ready_signal = BooleanField(default=False) + f_ready_time = BigIntegerField(null=True) + f_cancel_signal = BooleanField(default=False) + f_cancel_time = BigIntegerField(null=True) + f_rerun_signal = BooleanField(default=False) + f_end_scheduling_updates = IntegerField(null=True, default=0) + + f_engine_name = CharField(max_length=50, null=True) + f_engine_type = CharField(max_length=10, null=True) + f_cores = IntegerField(default=0) + f_memory = IntegerField(default=0) # MB + f_remaining_cores = IntegerField(default=0) + f_remaining_memory = IntegerField(default=0) # MB + f_resource_in_use = BooleanField(default=False) + f_apply_resource_time = BigIntegerField(null=True) + f_return_resource_time = BigIntegerField(null=True) + + f_inheritance_info = JSONField(null=True) + f_inheritance_status = CharField(max_length=50, null=True) + + f_start_time = BigIntegerField(null=True) + f_start_date = DateTimeField(null=True) + f_end_time = BigIntegerField(null=True) + f_end_date = DateTimeField(null=True) + f_elapsed = BigIntegerField(null=True) + + class Meta: + db_table = "t_job" + primary_key = CompositeKey('f_job_id', 'f_role', 'f_party_id') + + + +class PipelineComponentMeta(DataBaseModel): + f_model_id = CharField(max_length=100, index=True) + f_model_version = CharField(max_length=100, index=True) + f_role = CharField(max_length=50, index=True) + f_party_id = CharField(max_length=10, index=True) + f_component_name = CharField(max_length=100, index=True) + f_component_module_name = CharField(max_length=100) + f_model_alias = CharField(max_length=100, index=True) + f_model_proto_index = JSONField(null=True) + f_run_parameters = JSONField(null=True) + f_archive_sha256 = CharField(max_length=100, null=True) + f_archive_from_ip = CharField(max_length=100, null=True) + + class Meta: + db_table = 't_pipeline_component_meta' + indexes = ( + (('f_model_id', 'f_model_version', 'f_role', 'f_party_id', 'f_component_name'), True), + ) + + +""" diff --git a/web_server/db/db_services.py b/web_server/db/db_services.py new file mode 100644 index 0000000000000000000000000000000000000000..9f8a0a02a0e17ea4961f3ecec1927b15887f0a03 --- /dev/null +++ b/web_server/db/db_services.py @@ -0,0 +1,157 @@ +# +# Copyright 2021 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import abc +import json +import time +from functools import wraps +from shortuuid import ShortUUID + +from web_server.versions import get_fate_version + +from web_server.errors.error_services import * +from web_server.settings import ( + GRPC_PORT, HOST, HTTP_PORT, + RANDOM_INSTANCE_ID, stat_logger, +) + + +instance_id = ShortUUID().random(length=8) if RANDOM_INSTANCE_ID else f'flow-{HOST}-{HTTP_PORT}' +server_instance = ( + f'{HOST}:{GRPC_PORT}', + json.dumps({ + 'instance_id': instance_id, + 'timestamp': round(time.time() * 1000), + 'version': get_fate_version() or '', + 'host': HOST, + 'grpc_port': GRPC_PORT, + 'http_port': HTTP_PORT, + }), +) + + +def check_service_supported(method): + """Decorator to check if `service_name` is supported. + The attribute `supported_services` MUST be defined in class. + The first and second arguments of `method` MUST be `self` and `service_name`. + + :param Callable method: The class method. + :return: The inner wrapper function. + :rtype: Callable + """ + @wraps(method) + def magic(self, service_name, *args, **kwargs): + if service_name not in self.supported_services: + raise ServiceNotSupported(service_name=service_name) + return method(self, service_name, *args, **kwargs) + return magic + + +class ServicesDB(abc.ABC): + """Database for storage service urls. + Abstract base class for the real backends. + + """ + @property + @abc.abstractmethod + def supported_services(self): + """The names of supported services. + The returned list SHOULD contain `fateflow` (model download) and `servings` (FATE-Serving). + + :return: The service names. + :rtype: list + """ + pass + + @abc.abstractmethod + def _get_serving(self): + pass + + def get_serving(self): + + try: + return self._get_serving() + except ServicesError as e: + stat_logger.exception(e) + return [] + + @abc.abstractmethod + def _insert(self, service_name, service_url, value=''): + pass + + @check_service_supported + def insert(self, service_name, service_url, value=''): + """Insert a service url to database. + + :param str service_name: The service name. + :param str service_url: The service url. + :return: None + """ + try: + self._insert(service_name, service_url, value) + except ServicesError as e: + stat_logger.exception(e) + + @abc.abstractmethod + def _delete(self, service_name, service_url): + pass + + @check_service_supported + def delete(self, service_name, service_url): + """Delete a service url from database. + + :param str service_name: The service name. + :param str service_url: The service url. + :return: None + """ + try: + self._delete(service_name, service_url) + except ServicesError as e: + stat_logger.exception(e) + + def register_flow(self): + """Call `self.insert` for insert the flow server address to databae. + + :return: None + """ + self.insert('flow-server', *server_instance) + + def unregister_flow(self): + """Call `self.delete` for delete the flow server address from databae. + + :return: None + """ + self.delete('flow-server', server_instance[0]) + + @abc.abstractmethod + def _get_urls(self, service_name, with_values=False): + pass + + @check_service_supported + def get_urls(self, service_name, with_values=False): + """Query service urls from database. The urls may belong to other nodes. + Currently, only `fateflow` (model download) urls and `servings` (FATE-Serving) urls are supported. + `fateflow` is a url containing scheme, host, port and path, + while `servings` only contains host and port. + + :param str service_name: The service name. + :return: The service urls. + :rtype: list + """ + try: + return self._get_urls(service_name, with_values) + except ServicesError as e: + stat_logger.exception(e) + return [] diff --git a/web_server/db/db_utils.py b/web_server/db/db_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..df00ecb480114d0ebb83f484a3e5f9203f8a49fe --- /dev/null +++ b/web_server/db/db_utils.py @@ -0,0 +1,131 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import operator +from functools import reduce +from typing import Dict, Type, Union + +from web_server.utils import current_timestamp, timestamp_to_date + +from web_server.db.db_models import DB, DataBaseModel +from web_server.db.runtime_config import RuntimeConfig +from web_server.utils.log_utils import getLogger +from enum import Enum + + +LOGGER = getLogger() + + +@DB.connection_context() +def bulk_insert_into_db(model, data_source, replace_on_conflict=False): + DB.create_tables([model]) + + current_time = current_timestamp() + current_date = timestamp_to_date(current_time) + + for data in data_source: + if 'f_create_time' not in data: + data['f_create_time'] = current_time + data['f_create_date'] = timestamp_to_date(data['f_create_time']) + data['f_update_time'] = current_time + data['f_update_date'] = current_date + + preserve = tuple(data_source[0].keys() - {'f_create_time', 'f_create_date'}) + + batch_size = 50 if RuntimeConfig.USE_LOCAL_DATABASE else 1000 + + for i in range(0, len(data_source), batch_size): + with DB.atomic(): + query = model.insert_many(data_source[i:i + batch_size]) + if replace_on_conflict: + query = query.on_conflict(preserve=preserve) + query.execute() + + +def get_dynamic_db_model(base, job_id): + return type(base.model(table_index=get_dynamic_tracking_table_index(job_id=job_id))) + + +def get_dynamic_tracking_table_index(job_id): + return job_id[:8] + + +def fill_db_model_object(model_object, human_model_dict): + for k, v in human_model_dict.items(): + attr_name = 'f_%s' % k + if hasattr(model_object.__class__, attr_name): + setattr(model_object, attr_name, v) + return model_object + + +# https://docs.peewee-orm.com/en/latest/peewee/query_operators.html +supported_operators = { + '==': operator.eq, + '<': operator.lt, + '<=': operator.le, + '>': operator.gt, + '>=': operator.ge, + '!=': operator.ne, + '<<': operator.lshift, + '>>': operator.rshift, + '%': operator.mod, + '**': operator.pow, + '^': operator.xor, + '~': operator.inv, +} + +def query_dict2expression(model: Type[DataBaseModel], query: Dict[str, Union[bool, int, str, list, tuple]]): + expression = [] + + for field, value in query.items(): + if not isinstance(value, (list, tuple)): + value = ('==', value) + op, *val = value + + field = getattr(model, f'f_{field}') + value = supported_operators[op](field, val[0]) if op in supported_operators else getattr(field, op)(*val) + expression.append(value) + + return reduce(operator.iand, expression) + + +def query_db(model: Type[DataBaseModel], limit: int = 0, offset: int = 0, + query: dict = None, order_by: Union[str, list, tuple] = None): + data = model.select() + if query: + data = data.where(query_dict2expression(model, query)) + count = data.count() + + if not order_by: + order_by = 'create_time' + if not isinstance(order_by, (list, tuple)): + order_by = (order_by, 'asc') + order_by, order = order_by + order_by = getattr(model, f'f_{order_by}') + order_by = getattr(order_by, order)() + data = data.order_by(order_by) + + if limit > 0: + data = data.limit(limit) + if offset > 0: + data = data.offset(offset) + + return list(data), count + + +class StatusEnum(Enum): + # 样本可用状态 + VALID = "1" + IN_VALID = "0" \ No newline at end of file diff --git a/web_server/db/init_data.py b/web_server/db/init_data.py new file mode 100644 index 0000000000000000000000000000000000000000..882b62ef9559b19a3b96aba22cef321ff7a0e6ba --- /dev/null +++ b/web_server/db/init_data.py @@ -0,0 +1,141 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import time +import uuid + +from web_server.db import LLMType +from web_server.db.db_models import init_database_tables as init_web_db +from web_server.db.services import UserService +from web_server.db.services.llm_service import LLMFactoriesService, LLMService + + +def init_superuser(): + user_info = { + "id": uuid.uuid1().hex, + "password": "admin", + "nickname": "admin", + "is_superuser": True, + "email": "kai.hu@infiniflow.org", + "creator": "system", + "status": "1", + } + UserService.save(**user_info) + + +def init_llm_factory(): + factory_infos = [{ + "name": "OpenAI", + "logo": "", + "tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION", + "status": "1", + },{ + "name": "通义千问", + "logo": "", + "tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION", + "status": "1", + },{ + "name": "智普AI", + "logo": "", + "tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION", + "status": "1", + },{ + "name": "文心一言", + "logo": "", + "tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION", + "status": "1", + }, + ] + llm_infos = [{ + "fid": factory_infos[0]["name"], + "llm_name": "gpt-3.5-turbo", + "tags": "LLM,CHAT,4K", + "model_type": LLMType.CHAT.value + },{ + "fid": factory_infos[0]["name"], + "llm_name": "gpt-3.5-turbo-16k-0613", + "tags": "LLM,CHAT,16k", + "model_type": LLMType.CHAT.value + },{ + "fid": factory_infos[0]["name"], + "llm_name": "text-embedding-ada-002", + "tags": "TEXT EMBEDDING,8K", + "model_type": LLMType.EMBEDDING.value + },{ + "fid": factory_infos[0]["name"], + "llm_name": "whisper-1", + "tags": "SPEECH2TEXT", + "model_type": LLMType.SPEECH2TEXT.value + },{ + "fid": factory_infos[0]["name"], + "llm_name": "gpt-4", + "tags": "LLM,CHAT,8K", + "model_type": LLMType.CHAT.value + },{ + "fid": factory_infos[0]["name"], + "llm_name": "gpt-4-32k", + "tags": "LLM,CHAT,32K", + "model_type": LLMType.CHAT.value + },{ + "fid": factory_infos[0]["name"], + "llm_name": "gpt-4-vision-preview", + "tags": "LLM,CHAT,IMAGE2TEXT", + "model_type": LLMType.IMAGE2TEXT.value + },{ + "fid": factory_infos[1]["name"], + "llm_name": "qwen-turbo", + "tags": "LLM,CHAT,8K", + "model_type": LLMType.CHAT.value + },{ + "fid": factory_infos[1]["name"], + "llm_name": "qwen-plus", + "tags": "LLM,CHAT,32K", + "model_type": LLMType.CHAT.value + },{ + "fid": factory_infos[1]["name"], + "llm_name": "text-embedding-v2", + "tags": "TEXT EMBEDDING,2K", + "model_type": LLMType.EMBEDDING.value + },{ + "fid": factory_infos[1]["name"], + "llm_name": "paraformer-realtime-8k-v1", + "tags": "SPEECH2TEXT", + "model_type": LLMType.SPEECH2TEXT.value + },{ + "fid": factory_infos[1]["name"], + "llm_name": "qwen_vl_chat_v1", + "tags": "LLM,CHAT,IMAGE2TEXT", + "model_type": LLMType.IMAGE2TEXT.value + }, + ] + for info in factory_infos: + LLMFactoriesService.save(**info) + for info in llm_infos: + LLMService.save(**info) + + +def init_web_data(): + start_time = time.time() + if not UserService.get_all().count(): + init_superuser() + + if not LLMService.get_all().count():init_llm_factory() + + print("init web data success:{}".format(time.time() - start_time)) + + +if __name__ == '__main__': + init_web_db() + init_web_data() \ No newline at end of file diff --git a/web_server/db/operatioins.py b/web_server/db/operatioins.py new file mode 100644 index 0000000000000000000000000000000000000000..b8e1596fa87285d41ce396fdb1e7b959532d4cff --- /dev/null +++ b/web_server/db/operatioins.py @@ -0,0 +1,21 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import operator +import time +import typing +from web_server.utils.log_utils import sql_logger +import peewee \ No newline at end of file diff --git a/web_server/db/reload_config_base.py b/web_server/db/reload_config_base.py new file mode 100644 index 0000000000000000000000000000000000000000..049acbf76d2eac11e8d44929d1a7f4053010b4a6 --- /dev/null +++ b/web_server/db/reload_config_base.py @@ -0,0 +1,27 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +class ReloadConfigBase: + @classmethod + def get_all(cls): + configs = {} + for k, v in cls.__dict__.items(): + if not callable(getattr(cls, k)) and not k.startswith("__") and not k.startswith("_"): + configs[k] = v + return configs + + @classmethod + def get(cls, config_name): + return getattr(cls, config_name) if hasattr(cls, config_name) else None \ No newline at end of file diff --git a/web_server/db/runtime_config.py b/web_server/db/runtime_config.py new file mode 100644 index 0000000000000000000000000000000000000000..0955590897363719550ff49109f17aaa4516d4ec --- /dev/null +++ b/web_server/db/runtime_config.py @@ -0,0 +1,54 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from web_server.versions import get_versions +from .reload_config_base import ReloadConfigBase + + +class RuntimeConfig(ReloadConfigBase): + DEBUG = None + WORK_MODE = None + HTTP_PORT = None + JOB_SERVER_HOST = None + JOB_SERVER_VIP = None + ENV = dict() + SERVICE_DB = None + LOAD_CONFIG_MANAGER = False + + @classmethod + def init_config(cls, **kwargs): + for k, v in kwargs.items(): + if hasattr(cls, k): + setattr(cls, k, v) + + @classmethod + def init_env(cls): + cls.ENV.update(get_versions()) + + @classmethod + def load_config_manager(cls): + cls.LOAD_CONFIG_MANAGER = True + + @classmethod + def get_env(cls, key): + return cls.ENV.get(key, None) + + @classmethod + def get_all_env(cls): + return cls.ENV + + @classmethod + def set_service_db(cls, service_db): + cls.SERVICE_DB = service_db \ No newline at end of file diff --git a/web_server/db/service_registry.py b/web_server/db/service_registry.py new file mode 100644 index 0000000000000000000000000000000000000000..dc704be931968139a5a7d22289accbe76797bff1 --- /dev/null +++ b/web_server/db/service_registry.py @@ -0,0 +1,164 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import socket +from pathlib import Path +from web_server import utils +from .db_models import DB, ServiceRegistryInfo, ServerRegistryInfo +from .reload_config_base import ReloadConfigBase + + +class ServiceRegistry(ReloadConfigBase): + @classmethod + @DB.connection_context() + def load_service(cls, **kwargs) -> [ServiceRegistryInfo]: + service_registry_list = ServiceRegistryInfo.query(**kwargs) + return [service for service in service_registry_list] + + @classmethod + @DB.connection_context() + def save_service_info(cls, server_name, service_name, uri, method="POST", server_info=None, params=None, data=None, headers=None, protocol="http"): + if not server_info: + server_list = ServerRegistry.query_server_info_from_db(server_name=server_name) + if not server_list: + raise Exception(f"no found server {server_name}") + server_info = server_list[0] + url = f"{server_info.f_protocol}://{server_info.f_host}:{server_info.f_port}{uri}" + else: + url = f"{server_info.get('protocol', protocol)}://{server_info.get('host')}:{server_info.get('port')}{uri}" + service_info = { + "f_server_name": server_name, + "f_service_name": service_name, + "f_url": url, + "f_method": method, + "f_params": params if params else {}, + "f_data": data if data else {}, + "f_headers": headers if headers else {} + } + entity_model, status = ServiceRegistryInfo.get_or_create( + f_server_name=server_name, + f_service_name=service_name, + defaults=service_info) + if status is False: + for key in service_info: + setattr(entity_model, key, service_info[key]) + entity_model.save(force_insert=False) + + +class ServerRegistry(ReloadConfigBase): + FATEBOARD = None + FATE_ON_STANDALONE = None + FATE_ON_EGGROLL = None + FATE_ON_SPARK = None + MODEL_STORE_ADDRESS = None + SERVINGS = None + FATEMANAGER = None + STUDIO = None + + @classmethod + def load(cls): + cls.load_server_info_from_conf() + cls.load_server_info_from_db() + + @classmethod + def load_server_info_from_conf(cls): + path = Path(utils.file_utils.get_project_base_directory()) / 'conf' / utils.SERVICE_CONF + conf = utils.file_utils.load_yaml_conf(path) + if not isinstance(conf, dict): + raise ValueError('invalid config file') + + local_path = path.with_name(f'local.{utils.SERVICE_CONF}') + if local_path.exists(): + local_conf = utils.file_utils.load_yaml_conf(local_path) + if not isinstance(local_conf, dict): + raise ValueError('invalid local config file') + conf.update(local_conf) + for k, v in conf.items(): + if isinstance(v, dict): + setattr(cls, k.upper(), v) + + @classmethod + def register(cls, server_name, server_info): + cls.save_server_info_to_db(server_name, server_info.get("host"), server_info.get("port"), protocol=server_info.get("protocol", "http")) + setattr(cls, server_name, server_info) + + @classmethod + def save(cls, service_config): + update_server = {} + for server_name, server_info in service_config.items(): + cls.parameter_check(server_info) + api_info = server_info.pop("api", {}) + for service_name, info in api_info.items(): + ServiceRegistry.save_service_info(server_name, service_name, uri=info.get('uri'), method=info.get('method', 'POST'), server_info=server_info) + cls.save_server_info_to_db(server_name, server_info.get("host"), server_info.get("port"), protocol="http") + setattr(cls, server_name.upper(), server_info) + return update_server + + @classmethod + def parameter_check(cls, service_info): + if "host" in service_info and "port" in service_info: + cls.connection_test(service_info.get("host"), service_info.get("port")) + + @classmethod + def connection_test(cls, ip, port): + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + result = s.connect_ex((ip, port)) + if result != 0: + raise ConnectionRefusedError(f"connection refused: host {ip}, port {port}") + + @classmethod + def query(cls, service_name, default=None): + service_info = getattr(cls, service_name, default) + if not service_info: + service_info = utils.get_base_config(service_name, default) + return service_info + + @classmethod + @DB.connection_context() + def query_server_info_from_db(cls, server_name=None) -> [ServerRegistryInfo]: + if server_name: + server_list = ServerRegistryInfo.select().where(ServerRegistryInfo.f_server_name==server_name.upper()) + else: + server_list = ServerRegistryInfo.select() + return [server for server in server_list] + + @classmethod + @DB.connection_context() + def load_server_info_from_db(cls): + for server in cls.query_server_info_from_db(): + server_info = { + "host": server.f_host, + "port": server.f_port, + "protocol": server.f_protocol + } + setattr(cls, server.f_server_name.upper(), server_info) + + + @classmethod + @DB.connection_context() + def save_server_info_to_db(cls, server_name, host, port, protocol="http"): + server_info = { + "f_server_name": server_name, + "f_host": host, + "f_port": port, + "f_protocol": protocol + } + entity_model, status = ServerRegistryInfo.get_or_create( + f_server_name=server_name, + defaults=server_info) + if status is False: + for key in server_info: + setattr(entity_model, key, server_info[key]) + entity_model.save(force_insert=False) \ No newline at end of file diff --git a/web_server/db/services/__init__.py b/web_server/db/services/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9c9314bcc205e8158f4b28f95f7fe0ff6a0b49e8 --- /dev/null +++ b/web_server/db/services/__init__.py @@ -0,0 +1,38 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import pathlib +import re +from .user_service import UserService + + +def duplicate_name(query_func, **kwargs): + fnm = kwargs["name"] + objs = query_func(**kwargs) + if not objs: return fnm + ext = pathlib.Path(fnm).suffix #.jpg + nm = re.sub(r"%s$"%ext, "", fnm) + r = re.search(r"\([0-9]+\)$", nm) + c = 0 + if r: + c = int(r.group(1)) + nm = re.sub(r"\([0-9]+\)$", "", nm) + c += 1 + nm = f"{nm}({c})" + if ext: nm += f"{ext}" + + kwargs["name"] = nm + return duplicate_name(query_func, **kwargs) + diff --git a/web_server/db/services/common_service.py b/web_server/db/services/common_service.py new file mode 100644 index 0000000000000000000000000000000000000000..027f6f2825dc90a72056dfba8c94faf26d4991e0 --- /dev/null +++ b/web_server/db/services/common_service.py @@ -0,0 +1,153 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from datetime import datetime + +import peewee + +from web_server.db.db_models import DB +from web_server.utils import datetime_format + + +class CommonService: + model = None + + @classmethod + @DB.connection_context() + def query(cls, cols=None, reverse=None, order_by=None, **kwargs): + return cls.model.query(cols=cols, reverse=reverse, order_by=order_by, **kwargs) + + @classmethod + @DB.connection_context() + def get_all(cls, cols=None, reverse=None, order_by=None): + if cols: + query_records = cls.model.select(*cols) + else: + query_records = cls.model.select() + if reverse is not None: + if not order_by or not hasattr(cls, order_by): + order_by = "create_time" + if reverse is True: + query_records = query_records.order_by(cls.model.getter_by(order_by).desc()) + elif reverse is False: + query_records = query_records.order_by(cls.model.getter_by(order_by).asc()) + return query_records + + @classmethod + @DB.connection_context() + def get(cls, **kwargs): + return cls.model.get(**kwargs) + + @classmethod + @DB.connection_context() + def get_or_none(cls, **kwargs): + try: + return cls.model.get(**kwargs) + except peewee.DoesNotExist: + return None + + @classmethod + @DB.connection_context() + def save(cls, **kwargs): + #if "id" not in kwargs: + # kwargs["id"] = get_uuid() + sample_obj = cls.model(**kwargs).save(force_insert=True) + return sample_obj + + @classmethod + @DB.connection_context() + def insert_many(cls, data_list, batch_size=100): + with DB.atomic(): + for i in range(0, len(data_list), batch_size): + cls.model.insert_many(data_list[i:i + batch_size]).execute() + + @classmethod + @DB.connection_context() + def update_many_by_id(cls, data_list): + cur = datetime_format(datetime.now()) + with DB.atomic(): + for data in data_list: + data["update_time"] = cur + cls.model.update(data).where(cls.model.id == data["id"]).execute() + + @classmethod + @DB.connection_context() + def update_by_id(cls, pid, data): + data["update_time"] = datetime_format(datetime.now()) + num = cls.model.update(data).where(cls.model.id == pid).execute() + return num + + @classmethod + @DB.connection_context() + def get_by_id(cls, pid): + try: + obj = cls.model.query(id=pid)[0] + return True, obj + except Exception as e: + return False, None + + @classmethod + @DB.connection_context() + def get_by_ids(cls, pids, cols=None): + if cols: + objs = cls.model.select(*cols) + else: + objs = cls.model.select() + return objs.where(cls.model.id.in_(pids)) + + @classmethod + @DB.connection_context() + def delete_by_id(cls, pid): + return cls.model.delete().where(cls.model.id == pid).execute() + + + @classmethod + @DB.connection_context() + def filter_delete(cls, filters): + with DB.atomic(): + num = cls.model.delete().where(*filters).execute() + return num + + @classmethod + @DB.connection_context() + def filter_update(cls, filters, update_data): + with DB.atomic(): + cls.model.update(update_data).where(*filters).execute() + + @staticmethod + def cut_list(tar_list, n): + length = len(tar_list) + arr = range(length) + result = [tuple(tar_list[x:(x + n)]) for x in arr[::n]] + return result + + @classmethod + @DB.connection_context() + def filter_scope_list(cls, in_key, in_filters_list, filters=None, cols=None): + in_filters_tuple_list = cls.cut_list(in_filters_list, 20) + if not filters: + filters = [] + res_list = [] + if cols: + for i in in_filters_tuple_list: + query_records = cls.model.select(*cols).where(getattr(cls.model, in_key).in_(i), *filters) + if query_records: + res_list.extend([query_record for query_record in query_records]) + else: + for i in in_filters_tuple_list: + query_records = cls.model.select().where(getattr(cls.model, in_key).in_(i), *filters) + if query_records: + res_list.extend([query_record for query_record in query_records]) + return res_list \ No newline at end of file diff --git a/web_server/db/services/dialog_service.py b/web_server/db/services/dialog_service.py new file mode 100644 index 0000000000000000000000000000000000000000..e73217a7d856e1ce42c01afc84222c1ed0463e78 --- /dev/null +++ b/web_server/db/services/dialog_service.py @@ -0,0 +1,35 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import peewee +from werkzeug.security import generate_password_hash, check_password_hash + +from web_server.db.db_models import DB, UserTenant +from web_server.db.db_models import Dialog, Conversation, DialogKb +from web_server.db.services.common_service import CommonService +from web_server.utils import get_uuid, get_format_time +from web_server.db.db_utils import StatusEnum + + +class DialogService(CommonService): + model = Dialog + + +class ConversationService(CommonService): + model = Conversation + + +class DialogKbService(CommonService): + model = DialogKb diff --git a/web_server/db/services/document_service.py b/web_server/db/services/document_service.py new file mode 100644 index 0000000000000000000000000000000000000000..e8746a4ac19f8f054dde515f90c894592251dac4 --- /dev/null +++ b/web_server/db/services/document_service.py @@ -0,0 +1,75 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from web_server.db import TenantPermission, FileType +from web_server.db.db_models import DB, Knowledgebase +from web_server.db.db_models import Document +from web_server.db.services.common_service import CommonService +from web_server.db.services.kb_service import KnowledgebaseService +from web_server.utils import get_uuid, get_format_time +from web_server.db.db_utils import StatusEnum + + +class DocumentService(CommonService): + model = Document + + @classmethod + @DB.connection_context() + def get_by_kb_id(cls, kb_id, page_number, items_per_page, + orderby, desc, keywords): + if keywords: + docs = cls.model.select().where( + cls.model.kb_id == kb_id, + cls.model.name.like(f"%%{keywords}%%")) + else: + docs = cls.model.select().where(cls.model.kb_id == kb_id) + if desc: + docs = docs.order_by(cls.model.getter_by(orderby).desc()) + else: + docs = docs.order_by(cls.model.getter_by(orderby).asc()) + + docs = docs.paginate(page_number, items_per_page) + + return list(docs.dicts()) + + @classmethod + @DB.connection_context() + def insert(cls, doc): + if not cls.save(**doc): + raise RuntimeError("Database error (Document)!") + e, doc = cls.get_by_id(doc["id"]) + if not e: + raise RuntimeError("Database error (Document retrieval)!") + e, kb = KnowledgebaseService.get_by_id(doc.kb_id) + if not KnowledgebaseService.update_by_id( + kb.id, {"doc_num": kb.doc_num + 1}): + raise RuntimeError("Database error (Knowledgebase)!") + return doc + + @classmethod + @DB.connection_context() + def get_newly_uploaded(cls, tm, mod, comm, items_per_page=64): + fields = [cls.model.id, cls.model.kb_id, cls.model.parser_id, cls.model.name, cls.model.location, Knowledgebase.tenant_id] + docs = cls.model.select(fields).join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id)).where( + cls.model.status == StatusEnum.VALID.value, + cls.model.type != FileType.VIRTUAL, + cls.model.progress == 0, + cls.model.update_time >= tm, + cls.model.create_time % + comm == mod).order_by( + cls.model.update_time.asc()).paginate( + 1, + items_per_page) + return list(docs.dicts()) diff --git a/web_server/db/services/kb_service.py b/web_server/db/services/kb_service.py new file mode 100644 index 0000000000000000000000000000000000000000..84b2e4f93d9cb09bb7d19a97779bd6db089a5625 --- /dev/null +++ b/web_server/db/services/kb_service.py @@ -0,0 +1,43 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import peewee +from werkzeug.security import generate_password_hash, check_password_hash + +from web_server.db import TenantPermission +from web_server.db.db_models import DB, UserTenant +from web_server.db.db_models import Knowledgebase +from web_server.db.services.common_service import CommonService +from web_server.utils import get_uuid, get_format_time +from web_server.db.db_utils import StatusEnum + + +class KnowledgebaseService(CommonService): + model = Knowledgebase + + @classmethod + @DB.connection_context() + def get_by_tenant_ids(cls, joined_tenant_ids, user_id, page_number, items_per_page, orderby, desc): + kbs = cls.model.select().where( + ((cls.model.tenant_id.in_(joined_tenant_ids) & (cls.model.permission == TenantPermission.TEAM.value)) | (cls.model.tenant_id == user_id)) + & (cls.model.status==StatusEnum.VALID.value) + ) + if desc: kbs = kbs.order_by(cls.model.getter_by(orderby).desc()) + else: kbs = kbs.order_by(cls.model.getter_by(orderby).asc()) + + kbs = kbs.paginate(page_number, items_per_page) + + return list(kbs.dicts()) + diff --git a/web_server/db/services/knowledgebase_service.py b/web_server/db/services/knowledgebase_service.py new file mode 100644 index 0000000000000000000000000000000000000000..d5e8b34fe6ddda467f8d82df8adb808801b24a60 --- /dev/null +++ b/web_server/db/services/knowledgebase_service.py @@ -0,0 +1,31 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import peewee +from werkzeug.security import generate_password_hash, check_password_hash + +from web_server.db.db_models import DB, UserTenant +from web_server.db.db_models import Knowledgebase, Document +from web_server.db.services.common_service import CommonService +from web_server.utils import get_uuid, get_format_time +from web_server.db.db_utils import StatusEnum + + +class KnowledgebaseService(CommonService): + model = Knowledgebase + + +class DocumentService(CommonService): + model = Document diff --git a/web_server/db/services/llm_service.py b/web_server/db/services/llm_service.py new file mode 100644 index 0000000000000000000000000000000000000000..7d6b575fea608b8c21f749d22535fe3ca1e57272 --- /dev/null +++ b/web_server/db/services/llm_service.py @@ -0,0 +1,35 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import peewee +from werkzeug.security import generate_password_hash, check_password_hash + +from web_server.db.db_models import DB, UserTenant +from web_server.db.db_models import LLMFactories, LLM, TenantLLM +from web_server.db.services.common_service import CommonService +from web_server.utils import get_uuid, get_format_time +from web_server.db.db_utils import StatusEnum + + +class LLMFactoriesService(CommonService): + model = LLMFactories + + +class LLMService(CommonService): + model = LLM + + +class TenantLLMService(CommonService): + model = TenantLLM diff --git a/web_server/db/services/user_service.py b/web_server/db/services/user_service.py new file mode 100644 index 0000000000000000000000000000000000000000..42e0b5c11ad60bb1486fa41cb2d72b81937c57e0 --- /dev/null +++ b/web_server/db/services/user_service.py @@ -0,0 +1,105 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import peewee +from werkzeug.security import generate_password_hash, check_password_hash + +from web_server.db import UserTenantRole +from web_server.db.db_models import DB, UserTenant +from web_server.db.db_models import User, Tenant +from web_server.db.services.common_service import CommonService +from web_server.utils import get_uuid, get_format_time +from web_server.db.db_utils import StatusEnum + + +class UserService(CommonService): + model = User + + @classmethod + @DB.connection_context() + def filter_by_id(cls, user_id): + try: + user = cls.model.select().where(cls.model.id == user_id).get() + return user + except peewee.DoesNotExist: + return None + + @classmethod + @DB.connection_context() + def query_user(cls, email, password): + user = cls.model.select().where((cls.model.email == email), + (cls.model.status == StatusEnum.VALID.value)).first() + if user and check_password_hash(str(user.password), password): + return user + else: + return None + + @classmethod + @DB.connection_context() + def save(cls, **kwargs): + if "id" not in kwargs: + kwargs["id"] = get_uuid() + if "password" in kwargs: + kwargs["password"] = generate_password_hash(str(kwargs["password"])) + obj = cls.model(**kwargs).save(force_insert=True) + return obj + + + @classmethod + @DB.connection_context() + def delete_user(cls, user_ids, update_user_dict): + with DB.atomic(): + cls.model.update({"status": 0}).where(cls.model.id.in_(user_ids)).execute() + + @classmethod + @DB.connection_context() + def update_user(cls, user_id, user_dict): + date_time = get_format_time() + with DB.atomic(): + if user_dict: + user_dict["update_time"] = date_time + cls.model.update(user_dict).where(cls.model.id == user_id).execute() + + +class TenantService(CommonService): + model = Tenant + + @classmethod + @DB.connection_context() + def get_by_user_id(cls, user_id): + fields = [cls.model.id.alias("tenant_id"), cls.model.name, cls.model.llm_id, cls.model.embd_id, cls.model.asr_id, cls.model.img2txt_id, UserTenant.role] + return list(cls.model.select(*fields)\ + .join(UserTenant, on=((cls.model.id == UserTenant.tenant_id) & (UserTenant.user_id==user_id) & (UserTenant.status == StatusEnum.VALID.value)))\ + .where(cls.model.status == StatusEnum.VALID.value).dicts()) + + @classmethod + @DB.connection_context() + def get_joined_tenants_by_user_id(cls, user_id): + fields = [cls.model.id.alias("tenant_id"), cls.model.name, cls.model.llm_id, cls.model.embd_id, cls.model.asr_id, cls.model.img2txt_id, UserTenant.role] + return list(cls.model.select(*fields)\ + .join(UserTenant, on=((cls.model.id == UserTenant.tenant_id) & (UserTenant.user_id==user_id) & (UserTenant.status == StatusEnum.VALID.value) & (UserTenant.role==UserTenantRole.NORMAL.value)))\ + .where(cls.model.status == StatusEnum.VALID.value).dicts()) + + +class UserTenantService(CommonService): + model = UserTenant + + @classmethod + @DB.connection_context() + def save(cls, **kwargs): + if "id" not in kwargs: + kwargs["id"] = get_uuid() + obj = cls.model(**kwargs).save(force_insert=True) + return obj diff --git a/web_server/errors/__init__.py b/web_server/errors/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..358eb012bb0bfd4cbf95f469810ea18e69dcc934 --- /dev/null +++ b/web_server/errors/__init__.py @@ -0,0 +1,10 @@ +from .general_error import * + + +class FateFlowError(Exception): + message = 'Unknown Fate Flow Error' + + def __init__(self, message=None, *args, **kwargs): + message = str(message) if message is not None else self.message + message = message.format(*args, **kwargs) + super().__init__(message) \ No newline at end of file diff --git a/web_server/errors/error_services.py b/web_server/errors/error_services.py new file mode 100644 index 0000000000000000000000000000000000000000..f391a91884ca6d511665f49226c6b4b4b4fd11db --- /dev/null +++ b/web_server/errors/error_services.py @@ -0,0 +1,13 @@ +from web_server.errors import FateFlowError + +__all__ = ['ServicesError', 'ServiceNotSupported', 'ZooKeeperNotConfigured', + 'MissingZooKeeperUsernameOrPassword', 'ZooKeeperBackendError'] + + +class ServicesError(FateFlowError): + message = 'Unknown services error' + + +class ServiceNotSupported(ServicesError): + message = 'The service {service_name} is not supported' + diff --git a/web_server/errors/general_error.py b/web_server/errors/general_error.py new file mode 100644 index 0000000000000000000000000000000000000000..f4fd3fb88b2ba58c85893d9905f1529e29391667 --- /dev/null +++ b/web_server/errors/general_error.py @@ -0,0 +1,21 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +class ParameterError(Exception): + pass + + +class PassError(Exception): + pass \ No newline at end of file diff --git a/web_server/flask_session/2029240f6d1128be89ddc32729463129 b/web_server/flask_session/2029240f6d1128be89ddc32729463129 new file mode 100644 index 0000000000000000000000000000000000000000..60b84f8bf0af235343c89653c31a85c904ebfc66 Binary files /dev/null and b/web_server/flask_session/2029240f6d1128be89ddc32729463129 differ diff --git a/web_server/hook/__init__.py b/web_server/hook/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3c21c0718449bce394256b44796c923b3c67fc80 --- /dev/null +++ b/web_server/hook/__init__.py @@ -0,0 +1,57 @@ +import importlib + +from web_server.hook.common.parameters import SignatureParameters, AuthenticationParameters, \ + SignatureReturn, AuthenticationReturn, PermissionReturn, ClientAuthenticationReturn, ClientAuthenticationParameters +from web_server.settings import HOOK_MODULE, stat_logger,RetCode + + +class HookManager: + SITE_SIGNATURE = [] + SITE_AUTHENTICATION = [] + CLIENT_AUTHENTICATION = [] + PERMISSION_CHECK = [] + + @staticmethod + def init(): + if HOOK_MODULE is not None: + for modules in HOOK_MODULE.values(): + for module in modules.split(";"): + try: + importlib.import_module(module) + except Exception as e: + stat_logger.exception(e) + + @staticmethod + def register_site_signature_hook(func): + HookManager.SITE_SIGNATURE.append(func) + + @staticmethod + def register_site_authentication_hook(func): + HookManager.SITE_AUTHENTICATION.append(func) + + @staticmethod + def register_client_authentication_hook(func): + HookManager.CLIENT_AUTHENTICATION.append(func) + + @staticmethod + def register_permission_check_hook(func): + HookManager.PERMISSION_CHECK.append(func) + + @staticmethod + def client_authentication(parm: ClientAuthenticationParameters) -> ClientAuthenticationReturn: + if HookManager.CLIENT_AUTHENTICATION: + return HookManager.CLIENT_AUTHENTICATION[0](parm) + return ClientAuthenticationReturn() + + @staticmethod + def site_signature(parm: SignatureParameters) -> SignatureReturn: + if HookManager.SITE_SIGNATURE: + return HookManager.SITE_SIGNATURE[0](parm) + return SignatureReturn() + + @staticmethod + def site_authentication(parm: AuthenticationParameters) -> AuthenticationReturn: + if HookManager.SITE_AUTHENTICATION: + return HookManager.SITE_AUTHENTICATION[0](parm) + return AuthenticationReturn() + diff --git a/web_server/hook/api/client_authentication.py b/web_server/hook/api/client_authentication.py new file mode 100644 index 0000000000000000000000000000000000000000..99e93892dac2d29045ab5b5343d99b246df274ef --- /dev/null +++ b/web_server/hook/api/client_authentication.py @@ -0,0 +1,29 @@ +import requests + +from web_server.db.service_registry import ServiceRegistry +from web_server.settings import RegistryServiceName +from web_server.hook import HookManager +from web_server.hook.common.parameters import ClientAuthenticationParameters, ClientAuthenticationReturn +from web_server.settings import HOOK_SERVER_NAME + + +@HookManager.register_client_authentication_hook +def authentication(parm: ClientAuthenticationParameters) -> ClientAuthenticationReturn: + service_list = ServiceRegistry.load_service( + server_name=HOOK_SERVER_NAME, + service_name=RegistryServiceName.CLIENT_AUTHENTICATION.value + ) + if not service_list: + raise Exception(f"client authentication error: no found server" + f" {HOOK_SERVER_NAME} service client_authentication") + service = service_list[0] + response = getattr(requests, service.f_method.lower(), None)( + url=service.f_url, + json=parm.to_dict() + ) + if response.status_code != 200: + raise Exception( + f"client authentication error: request authentication url failed, status code {response.status_code}") + elif response.json().get("code") != 0: + return ClientAuthenticationReturn(code=response.json().get("code"), message=response.json().get("msg")) + return ClientAuthenticationReturn() \ No newline at end of file diff --git a/web_server/hook/api/permission.py b/web_server/hook/api/permission.py new file mode 100644 index 0000000000000000000000000000000000000000..318173d0e7a7b1dbd080073d9771cb88bb6a0b06 --- /dev/null +++ b/web_server/hook/api/permission.py @@ -0,0 +1,25 @@ +import requests + +from web_server.db.service_registry import ServiceRegistry +from web_server.settings import RegistryServiceName +from web_server.hook import HookManager +from web_server.hook.common.parameters import PermissionCheckParameters, PermissionReturn +from web_server.settings import HOOK_SERVER_NAME + + +@HookManager.register_permission_check_hook +def permission(parm: PermissionCheckParameters) -> PermissionReturn: + service_list = ServiceRegistry.load_service(server_name=HOOK_SERVER_NAME, service_name=RegistryServiceName.PERMISSION_CHECK.value) + if not service_list: + raise Exception(f"permission check error: no found server {HOOK_SERVER_NAME} service permission") + service = service_list[0] + response = getattr(requests, service.f_method.lower(), None)( + url=service.f_url, + json=parm.to_dict() + ) + if response.status_code != 200: + raise Exception( + f"permission check error: request permission url failed, status code {response.status_code}") + elif response.json().get("code") != 0: + return PermissionReturn(code=response.json().get("code"), message=response.json().get("msg")) + return PermissionReturn() diff --git a/web_server/hook/api/site_authentication.py b/web_server/hook/api/site_authentication.py new file mode 100644 index 0000000000000000000000000000000000000000..bea3b77888e839b2e0e4de56712822219e181914 --- /dev/null +++ b/web_server/hook/api/site_authentication.py @@ -0,0 +1,49 @@ +import requests + +from web_server.db.service_registry import ServiceRegistry +from web_server.settings import RegistryServiceName +from web_server.hook import HookManager +from web_server.hook.common.parameters import SignatureParameters, AuthenticationParameters, AuthenticationReturn,\ + SignatureReturn +from web_server.settings import HOOK_SERVER_NAME, PARTY_ID + + +@HookManager.register_site_signature_hook +def signature(parm: SignatureParameters) -> SignatureReturn: + service_list = ServiceRegistry.load_service(server_name=HOOK_SERVER_NAME, service_name=RegistryServiceName.SIGNATURE.value) + if not service_list: + raise Exception(f"signature error: no found server {HOOK_SERVER_NAME} service signature") + service = service_list[0] + response = getattr(requests, service.f_method.lower(), None)( + url=service.f_url, + json=parm.to_dict() + ) + if response.status_code == 200: + if response.json().get("code") == 0: + return SignatureReturn(site_signature=response.json().get("data")) + else: + raise Exception(f"signature error: request signature url failed, result: {response.json()}") + else: + raise Exception(f"signature error: request signature url failed, status code {response.status_code}") + + +@HookManager.register_site_authentication_hook +def authentication(parm: AuthenticationParameters) -> AuthenticationReturn: + if not parm.src_party_id or str(parm.src_party_id) == "0": + parm.src_party_id = PARTY_ID + service_list = ServiceRegistry.load_service(server_name=HOOK_SERVER_NAME, + service_name=RegistryServiceName.SITE_AUTHENTICATION.value) + if not service_list: + raise Exception( + f"site authentication error: no found server {HOOK_SERVER_NAME} service site_authentication") + service = service_list[0] + response = getattr(requests, service.f_method.lower(), None)( + url=service.f_url, + json=parm.to_dict() + ) + if response.status_code != 200: + raise Exception( + f"site authentication error: request site_authentication url failed, status code {response.status_code}") + elif response.json().get("code") != 0: + return AuthenticationReturn(code=response.json().get("code"), message=response.json().get("msg")) + return AuthenticationReturn() \ No newline at end of file diff --git a/web_server/hook/common/parameters.py b/web_server/hook/common/parameters.py new file mode 100644 index 0000000000000000000000000000000000000000..40ce4ef19336fb4e4177a3e739894239a6add5ad --- /dev/null +++ b/web_server/hook/common/parameters.py @@ -0,0 +1,56 @@ +from web_server.settings import RetCode + + +class ParametersBase: + def to_dict(self): + d = {} + for k, v in self.__dict__.items(): + d[k] = v + return d + + +class ClientAuthenticationParameters(ParametersBase): + def __init__(self, full_path, headers, form, data, json): + self.full_path = full_path + self.headers = headers + self.form = form + self.data = data + self.json = json + + +class ClientAuthenticationReturn(ParametersBase): + def __init__(self, code=RetCode.SUCCESS, message="success"): + self.code = code + self.message = message + + +class SignatureParameters(ParametersBase): + def __init__(self, party_id, body): + self.party_id = party_id + self.body = body + + +class SignatureReturn(ParametersBase): + def __init__(self, code=RetCode.SUCCESS, site_signature=None): + self.code = code + self.site_signature = site_signature + + +class AuthenticationParameters(ParametersBase): + def __init__(self, site_signature, body): + self.site_signature = site_signature + self.body = body + + +class AuthenticationReturn(ParametersBase): + def __init__(self, code=RetCode.SUCCESS, message="success"): + self.code = code + self.message = message + + +class PermissionReturn(ParametersBase): + def __init__(self, code=RetCode.SUCCESS, message="success"): + self.code = code + self.message = message + + diff --git a/web_server/ragflow_server.py b/web_server/ragflow_server.py new file mode 100644 index 0000000000000000000000000000000000000000..3bd9181d68208312d02d9ee44ffa4837f8662bdb --- /dev/null +++ b/web_server/ragflow_server.py @@ -0,0 +1,80 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# init env. must be the first import + +import logging +import os +import signal +import sys +import traceback + +from werkzeug.serving import run_simple + +from web_server.apps import app +from web_server.db.runtime_config import RuntimeConfig +from web_server.hook import HookManager +from web_server.settings import ( + HOST, HTTP_PORT, access_logger, database_logger, stat_logger, +) +from web_server import utils + +from web_server.db.db_models import init_database_tables as init_web_db +from web_server.db.init_data import init_web_data +from web_server.versions import get_versions + +if __name__ == '__main__': + stat_logger.info( + f'project base: {utils.file_utils.get_project_base_directory()}' + ) + + # init db + init_web_db() + init_web_data() + # init runtime config + import argparse + parser = argparse.ArgumentParser() + parser.add_argument('--version', default=False, help="fate flow version", action='store_true') + parser.add_argument('--debug', default=False, help="debug mode", action='store_true') + args = parser.parse_args() + if args.version: + print(get_versions()) + sys.exit(0) + + RuntimeConfig.DEBUG = args.debug + if RuntimeConfig.DEBUG: + stat_logger.info("run on debug mode") + + RuntimeConfig.init_env() + RuntimeConfig.init_config(JOB_SERVER_HOST=HOST, HTTP_PORT=HTTP_PORT) + + HookManager.init() + + peewee_logger = logging.getLogger('peewee') + peewee_logger.propagate = False + # fate_arch.common.log.ROpenHandler + peewee_logger.addHandler(database_logger.handlers[0]) + peewee_logger.setLevel(database_logger.level) + + # start http server + try: + stat_logger.info("FATE Flow http server start...") + werkzeug_logger = logging.getLogger("werkzeug") + for h in access_logger.handlers: + werkzeug_logger.addHandler(h) + run_simple(hostname=HOST, port=HTTP_PORT, application=app, threaded=True, use_reloader=RuntimeConfig.DEBUG, use_debugger=RuntimeConfig.DEBUG) + except Exception: + traceback.print_exc() + os.kill(os.getpid(), signal.SIGKILL) \ No newline at end of file diff --git a/web_server/settings.py b/web_server/settings.py new file mode 100644 index 0000000000000000000000000000000000000000..a93efaa93329be4ec765d22d43f747046a55bd75 --- /dev/null +++ b/web_server/settings.py @@ -0,0 +1,156 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import os + +from enum import IntEnum, Enum + +from web_server.utils import get_base_config,decrypt_database_config +from web_server.utils.file_utils import get_project_base_directory +from web_server.utils.log_utils import LoggerFactory, getLogger + + +# Server +API_VERSION = "v1" +FATE_FLOW_SERVICE_NAME = "ragflow" +SERVER_MODULE = "rag_flow_server.py" +TEMP_DIRECTORY = os.path.join(get_project_base_directory(), "temp") +FATE_FLOW_CONF_PATH = os.path.join(get_project_base_directory(), "conf") + +SUBPROCESS_STD_LOG_NAME = "std.log" + +ERROR_REPORT = True +ERROR_REPORT_WITH_PATH = False + +MAX_TIMESTAMP_INTERVAL = 60 +SESSION_VALID_PERIOD = 7 * 24 * 60 * 60 * 1000 + +REQUEST_TRY_TIMES = 3 +REQUEST_WAIT_SEC = 2 +REQUEST_MAX_WAIT_SEC = 300 + +USE_REGISTRY = get_base_config("use_registry") + +LLM = get_base_config("llm", {}) +CHAT_MDL = LLM.get("chat_model", "gpt-3.5-turbo") +EMBEDDING_MDL = LLM.get("embedding_model", "text-embedding-ada-002") +ASR_MDL = LLM.get("asr_model", "whisper-1") +PARSERS = LLM.get("parsers", "General,Resume,Laws,Product Instructions,Books,Paper,Q&A,Programming Code,Power Point,Research Report") +IMAGE2TEXT_MDL = LLM.get("image2text_model", "gpt-4-vision-preview") + +# distribution +DEPENDENT_DISTRIBUTION = get_base_config("dependent_distribution", False) +FATE_FLOW_UPDATE_CHECK = False + +HOST = get_base_config(FATE_FLOW_SERVICE_NAME, {}).get("host", "127.0.0.1") +HTTP_PORT = get_base_config(FATE_FLOW_SERVICE_NAME, {}).get("http_port") + +SECRET_KEY = get_base_config(FATE_FLOW_SERVICE_NAME, {}).get("secret_key", "infiniflow") +TOKEN_EXPIRE_IN = get_base_config(FATE_FLOW_SERVICE_NAME, {}).get("token_expires_in", 3600) + +NGINX_HOST = get_base_config(FATE_FLOW_SERVICE_NAME, {}).get("nginx", {}).get("host") or HOST +NGINX_HTTP_PORT = get_base_config(FATE_FLOW_SERVICE_NAME, {}).get("nginx", {}).get("http_port") or HTTP_PORT + +RANDOM_INSTANCE_ID = get_base_config(FATE_FLOW_SERVICE_NAME, {}).get("random_instance_id", False) + +PROXY = get_base_config(FATE_FLOW_SERVICE_NAME, {}).get("proxy") +PROXY_PROTOCOL = get_base_config(FATE_FLOW_SERVICE_NAME, {}).get("protocol") + +DATABASE = decrypt_database_config() + +# Logger +LoggerFactory.set_directory(os.path.join(get_project_base_directory(), "logs", "web_server")) +# {CRITICAL: 50, FATAL:50, ERROR:40, WARNING:30, WARN:30, INFO:20, DEBUG:10, NOTSET:0} +LoggerFactory.LEVEL = 10 + +stat_logger = getLogger("stat") +access_logger = getLogger("access") +database_logger = getLogger("database") + +# Switch +# upload +UPLOAD_DATA_FROM_CLIENT = True + +# authentication +AUTHENTICATION_CONF = get_base_config("authentication", {}) + +# client +CLIENT_AUTHENTICATION = AUTHENTICATION_CONF.get("client", {}).get("switch", False) +HTTP_APP_KEY = AUTHENTICATION_CONF.get("client", {}).get("http_app_key") +GITHUB_OAUTH = get_base_config("oauth", {}).get("github") +WECHAT_OAUTH = get_base_config("oauth", {}).get("wechat") + +# site +SITE_AUTHENTICATION = AUTHENTICATION_CONF.get("site", {}).get("switch", False) + +# permission +PERMISSION_CONF = get_base_config("permission", {}) +PERMISSION_SWITCH = PERMISSION_CONF.get("switch") +COMPONENT_PERMISSION = PERMISSION_CONF.get("component") +DATASET_PERMISSION = PERMISSION_CONF.get("dataset") + +HOOK_MODULE = get_base_config("hook_module") +HOOK_SERVER_NAME = get_base_config("hook_server_name") + +ENABLE_MODEL_STORE = get_base_config('enable_model_store', False) +# authentication +USE_AUTHENTICATION = False +USE_DATA_AUTHENTICATION = False +AUTOMATIC_AUTHORIZATION_OUTPUT_DATA = True +USE_DEFAULT_TIMEOUT = False +AUTHENTICATION_DEFAULT_TIMEOUT = 30 * 24 * 60 * 60 # s +PRIVILEGE_COMMAND_WHITELIST = [] +CHECK_NODES_IDENTITY = False + +class CustomEnum(Enum): + @classmethod + def valid(cls, value): + try: + cls(value) + return True + except: + return False + + @classmethod + def values(cls): + return [member.value for member in cls.__members__.values()] + + @classmethod + def names(cls): + return [member.name for member in cls.__members__.values()] + + +class PythonDependenceName(CustomEnum): + Fate_Source_Code = "python" + Python_Env = "miniconda" + + +class ModelStorage(CustomEnum): + REDIS = "redis" + MYSQL = "mysql" + + +class RetCode(IntEnum, CustomEnum): + SUCCESS = 0 + NOT_EFFECTIVE = 10 + EXCEPTION_ERROR = 100 + ARGUMENT_ERROR = 101 + DATA_ERROR = 102 + OPERATING_ERROR = 103 + CONNECTION_ERROR = 105 + RUNNING = 106 + PERMISSION_ERROR = 108 + AUTHENTICATION_ERROR = 109 + SERVER_ERROR = 500 diff --git a/web_server/utils/__init__.py b/web_server/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..57c11ba1cbc7a7ac84460562d8940ad14fff6184 --- /dev/null +++ b/web_server/utils/__init__.py @@ -0,0 +1,321 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import base64 +from datetime import datetime +import io +import json +import os +import pickle +import socket +import time +import uuid +import requests +from enum import Enum, IntEnum +import importlib +from Cryptodome.PublicKey import RSA +from Cryptodome.Cipher import PKCS1_v1_5 as Cipher_pkcs1_v1_5 + +from filelock import FileLock + +from . import file_utils + +SERVICE_CONF = "service_conf.yaml" + +def conf_realpath(conf_name): + conf_path = f"conf/{conf_name}" + return os.path.join(file_utils.get_project_base_directory(), conf_path) + +def get_base_config(key, default=None, conf_name=SERVICE_CONF) -> dict: + local_config = {} + local_path = conf_realpath(f'local.{conf_name}') + if default is None: + default = os.environ.get(key.upper()) + + if os.path.exists(local_path): + local_config = file_utils.load_yaml_conf(local_path) + if not isinstance(local_config, dict): + raise ValueError(f'Invalid config file: "{local_path}".') + + if key is not None and key in local_config: + return local_config[key] + + config_path = conf_realpath(conf_name) + config = file_utils.load_yaml_conf(config_path) + + if not isinstance(config, dict): + raise ValueError(f'Invalid config file: "{config_path}".') + + config.update(local_config) + return config.get(key, default) if key is not None else config + + +use_deserialize_safe_module = get_base_config('use_deserialize_safe_module', False) + + +class CoordinationCommunicationProtocol(object): + HTTP = "http" + GRPC = "grpc" + + +class BaseType: + def to_dict(self): + return dict([(k.lstrip("_"), v) for k, v in self.__dict__.items()]) + + def to_dict_with_type(self): + def _dict(obj): + module = None + if issubclass(obj.__class__, BaseType): + data = {} + for attr, v in obj.__dict__.items(): + k = attr.lstrip("_") + data[k] = _dict(v) + module = obj.__module__ + elif isinstance(obj, (list, tuple)): + data = [] + for i, vv in enumerate(obj): + data.append(_dict(vv)) + elif isinstance(obj, dict): + data = {} + for _k, vv in obj.items(): + data[_k] = _dict(vv) + else: + data = obj + return {"type": obj.__class__.__name__, "data": data, "module": module} + return _dict(self) + + +class CustomJSONEncoder(json.JSONEncoder): + def __init__(self, **kwargs): + self._with_type = kwargs.pop("with_type", False) + super().__init__(**kwargs) + + def default(self, obj): + if isinstance(obj, datetime.datetime): + return obj.strftime('%Y-%m-%d %H:%M:%S') + elif isinstance(obj, datetime.date): + return obj.strftime('%Y-%m-%d') + elif isinstance(obj, datetime.timedelta): + return str(obj) + elif issubclass(type(obj), Enum) or issubclass(type(obj), IntEnum): + return obj.value + elif isinstance(obj, set): + return list(obj) + elif issubclass(type(obj), BaseType): + if not self._with_type: + return obj.to_dict() + else: + return obj.to_dict_with_type() + elif isinstance(obj, type): + return obj.__name__ + else: + return json.JSONEncoder.default(self, obj) + + +def rag_uuid(): + return uuid.uuid1().hex + + +def string_to_bytes(string): + return string if isinstance(string, bytes) else string.encode(encoding="utf-8") + + +def bytes_to_string(byte): + return byte.decode(encoding="utf-8") + + +def json_dumps(src, byte=False, indent=None, with_type=False): + dest = json.dumps(src, indent=indent, cls=CustomJSONEncoder, with_type=with_type) + if byte: + dest = string_to_bytes(dest) + return dest + + +def json_loads(src, object_hook=None, object_pairs_hook=None): + if isinstance(src, bytes): + src = bytes_to_string(src) + return json.loads(src, object_hook=object_hook, object_pairs_hook=object_pairs_hook) + + +def current_timestamp(): + return int(time.time() * 1000) + + +def timestamp_to_date(timestamp, format_string="%Y-%m-%d %H:%M:%S"): + if not timestamp: + timestamp = time.time() + timestamp = int(timestamp) / 1000 + time_array = time.localtime(timestamp) + str_date = time.strftime(format_string, time_array) + return str_date + + +def date_string_to_timestamp(time_str, format_string="%Y-%m-%d %H:%M:%S"): + time_array = time.strptime(time_str, format_string) + time_stamp = int(time.mktime(time_array) * 1000) + return time_stamp + + +def serialize_b64(src, to_str=False): + dest = base64.b64encode(pickle.dumps(src)) + if not to_str: + return dest + else: + return bytes_to_string(dest) + + +def deserialize_b64(src): + src = base64.b64decode(string_to_bytes(src) if isinstance(src, str) else src) + if use_deserialize_safe_module: + return restricted_loads(src) + return pickle.loads(src) + + +safe_module = { + 'numpy', + 'fate_flow' +} + + +class RestrictedUnpickler(pickle.Unpickler): + def find_class(self, module, name): + import importlib + if module.split('.')[0] in safe_module: + _module = importlib.import_module(module) + return getattr(_module, name) + # Forbid everything else. + raise pickle.UnpicklingError("global '%s.%s' is forbidden" % + (module, name)) + + +def restricted_loads(src): + """Helper function analogous to pickle.loads().""" + return RestrictedUnpickler(io.BytesIO(src)).load() + + +def get_lan_ip(): + if os.name != "nt": + import fcntl + import struct + + def get_interface_ip(ifname): + s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + return socket.inet_ntoa( + fcntl.ioctl(s.fileno(), 0x8915, struct.pack('256s', string_to_bytes(ifname[:15])))[20:24]) + + ip = socket.gethostbyname(socket.getfqdn()) + if ip.startswith("127.") and os.name != "nt": + interfaces = [ + "bond1", + "eth0", + "eth1", + "eth2", + "wlan0", + "wlan1", + "wifi0", + "ath0", + "ath1", + "ppp0", + ] + for ifname in interfaces: + try: + ip = get_interface_ip(ifname) + break + except IOError as e: + pass + return ip or '' + +def from_dict_hook(in_dict: dict): + if "type" in in_dict and "data" in in_dict: + if in_dict["module"] is None: + return in_dict["data"] + else: + return getattr(importlib.import_module(in_dict["module"]), in_dict["type"])(**in_dict["data"]) + else: + return in_dict + + +def decrypt_database_password(password): + encrypt_password = get_base_config("encrypt_password", False) + encrypt_module = get_base_config("encrypt_module", False) + private_key = get_base_config("private_key", None) + + if not password or not encrypt_password: + return password + + if not private_key: + raise ValueError("No private key") + + module_fun = encrypt_module.split("#") + pwdecrypt_fun = getattr(importlib.import_module(module_fun[0]), module_fun[1]) + + return pwdecrypt_fun(private_key, password) + + +def decrypt_database_config(database=None, passwd_key="passwd", name="database"): + if not database: + database = get_base_config(name, {}) + + database[passwd_key] = decrypt_database_password(database[passwd_key]) + return database + + +def update_config(key, value, conf_name=SERVICE_CONF): + conf_path = conf_realpath(conf_name=conf_name) + if not os.path.isabs(conf_path): + conf_path = os.path.join(file_utils.get_project_base_directory(), conf_path) + + with FileLock(os.path.join(os.path.dirname(conf_path), ".lock")): + config = file_utils.load_yaml_conf(conf_path=conf_path) or {} + config[key] = value + file_utils.rewrite_yaml_conf(conf_path=conf_path, config=config) + + +def get_uuid(): + return uuid.uuid1().hex + + +def datetime_format(date_time: datetime) -> datetime: + return datetime(date_time.year, date_time.month, date_time.day, date_time.hour, date_time.minute, date_time.second) + + +def get_format_time() -> datetime: + return datetime_format(datetime.now()) + + +def str2date(date_time: str): + return datetime.strptime(date_time, '%Y-%m-%d') + + +def elapsed2time(elapsed): + seconds = elapsed / 1000 + minuter, second = divmod(seconds, 60) + hour, minuter = divmod(minuter, 60) + return '%02d:%02d:%02d' % (hour, minuter, second) + + +def decrypt(line): + file_path = os.path.join(file_utils.get_project_base_directory(), "conf", "private.pem") + rsa_key = RSA.importKey(open(file_path).read(), "Welcome") + cipher = Cipher_pkcs1_v1_5.new(rsa_key) + return cipher.decrypt(base64.b64decode(line), "Fail to decrypt password!").decode('utf-8') + + +def download_img(url): + if not url: return "" + response = requests.get(url) + return "data:" + \ + response.headers.get('Content-Type', 'image/jpg') + ";" + \ + "base64," + base64.b64encode(response.content).decode("utf-8") diff --git a/web_server/utils/api_utils.py b/web_server/utils/api_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..2933a0621cbc8ce8f76c99510ac58d26b5210a4d --- /dev/null +++ b/web_server/utils/api_utils.py @@ -0,0 +1,212 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import json +import random +import time +from functools import wraps +from io import BytesIO +from flask import ( + Response, jsonify, send_file,make_response, + request as flask_request, +) +from werkzeug.http import HTTP_STATUS_CODES + +from web_server.utils import json_dumps +from web_server.versions import get_fate_version +from web_server.settings import RetCode +from web_server.settings import ( + REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC, + stat_logger,CLIENT_AUTHENTICATION, HTTP_APP_KEY, SECRET_KEY +) +import requests +import functools +from web_server.utils import CustomJSONEncoder +from uuid import uuid1 +from base64 import b64encode +from hmac import HMAC +from urllib.parse import quote, urlencode + + +requests.models.complexjson.dumps = functools.partial(json.dumps, cls=CustomJSONEncoder) + + +def request(**kwargs): + sess = requests.Session() + stream = kwargs.pop('stream', sess.stream) + timeout = kwargs.pop('timeout', None) + kwargs['headers'] = {k.replace('_', '-').upper(): v for k, v in kwargs.get('headers', {}).items()} + prepped = requests.Request(**kwargs).prepare() + + if CLIENT_AUTHENTICATION and HTTP_APP_KEY and SECRET_KEY: + timestamp = str(round(time() * 1000)) + nonce = str(uuid1()) + signature = b64encode(HMAC(SECRET_KEY.encode('ascii'), b'\n'.join([ + timestamp.encode('ascii'), + nonce.encode('ascii'), + HTTP_APP_KEY.encode('ascii'), + prepped.path_url.encode('ascii'), + prepped.body if kwargs.get('json') else b'', + urlencode(sorted(kwargs['data'].items()), quote_via=quote, safe='-._~').encode('ascii') + if kwargs.get('data') and isinstance(kwargs['data'], dict) else b'', + ]), 'sha1').digest()).decode('ascii') + + prepped.headers.update({ + 'TIMESTAMP': timestamp, + 'NONCE': nonce, + 'APP-KEY': HTTP_APP_KEY, + 'SIGNATURE': signature, + }) + + return sess.send(prepped, stream=stream, timeout=timeout) + + +fate_version = get_fate_version() or '' + + +def get_exponential_backoff_interval(retries, full_jitter=False): + """Calculate the exponential backoff wait time.""" + # Will be zero if factor equals 0 + countdown = min(REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC * (2 ** retries)) + # Full jitter according to + # https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/ + if full_jitter: + countdown = random.randrange(countdown + 1) + # Adjust according to maximum wait time and account for negative values. + return max(0, countdown) + + +def get_json_result(retcode=RetCode.SUCCESS, retmsg='success', data=None, job_id=None, meta=None): + import re + result_dict = { + "retcode": retcode, + "retmsg":retmsg, + # "retmsg": re.sub(r"fate", "seceum", retmsg, flags=re.IGNORECASE), + "data": data, + "jobId": job_id, + "meta": meta, + } + + response = {} + for key, value in result_dict.items(): + if value is None and key != "retcode": + continue + else: + response[key] = value + return jsonify(response) + +def get_data_error_result(retcode=RetCode.DATA_ERROR, retmsg='Sorry! Data missing!'): + import re + result_dict = {"retcode": retcode, "retmsg": re.sub(r"fate", "seceum", retmsg, flags=re.IGNORECASE)} + response = {} + for key, value in result_dict.items(): + if value is None and key != "retcode": + continue + else: + response[key] = value + return jsonify(response) + +def server_error_response(e): + stat_logger.exception(e) + try: + if e.code==401: + return get_json_result(retcode=401, retmsg=repr(e)) + except: + pass + if len(e.args) > 1: + return get_json_result(retcode=RetCode.EXCEPTION_ERROR, retmsg=repr(e.args[0]), data=e.args[1]) + return get_json_result(retcode=RetCode.EXCEPTION_ERROR, retmsg=repr(e)) + + +def error_response(response_code, retmsg=None): + if retmsg is None: + retmsg = HTTP_STATUS_CODES.get(response_code, 'Unknown Error') + + return Response(json.dumps({ + 'retmsg': retmsg, + 'retcode': response_code, + }), status=response_code, mimetype='application/json') + + +def validate_request(*args, **kwargs): + def wrapper(func): + @wraps(func) + def decorated_function(*_args, **_kwargs): + input_arguments = flask_request.json or flask_request.form.to_dict() + no_arguments = [] + error_arguments = [] + for arg in args: + if arg not in input_arguments: + no_arguments.append(arg) + for k, v in kwargs.items(): + config_value = input_arguments.get(k, None) + if config_value is None: + no_arguments.append(k) + elif isinstance(v, (tuple, list)): + if config_value not in v: + error_arguments.append((k, set(v))) + elif config_value != v: + error_arguments.append((k, v)) + if no_arguments or error_arguments: + error_string = "" + if no_arguments: + error_string += "required argument are missing: {}; ".format(",".join(no_arguments)) + if error_arguments: + error_string += "required argument values: {}".format(",".join(["{}={}".format(a[0], a[1]) for a in error_arguments])) + return get_json_result(retcode=RetCode.ARGUMENT_ERROR, retmsg=error_string) + return func(*_args, **_kwargs) + return decorated_function + return wrapper + + +def is_localhost(ip): + return ip in {'127.0.0.1', '::1', '[::1]', 'localhost'} + + +def send_file_in_mem(data, filename): + if not isinstance(data, (str, bytes)): + data = json_dumps(data) + if isinstance(data, str): + data = data.encode('utf-8') + + f = BytesIO() + f.write(data) + f.seek(0) + + return send_file(f, as_attachment=True, attachment_filename=filename) + + +def get_json_result(retcode=RetCode.SUCCESS, retmsg='success', data=None): + response = {"retcode": retcode, "retmsg": retmsg, "data": data} + return jsonify(response) + + +def cors_reponse(retcode=RetCode.SUCCESS, retmsg='success', data=None, auth=None): + result_dict = {"retcode": retcode, "retmsg": retmsg, "data": data} + response_dict = {} + for key, value in result_dict.items(): + if value is None and key != "retcode": + continue + else: + response_dict[key] = value + response = make_response(jsonify(response_dict)) + if auth: + response.headers["Authorization"] = auth + response.headers["Access-Control-Allow-Origin"] = "*" + response.headers["Access-Control-Allow-Method"] = "*" + response.headers["Access-Control-Allow-Headers"] = "*" + response.headers["Access-Control-Allow-Headers"] = "*" + response.headers["Access-Control-Expose-Headers"] = "Authorization" + return response \ No newline at end of file diff --git a/web_server/utils/file_utils.py b/web_server/utils/file_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..54b1514ecf85b70b5ebdf0155e9c68ffb92f5c6f --- /dev/null +++ b/web_server/utils/file_utils.py @@ -0,0 +1,153 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import json +import os +import re + +from cachetools import LRUCache, cached +from ruamel.yaml import YAML + +from web_server.db import FileType + +PROJECT_BASE = os.getenv("RAG_PROJECT_BASE") or os.getenv("RAG_DEPLOY_BASE") +FATE_BASE = os.getenv("RAG_BASE") + +def get_project_base_directory(*args): + global PROJECT_BASE + if PROJECT_BASE is None: + PROJECT_BASE = os.path.abspath( + os.path.join( + os.path.dirname(os.path.realpath(__file__)), + os.pardir, + os.pardir, + ) + ) + + if args: + return os.path.join(PROJECT_BASE, *args) + return PROJECT_BASE + + +def get_fate_directory(*args): + global FATE_BASE + if FATE_BASE is None: + FATE_BASE = os.path.abspath( + os.path.join( + os.path.dirname(os.path.realpath(__file__)), + os.pardir, + os.pardir, + os.pardir, + ) + ) + if args: + return os.path.join(FATE_BASE, *args) + return FATE_BASE + + +def get_fate_python_directory(*args): + return get_fate_directory("python", *args) + + + +@cached(cache=LRUCache(maxsize=10)) +def load_json_conf(conf_path): + if os.path.isabs(conf_path): + json_conf_path = conf_path + else: + json_conf_path = os.path.join(get_project_base_directory(), conf_path) + try: + with open(json_conf_path) as f: + return json.load(f) + except BaseException: + raise EnvironmentError( + "loading json file config from '{}' failed!".format(json_conf_path) + ) + + +def dump_json_conf(config_data, conf_path): + if os.path.isabs(conf_path): + json_conf_path = conf_path + else: + json_conf_path = os.path.join(get_project_base_directory(), conf_path) + try: + with open(json_conf_path, "w") as f: + json.dump(config_data, f, indent=4) + except BaseException: + raise EnvironmentError( + "loading json file config from '{}' failed!".format(json_conf_path) + ) + + +def load_json_conf_real_time(conf_path): + if os.path.isabs(conf_path): + json_conf_path = conf_path + else: + json_conf_path = os.path.join(get_project_base_directory(), conf_path) + try: + with open(json_conf_path) as f: + return json.load(f) + except BaseException: + raise EnvironmentError( + "loading json file config from '{}' failed!".format(json_conf_path) + ) + + +def load_yaml_conf(conf_path): + if not os.path.isabs(conf_path): + conf_path = os.path.join(get_project_base_directory(), conf_path) + try: + with open(conf_path) as f: + yaml = YAML(typ='safe', pure=True) + return yaml.load(f) + except Exception as e: + raise EnvironmentError( + "loading yaml file config from {} failed:".format(conf_path), e + ) + + +def rewrite_yaml_conf(conf_path, config): + if not os.path.isabs(conf_path): + conf_path = os.path.join(get_project_base_directory(), conf_path) + try: + with open(conf_path, "w") as f: + yaml = YAML(typ="safe") + yaml.dump(config, f) + except Exception as e: + raise EnvironmentError( + "rewrite yaml file config {} failed:".format(conf_path), e + ) + + +def rewrite_json_file(filepath, json_data): + with open(filepath, "w") as f: + json.dump(json_data, f, indent=4, separators=(",", ": ")) + f.close() + + +def filename_type(filename): + filename = filename.lower() + if re.match(r".*\.pdf$", filename): + return FileType.PDF.value + + if re.match(r".*\.(doc|ppt|yml|xml|htm|json|csv|txt|ini|xsl|wps|rtf|hlp|pages|numbers|key)$", filename): + return FileType.DOC.value + + if re.match(r".*\.(wav|flac|ape|alac|wavpack|wv|mp3|aac|ogg|vorbis|opus|mp3)$", filename): + return FileType.AURAL.value + + if re.match(r".*\.(jpg|jpeg|png|tif|gif|pcx|tga|exif|fpx|svg|psd|cdr|pcd|dxf|ufo|eps|ai|raw|WMF|webp|avif|apng|icon|ico|mpg|mpeg|avi|rm|rmvb|mov|wmv|asf|dat|asx|wvx|mpe|mpa|mp4)$", filename): + return FileType.VISUAL \ No newline at end of file diff --git a/web_server/utils/log_utils.py b/web_server/utils/log_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..5efe3c81712e96b7ebc2b4aa2567ae8cb7d8485d --- /dev/null +++ b/web_server/utils/log_utils.py @@ -0,0 +1,299 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import os +import typing +import traceback +import logging +import inspect +from logging.handlers import TimedRotatingFileHandler +from threading import RLock + +from web_server.utils import file_utils + +class LoggerFactory(object): + TYPE = "FILE" + LOG_FORMAT = "[%(levelname)s] [%(asctime)s] [jobId] [%(process)s:%(thread)s] - [%(module)s.%(funcName)s] [line:%(lineno)d]: %(message)s" + LEVEL = logging.DEBUG + logger_dict = {} + global_handler_dict = {} + + LOG_DIR = None + PARENT_LOG_DIR = None + log_share = True + + append_to_parent_log = None + + lock = RLock() + # CRITICAL = 50 + # FATAL = CRITICAL + # ERROR = 40 + # WARNING = 30 + # WARN = WARNING + # INFO = 20 + # DEBUG = 10 + # NOTSET = 0 + levels = (10, 20, 30, 40) + schedule_logger_dict = {} + + @staticmethod + def set_directory(directory=None, parent_log_dir=None, append_to_parent_log=None, force=False): + if parent_log_dir: + LoggerFactory.PARENT_LOG_DIR = parent_log_dir + if append_to_parent_log: + LoggerFactory.append_to_parent_log = append_to_parent_log + with LoggerFactory.lock: + if not directory: + directory = file_utils.get_project_base_directory("logs") + if not LoggerFactory.LOG_DIR or force: + LoggerFactory.LOG_DIR = directory + if LoggerFactory.log_share: + oldmask = os.umask(000) + os.makedirs(LoggerFactory.LOG_DIR, exist_ok=True) + os.umask(oldmask) + else: + os.makedirs(LoggerFactory.LOG_DIR, exist_ok=True) + for loggerName, ghandler in LoggerFactory.global_handler_dict.items(): + for className, (logger, handler) in LoggerFactory.logger_dict.items(): + logger.removeHandler(ghandler) + ghandler.close() + LoggerFactory.global_handler_dict = {} + for className, (logger, handler) in LoggerFactory.logger_dict.items(): + logger.removeHandler(handler) + _handler = None + if handler: + handler.close() + if className != "default": + _handler = LoggerFactory.get_handler(className) + logger.addHandler(_handler) + LoggerFactory.assemble_global_handler(logger) + LoggerFactory.logger_dict[className] = logger, _handler + + @staticmethod + def new_logger(name): + logger = logging.getLogger(name) + logger.propagate = False + logger.setLevel(LoggerFactory.LEVEL) + return logger + + @staticmethod + def get_logger(class_name=None): + with LoggerFactory.lock: + if class_name in LoggerFactory.logger_dict.keys(): + logger, handler = LoggerFactory.logger_dict[class_name] + if not logger: + logger, handler = LoggerFactory.init_logger(class_name) + else: + logger, handler = LoggerFactory.init_logger(class_name) + return logger + + @staticmethod + def get_global_handler(logger_name, level=None, log_dir=None): + if not LoggerFactory.LOG_DIR: + return logging.StreamHandler() + if log_dir: + logger_name_key = logger_name + "_" + log_dir + else: + logger_name_key = logger_name + "_" + LoggerFactory.LOG_DIR + # if loggerName not in LoggerFactory.globalHandlerDict: + if logger_name_key not in LoggerFactory.global_handler_dict: + with LoggerFactory.lock: + if logger_name_key not in LoggerFactory.global_handler_dict: + handler = LoggerFactory.get_handler(logger_name, level, log_dir) + LoggerFactory.global_handler_dict[logger_name_key] = handler + return LoggerFactory.global_handler_dict[logger_name_key] + + @staticmethod + def get_handler(class_name, level=None, log_dir=None, log_type=None, job_id=None): + if not log_type: + if not LoggerFactory.LOG_DIR or not class_name: + return logging.StreamHandler() + # return Diy_StreamHandler() + + if not log_dir: + log_file = os.path.join(LoggerFactory.LOG_DIR, "{}.log".format(class_name)) + else: + log_file = os.path.join(log_dir, "{}.log".format(class_name)) + else: + log_file = os.path.join(log_dir, "fate_flow_{}.log".format( + log_type) if level == LoggerFactory.LEVEL else 'fate_flow_{}_error.log'.format(log_type)) + job_id = job_id or os.getenv("FATE_JOB_ID") + if job_id: + formatter = logging.Formatter(LoggerFactory.LOG_FORMAT.replace("jobId", job_id)) + else: + formatter = logging.Formatter(LoggerFactory.LOG_FORMAT.replace("jobId", "Server")) + os.makedirs(os.path.dirname(log_file), exist_ok=True) + if LoggerFactory.log_share: + handler = ROpenHandler(log_file, + when='D', + interval=1, + backupCount=14, + delay=True) + else: + handler = TimedRotatingFileHandler(log_file, + when='D', + interval=1, + backupCount=14, + delay=True) + if level: + handler.level = level + + handler.setFormatter(formatter) + return handler + + @staticmethod + def init_logger(class_name): + with LoggerFactory.lock: + logger = LoggerFactory.new_logger(class_name) + handler = None + if class_name: + handler = LoggerFactory.get_handler(class_name) + logger.addHandler(handler) + LoggerFactory.logger_dict[class_name] = logger, handler + + else: + LoggerFactory.logger_dict["default"] = logger, handler + + LoggerFactory.assemble_global_handler(logger) + return logger, handler + + @staticmethod + def assemble_global_handler(logger): + if LoggerFactory.LOG_DIR: + for level in LoggerFactory.levels: + if level >= LoggerFactory.LEVEL: + level_logger_name = logging._levelToName[level] + logger.addHandler(LoggerFactory.get_global_handler(level_logger_name, level)) + if LoggerFactory.append_to_parent_log and LoggerFactory.PARENT_LOG_DIR: + for level in LoggerFactory.levels: + if level >= LoggerFactory.LEVEL: + level_logger_name = logging._levelToName[level] + logger.addHandler( + LoggerFactory.get_global_handler(level_logger_name, level, LoggerFactory.PARENT_LOG_DIR)) + + +def setDirectory(directory=None): + LoggerFactory.set_directory(directory) + + +def setLevel(level): + LoggerFactory.LEVEL = level + + +def getLogger(className=None, useLevelFile=False): + if className is None: + frame = inspect.stack()[1] + module = inspect.getmodule(frame[0]) + className = 'stat' + return LoggerFactory.get_logger(className) + + +def exception_to_trace_string(ex): + return "".join(traceback.TracebackException.from_exception(ex).format()) + + +class ROpenHandler(TimedRotatingFileHandler): + def _open(self): + prevumask = os.umask(000) + rtv = TimedRotatingFileHandler._open(self) + os.umask(prevumask) + return rtv + + +def sql_logger(job_id='', log_type='sql'): + key = job_id + log_type + if key in LoggerFactory.schedule_logger_dict.keys(): + return LoggerFactory.schedule_logger_dict[key] + return get_job_logger(job_id=job_id, log_type=log_type) + + +def ready_log(msg, job=None, task=None, role=None, party_id=None, detail=None): + prefix, suffix = base_msg(job, task, role, party_id, detail) + return f"{prefix}{msg} ready{suffix}" + + +def start_log(msg, job=None, task=None, role=None, party_id=None, detail=None): + prefix, suffix = base_msg(job, task, role, party_id, detail) + return f"{prefix}start to {msg}{suffix}" + + +def successful_log(msg, job=None, task=None, role=None, party_id=None, detail=None): + prefix, suffix = base_msg(job, task, role, party_id, detail) + return f"{prefix}{msg} successfully{suffix}" + + +def warning_log(msg, job=None, task=None, role=None, party_id=None, detail=None): + prefix, suffix = base_msg(job, task, role, party_id, detail) + return f"{prefix}{msg} is not effective{suffix}" + + +def failed_log(msg, job=None, task=None, role=None, party_id=None, detail=None): + prefix, suffix = base_msg(job, task, role, party_id, detail) + return f"{prefix}failed to {msg}{suffix}" + + +def base_msg(job=None, task=None, role: str = None, party_id: typing.Union[str, int] = None, detail=None): + if detail: + detail_msg = f" detail: \n{detail}" + else: + detail_msg = "" + if task is not None: + return f"task {task.f_task_id} {task.f_task_version} ", f" on {task.f_role} {task.f_party_id}{detail_msg}" + elif job is not None: + return "", f" on {job.f_role} {job.f_party_id}{detail_msg}" + elif role and party_id: + return "", f" on {role} {party_id}{detail_msg}" + else: + return "", f"{detail_msg}" + + +def exception_to_trace_string(ex): + return "".join(traceback.TracebackException.from_exception(ex).format()) + + +def get_logger_base_dir(): + job_log_dir = file_utils.get_fate_flow_directory('logs') + return job_log_dir + + +def get_job_logger(job_id, log_type): + fate_flow_log_dir = file_utils.get_fate_flow_directory('logs', 'fate_flow') + job_log_dir = file_utils.get_fate_flow_directory('logs', job_id) + if not job_id: + log_dirs = [fate_flow_log_dir] + else: + if log_type == 'audit': + log_dirs = [job_log_dir, fate_flow_log_dir] + else: + log_dirs = [job_log_dir] + if LoggerFactory.log_share: + oldmask = os.umask(000) + os.makedirs(job_log_dir, exist_ok=True) + os.makedirs(fate_flow_log_dir, exist_ok=True) + os.umask(oldmask) + else: + os.makedirs(job_log_dir, exist_ok=True) + os.makedirs(fate_flow_log_dir, exist_ok=True) + logger = LoggerFactory.new_logger(f"{job_id}_{log_type}") + for job_log_dir in log_dirs: + handler = LoggerFactory.get_handler(class_name=None, level=LoggerFactory.LEVEL, + log_dir=job_log_dir, log_type=log_type, job_id=job_id) + error_handler = LoggerFactory.get_handler(class_name=None, level=logging.ERROR, log_dir=job_log_dir, log_type=log_type, job_id=job_id) + logger.addHandler(handler) + logger.addHandler(error_handler) + with LoggerFactory.lock: + LoggerFactory.schedule_logger_dict[job_id + log_type] = logger + return logger + diff --git a/web_server/utils/t_crypt.py b/web_server/utils/t_crypt.py new file mode 100644 index 0000000000000000000000000000000000000000..1d007f49c35a589d6e39bccaf011f059949e4337 --- /dev/null +++ b/web_server/utils/t_crypt.py @@ -0,0 +1,18 @@ +import base64, os, sys +from Cryptodome.PublicKey import RSA +from Cryptodome.Cipher import PKCS1_v1_5 as Cipher_pkcs1_v1_5 +from web_server.utils import decrypt, file_utils + +def crypt(line): + file_path = os.path.join(file_utils.get_project_base_directory(), "conf", "public.pem") + rsa_key = RSA.importKey(open(file_path).read()) + cipher = Cipher_pkcs1_v1_5.new(rsa_key) + return base64.b64encode(cipher.encrypt(line.encode('utf-8'))).decode("utf-8") + + + +if __name__ == "__main__": + pswd = crypt(sys.argv[1]) + print(pswd) + print(decrypt(pswd)) + diff --git a/web_server/versions.py b/web_server/versions.py new file mode 100644 index 0000000000000000000000000000000000000000..e6e3cc135980d58fe394c0d38b29caf832cba708 --- /dev/null +++ b/web_server/versions.py @@ -0,0 +1,30 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import os + +import dotenv +import typing + +from web_server.utils.file_utils import get_project_base_directory + + +def get_versions() -> typing.Mapping[str, typing.Any]: + return dotenv.dotenv_values( + dotenv_path=os.path.join(get_project_base_directory(), "rag.env") + ) + +def get_fate_version() -> typing.Optional[str]: + return get_versions().get("RAG") \ No newline at end of file