Spaces:
Sleeping
Sleeping
| import os | |
| import sqlite3 | |
| import secrets | |
| import hashlib | |
| import spaces | |
| import time | |
| from argon2 import PasswordHasher | |
| from cryptography.fernet import Fernet | |
| from transformers import AutoTokenizer, AutoModel | |
| import torch | |
| import numpy as np | |
| # Initialize global variables | |
| TOKEN = os.getenv("HF_TOKEN") | |
| MODEL_NAME = os.getenv("SECRET_M") | |
| ADMIN_USERNAME = os.getenv("ADMIN_USERNAME") | |
| ADMIN_PASSWORD = os.getenv("ADMIN_PASSWORD") | |
| tokenizer = None | |
| model = None | |
| # Initialize Argon2 hasher and Fernet cipher | |
| ph = PasswordHasher() | |
| cipher_key = Fernet.generate_key() | |
| cipher = Fernet(cipher_key) | |
| def get_db_connection(): | |
| conn = sqlite3.connect('database/grimvault.db') | |
| conn.row_factory = sqlite3.Row | |
| return conn | |
| def create_tables(conn): | |
| c = conn.cursor() | |
| c.execute('''CREATE TABLE IF NOT EXISTS users | |
| (username TEXT PRIMARY KEY, password_hash TEXT, embedding_hash TEXT, | |
| salt TEXT, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP)''') | |
| c.execute('''CREATE TABLE IF NOT EXISTS files | |
| (id INTEGER PRIMARY KEY, username TEXT, filename TEXT, | |
| content BLOB, size INTEGER, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP)''') | |
| conn.commit() | |
| def get_embedding(text): | |
| global tokenizer, model | |
| if tokenizer is None or model is None: | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
| model = AutoModel.from_pretrained(MODEL_NAME, torch_dtype=torch.float16) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| model.resize_token_embeddings(len(tokenizer)) | |
| inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| return outputs.last_hidden_state.mean(dim=1).squeeze().numpy() | |
| def hash_embedding(embedding, salt): | |
| salted_embedding = np.concatenate([embedding, np.frombuffer(salt, dtype=np.float32)]) | |
| return hashlib.sha256(salted_embedding.tobytes()).hexdigest() | |
| def create_user(username, password): | |
| conn = get_db_connection() | |
| c = conn.cursor() | |
| # Check if username already exists | |
| c.execute("SELECT * FROM users WHERE username = ?", (username,)) | |
| if c.fetchone(): | |
| conn.close() | |
| return "Username already exists." | |
| # Generate salt and create password hash | |
| salt = secrets.token_bytes(16) | |
| password_hash = ph.hash(password + salt.hex()) | |
| # Generate embedding and hash it | |
| embedding = get_embedding(password) | |
| embedding_hash = hash_embedding(embedding, salt) | |
| # Store user data | |
| c.execute("INSERT INTO users (username, password_hash, embedding_hash, salt) VALUES (?, ?, ?, ?)", | |
| (username, password_hash, embedding_hash, salt)) | |
| conn.commit() | |
| conn.close() | |
| return "User created successfully." | |
| def verify_user(username, password): | |
| conn = get_db_connection() | |
| c = conn.cursor() | |
| c.execute("SELECT * FROM users WHERE username = ?", (username,)) | |
| user = c.fetchone() | |
| conn.close() | |
| if not user: | |
| return False | |
| try: | |
| # Verify password | |
| ph.verify(user['password_hash'], password + user['salt'].hex()) | |
| # Verify embedding | |
| embedding = get_embedding(password) | |
| embedding_hash = hash_embedding(embedding, user['salt']) | |
| if embedding_hash != user['embedding_hash']: | |
| return False | |
| return True | |
| except: | |
| return False | |
| def get_user_files(username): | |
| conn = get_db_connection() | |
| c = conn.cursor() | |
| c.execute("SELECT filename, size FROM files WHERE username = ?", (username,)) | |
| files = c.fetchall() | |
| conn.close() | |
| return files | |
| def upload_file(username, filename, content): | |
| conn = get_db_connection() | |
| c = conn.cursor() | |
| # Check if file already exists | |
| c.execute("SELECT * FROM files WHERE username = ? AND filename = ?", (username, filename)) | |
| if c.fetchone(): | |
| conn.close() | |
| return f"File {filename} already exists." | |
| # Insert file data | |
| c.execute("INSERT INTO files (username, filename, content, size) VALUES (?, ?, ?, ?)", | |
| (username, filename, content, len(content))) | |
| conn.commit() | |
| conn.close() | |
| return f"File {filename} uploaded successfully." | |
| def download_file(username, filename): | |
| conn = get_db_connection() | |
| c = conn.cursor() | |
| c.execute("SELECT content FROM files WHERE username = ? AND filename = ?", (username, filename)) | |
| file = c.fetchone() | |
| conn.close() | |
| if file: | |
| return file['content'] | |
| else: | |
| return None | |
| def delete_file(username, filename): | |
| conn = get_db_connection() | |
| c = conn.cursor() | |
| c.execute("DELETE FROM files WHERE username = ? AND filename = ?", (username, filename)) | |
| if c.rowcount == 0: | |
| conn.close() | |
| return f"File {filename} not found." | |
| conn.commit() | |
| conn.close() | |
| return f"File {filename} deleted successfully." | |
| def empty_vault(username): | |
| conn = get_db_connection() | |
| c = conn.cursor() | |
| c.execute("DELETE FROM files WHERE username = ?", (username,)) | |
| conn.commit() | |
| conn.close() | |
| return "All files in your vault have been deleted." | |
| def is_admin(username): | |
| return username == ADMIN_USERNAME | |
| def get_all_accounts(): | |
| conn = get_db_connection() | |
| c = conn.cursor() | |
| c.execute("SELECT username, created_at FROM users") | |
| accounts = c.fetchall() | |
| conn.close() | |
| return accounts | |
| def delete_account(username): | |
| if username == ADMIN_USERNAME: | |
| return "Cannot delete admin account." | |
| conn = get_db_connection() | |
| c = conn.cursor() | |
| c.execute("DELETE FROM users WHERE username = ?", (username,)) | |
| c.execute("DELETE FROM files WHERE username = ?", (username,)) | |
| conn.commit() | |
| conn.close() | |
| return f"Account {username} and all associated files have been deleted." | |
| def encrypt_file(filename, content): | |
| return cipher.encrypt(content) | |
| def decrypt_file(filename, encrypted_content): | |
| return cipher.decrypt(encrypted_content) | |
| # Rate limiting | |
| RATE_LIMIT = 5 # maximum number of requests per minute | |
| rate_limit_dict = {} | |
| def is_rate_limited(username): | |
| current_time = time.time() | |
| if username in rate_limit_dict: | |
| last_request_time, count = rate_limit_dict[username] | |
| if current_time - last_request_time < 60: # within 1 minute | |
| if count >= RATE_LIMIT: | |
| return True | |
| rate_limit_dict[username] = (last_request_time, count + 1) | |
| else: | |
| rate_limit_dict[username] = (current_time, 1) | |
| else: | |
| rate_limit_dict[username] = (current_time, 1) | |
| return False | |
| # Account lockout | |
| MAX_LOGIN_ATTEMPTS = 5 | |
| LOCKOUT_TIME = 300 # 5 minutes | |
| lockout_dict = {} | |
| def is_account_locked(username): | |
| if username in lockout_dict: | |
| attempts, lockout_time = lockout_dict[username] | |
| if attempts >= MAX_LOGIN_ATTEMPTS: | |
| if time.time() - lockout_time < LOCKOUT_TIME: | |
| return True | |
| else: | |
| del lockout_dict[username] | |
| return False | |
| def record_login_attempt(username, success): | |
| if username not in lockout_dict: | |
| lockout_dict[username] = [0, 0] | |
| if success: | |
| del lockout_dict[username] | |
| else: | |
| lockout_dict[username][0] += 1 | |
| lockout_dict[username][1] = time.time() |