Spaces:
Running
Running
from fastapi import FastAPI, UploadFile, File, HTTPException, Depends, Header | |
from fastapi.middleware.cors import CORSMiddleware | |
from pydantic import BaseModel | |
import google.generativeai as genai | |
from typing import List | |
import os | |
from dotenv import load_dotenv | |
import io | |
from datetime import datetime, timedelta | |
import uuid | |
import json | |
import re | |
# File Format Libraries | |
import PyPDF2 | |
import docx | |
import openpyxl | |
import csv | |
import io | |
import pptx | |
from db import get_db, Chat, ChatMessage, User, Document, SessionLocal | |
from pyqs import get_q_paper | |
from fastapi.security import OAuth2PasswordBearer | |
import requests | |
from jose import jwt | |
import random | |
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") | |
load_dotenv() | |
GOOGLE_CLIENT_ID = os.getenv('GOOGLE_CLIENT_ID') | |
GOOGLE_CLIENT_SECRET = os.getenv('GOOGLE_CLIENT_SECRET') | |
GOOGLE_REDIRECT_URI = os.getenv('GOOGLE_REDIRECT_URI') | |
api_keys = os.getenv('GEMINI_API_KEYS').split(',') | |
def parse_json_from_gemini(json_str: str): | |
try: | |
# Remove potential leading/trailing whitespace | |
json_str = json_str.strip() | |
# Extract JSON content from triple backticks and "json" language specifier | |
json_match = re.search(r"```json\s*(.*?)\s*```", json_str, re.DOTALL) | |
if json_match: | |
json_str = json_match.group(1) | |
return json.loads(json_str) | |
except (json.JSONDecodeError, AttributeError): | |
return None | |
load_dotenv() | |
app = FastAPI(title="EduScope AI") | |
# Configure CORS | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
async def login_google(): | |
return { | |
"url": f"https://accounts.google.com/o/oauth2/auth?response_type=code&client_id={GOOGLE_CLIENT_ID}&redirect_uri={GOOGLE_REDIRECT_URI}&scope=openid%20profile%20email&access_type=offline" | |
} | |
async def auth_google(code: str, db: SessionLocal = Depends(get_db)): | |
token_url = "https://accounts.google.com/o/oauth2/token" | |
data = { | |
"code": code, | |
"client_id": GOOGLE_CLIENT_ID, | |
"client_secret": GOOGLE_CLIENT_SECRET, | |
"redirect_uri": GOOGLE_REDIRECT_URI, | |
"grant_type": "authorization_code", | |
} | |
response = requests.post(token_url, data=data) | |
access_token = response.json().get("access_token") | |
user_info = requests.get("https://www.googleapis.com/oauth2/v1/userinfo", headers={"Authorization": f"Bearer {access_token}"}).json() | |
user = db.query(User).filter(User.id == user_info["id"]).first() | |
if not user: | |
user = User(id=user_info["id"], email=user_info["email"], name=user_info["name"]) | |
db.add(user) | |
db.commit() | |
return {"token": jwt.encode(user_info, GOOGLE_CLIENT_SECRET, algorithm="HS256")} | |
# return user_info.json() | |
async def decode_token(authorization: str = Header(...)): | |
if not authorization.startswith("Bearer "): | |
raise HTTPException( | |
status_code=400, | |
detail="Authorization header must start with 'Bearer '" | |
) | |
token = authorization[len("Bearer "):] # Extract token part | |
try: | |
# Decode and verify the JWT token | |
token_data = jwt.decode(token, GOOGLE_CLIENT_SECRET, algorithms=["HS256"]) | |
return token_data # Return decoded token data | |
except jwt.ExpiredSignatureError: | |
raise HTTPException(status_code=401, detail="Token has expired") | |
except jwt.InvalidTokenError: | |
raise HTTPException(status_code=401, detail="Invalid token") | |
async def get_token(user_data: dict = Depends(decode_token)): | |
return user_data | |
async def create_chat(title: str, user_data: dict = Depends(decode_token), db: SessionLocal = Depends(get_db)): | |
user_id = user_data["id"] | |
chat = Chat(chat_id=str(uuid.uuid4()), user_id=user_id, title=title) | |
db.add(chat) | |
db.commit() | |
return {"chat_id": chat.chat_id, "title": title, "timestamp": chat.timestamp} | |
async def get_chats(user_data: dict = Depends(decode_token), db: SessionLocal = Depends(get_db)): | |
user_id = user_data["id"] | |
chats = db.query(Chat).filter(Chat.user_id == user_id).all() | |
return [{"chat_id": chat.chat_id, "title": chat.title, "timestamp": chat.timestamp} for chat in chats] | |
class DocumentSchema(BaseModel): | |
id: str | |
name: str | |
timestamp: str | |
class Query(BaseModel): | |
text: str | |
selected_docs: List[str] | |
class ChatMessageSchema(BaseModel): | |
id: str | |
type: str # 'user' or 'assistant' | |
content: str | |
timestamp: str | |
referenced_docs: List[str] = [] | |
class Analysis(BaseModel): | |
insight: str | |
pareto_analysis: dict | |
def extract_text_from_file(file: UploadFile): | |
""" | |
Extract text from various file types | |
Supports: PDF, DOCX, XLSX, CSV, TXT, PPTX | |
""" | |
file_extension = os.path.splitext(file.filename)[1].lower() | |
content = file.file.read() | |
print(file_extension) | |
try: | |
if file_extension == '.pdf': | |
pdf_reader = PyPDF2.PdfReader(io.BytesIO(content)) | |
text = "\n".join([page.extract_text() for page in pdf_reader.pages]) | |
elif file_extension == '.docx': | |
doc = docx.Document(io.BytesIO(content)) | |
text = "\n".join([para.text for para in doc.paragraphs]) | |
elif file_extension == '.xlsx': | |
wb = openpyxl.load_workbook(io.BytesIO(content), read_only=True) | |
text = "" | |
for sheet in wb: | |
for row in sheet.iter_rows(values_only=True): | |
text += " ".join(str(cell) for cell in row if cell is not None) + "\n" | |
elif file_extension == '.csv': | |
csv_reader = csv.reader(io.StringIO(content.decode('utf-8'))) | |
text = "\n".join([" ".join(row) for row in csv_reader]) | |
elif file_extension == '.txt': | |
text = content.decode('utf-8') | |
elif file_extension in ['.ppt', '.pptx']: | |
ppt = pptx.Presentation(io.BytesIO(content)) | |
text = "" | |
for slide in ppt.slides: | |
for shape in slide.shapes: | |
if hasattr(shape, "text"): | |
text += shape.text + "\n" | |
else: | |
raise ValueError(f"Unsupported file type: {file_extension}") | |
return text | |
except Exception as e: | |
raise HTTPException(status_code=400, detail=f"Error processing file: {str(e)}") | |
async def search_by_subject_code(subject_code: str, user_data: dict = Depends(decode_token)): | |
codes = requests.get(f"https://cl.thapar.edu/search1.php?term={subject_code}",verify=False).json() | |
return codes | |
async def import_q_papers(chat_id: str, subject_code: str, user_data: dict = Depends(decode_token), db: SessionLocal = Depends(get_db)): | |
user_id = user_data["id"] | |
chat = db.query(Chat).filter(Chat.chat_id == chat_id, Chat.user_id == user_id).first() | |
if not chat: | |
raise HTTPException(status_code=404, detail="Chat not found") | |
q_papers = get_q_paper(subject_code) | |
if not q_papers: | |
raise HTTPException(status_code=404, detail="No question papers found for the given subject code") | |
for paper in q_papers: | |
download_link = paper["DownloadLink"] | |
response = requests.get(download_link, verify=False) | |
if response.status_code != 200: | |
raise HTTPException(status_code=500, detail=f"Failed to download the paper from {download_link}") | |
try: | |
pdf_reader = PyPDF2.PdfReader(io.BytesIO(response.content)) | |
text = "\n".join([page.extract_text() for page in pdf_reader.pages]) | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"Failed to process PDF: {str(e)}") | |
title = f"{paper['CourseName']}_{paper['Year']}_{paper['Semester']}_{paper['ExamType']}..pdf" | |
doc_id = str(uuid.uuid4()) | |
document = Document( | |
id=doc_id, | |
chat_id=chat_id, | |
name=title, | |
content=text, | |
timestamp=datetime.now() | |
) | |
db.add(document) | |
db.commit() | |
return {"message": "Question papers imported successfully"} | |
async def upload_document(chat_id: str, file: UploadFile = File(...), user_data: dict = Depends(decode_token), db: SessionLocal = Depends(get_db)): | |
user_id = user_data["id"] | |
# Check if the chat exists and belongs to the user | |
chat = db.query(Chat).filter(Chat.chat_id == chat_id, Chat.user_id == user_id).first() | |
if not chat: | |
raise HTTPException(status_code=404, detail="Chat not found") | |
try: | |
text = extract_text_from_file(file) | |
doc_id = str(uuid.uuid4()) | |
document = Document( | |
id=doc_id, | |
chat_id=chat_id, | |
name=file.filename, | |
content=text, | |
timestamp=datetime.now() | |
) | |
db.add(document) | |
db.commit() | |
db.refresh(document) | |
return { | |
"id": document.id, | |
"name": document.name, | |
"timestamp": document.timestamp.isoformat() | |
} | |
except HTTPException as e: | |
raise e | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"Unexpected error: {str(e)}") | |
async def get_documents(chat_id: str, user_data: dict = Depends(decode_token), db: SessionLocal = Depends(get_db)): | |
user_id = user_data["id"] | |
chat = db.query(Chat).filter(Chat.chat_id == chat_id, Chat.user_id == user_id).first() | |
if not chat: | |
raise HTTPException(status_code=404, detail="Chat not found") | |
documents = db.query(Document).filter(Document.chat_id == chat_id).all() | |
return [{ | |
"id": doc.id, | |
"name": doc.name, | |
"timestamp": doc.timestamp.isoformat() | |
} for doc in documents] | |
async def analyze_text(chat_id: str, query: Query, user_data: dict = Depends(decode_token), db: SessionLocal = Depends(get_db)): | |
user_id = user_data["id"] | |
# Check if the chat exists and belongs to the user | |
chat = db.query(Chat).filter(Chat.chat_id == chat_id, Chat.user_id == user_id).first() | |
if not chat: | |
raise HTTPException(status_code=404, detail="Chat not found") | |
# Fetch documents | |
docs = db.query(Document).filter(Document.chat_id == chat_id, Document.id.in_(query.selected_docs)).all() | |
if not docs: | |
raise HTTPException(status_code=400, detail="No documents found for analysis") | |
# Combine content from selected documents | |
combined_context = "\n\n".join([ | |
f"Document '{doc.name}':\n{doc.content}" for doc in docs | |
]) | |
prompt = f""" | |
Analyze the following text in the context of this query: {query.text} | |
Context from multiple documents: | |
{combined_context} | |
Provide: | |
1. Detailed insights and analysis, comparing information across documents when relevant | |
2. Apply the Pareto Principle (80/20 rule) to identify the most important aspects | |
Format the response as JSON with 'insight' and 'pareto_analysis' keys. | |
Example format: | |
{{ | |
"insight": "Key findings and analysis from the documents based on query...", | |
"pareto_analysis": {{ | |
"vital_few": "The 20% of factors that drive 80% of the impact...", | |
"trivial_many": "The remaining 80% of factors that contribute 20% of the impact..." | |
}} | |
}} | |
also give a complete html document with a intreactive quiz (minimum 5 questions) using jquery and also a flashcards to help the user understand the content better. | |
""" | |
api_key = random.choice(api_keys) | |
genai.configure(api_key=api_key) | |
print("Selected API Key: ", api_key) | |
model = genai.GenerativeModel('gemini-1.5-flash') | |
response = model.generate_content(prompt) | |
response_text = response.text | |
# Save user message | |
user_message = ChatMessage( | |
id=str(uuid.uuid4()), | |
chat_id=chat_id, | |
type="user", | |
content=query.text, | |
timestamp=datetime.now(), | |
referenced_docs=json.dumps(query.selected_docs) | |
) | |
db.add(user_message) | |
# Parse analysis | |
analysis = parse_json_from_gemini(response_text) | |
# Save assistant message | |
assistant_message = ChatMessage( | |
id=str(uuid.uuid4()), | |
chat_id=chat_id, | |
type="assistant", | |
content=json.dumps(analysis, indent=4), | |
timestamp=datetime.now() -timedelta(seconds=3), | |
referenced_docs=json.dumps(query.selected_docs) | |
) | |
db.add(assistant_message) | |
if '```html' in response_text: | |
html = response_text.split('```html')[1] | |
html = html.split('```')[0] | |
html = html.strip() | |
assistant_message_1 = ChatMessage( | |
id=str(uuid.uuid4()), | |
chat_id=chat_id, | |
type="assistant", | |
content=html, | |
timestamp=datetime.now(), | |
referenced_docs=json.dumps(query.selected_docs) | |
) | |
db.add(assistant_message_1) | |
db.commit() | |
return analysis | |
async def get_chat_history(chat_id: str, user_data: dict = Depends(decode_token), db: SessionLocal = Depends(get_db)): | |
user_id = user_data["id"] | |
# Check if the chat exists and belongs to the user | |
chat = db.query(Chat).filter(Chat.chat_id == chat_id, Chat.user_id == user_id).first() | |
if not chat: | |
raise HTTPException(status_code=404, detail="Chat not found") | |
messages = db.query(ChatMessage).filter(ChatMessage.chat_id == chat_id).order_by(ChatMessage.timestamp).all() | |
return [{ | |
"id": msg.id, | |
"type": msg.type, | |
"content": msg.content, | |
"timestamp": msg.timestamp.isoformat(), | |
"referenced_docs": json.loads(msg.referenced_docs) if msg.referenced_docs else [] | |
} for msg in messages] | |
async def clear_chat(chat_id: str, user_data: dict = Depends(decode_token), db: SessionLocal = Depends(get_db)): | |
user_id = user_data["id"] | |
chat = db.query(Chat).filter(Chat.chat_id == chat_id, Chat.user_id == user_id).first() | |
if not chat: | |
raise HTTPException(status_code=404, detail="Chat not found") | |
# Delete documents and messages | |
db.query(Document).filter(Document.chat_id == chat_id).delete() | |
db.query(ChatMessage).filter(ChatMessage.chat_id == chat_id).delete() | |
db.commit() | |
return {"message": "Chat cleared successfully"} | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=8000) | |