Spaces:
Running
Running
# main.py | |
from fastapi import FastAPI, UploadFile, File, HTTPException,Depends,Header | |
from fastapi.middleware.cors import CORSMiddleware | |
from pydantic import BaseModel | |
import google.generativeai as genai | |
from typing import List, Dict | |
import os | |
from dotenv import load_dotenv | |
import io | |
from datetime import datetime | |
import uuid | |
import json | |
import re | |
# File Format Libraries | |
import PyPDF2 | |
import docx | |
import openpyxl | |
import csv | |
import io | |
import pptx | |
from db import get_db,Chat,ChatMessage,User,SessionLocal | |
from fastapi.security import OAuth2PasswordBearer | |
import requests | |
from jose import jwt | |
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") | |
DOMAIN = "http://localhost:8000" | |
# Replace these with your own values from the Google Developer Console | |
GOOGLE_CLIENT_ID = "862058885628-e6mjev28p8e112qrp9gnn4q8mlif3bbf.apps.googleusercontent.com" | |
GOOGLE_CLIENT_SECRET = "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" | |
GOOGLE_REDIRECT_URI = "http://localhost:5173/callback" | |
def parse_json_from_gemini(json_str: str): | |
try: | |
# Remove potential leading/trailing whitespace | |
json_str = json_str.strip() | |
# Extract JSON content from triple backticks and "json" language specifier | |
json_match = re.search(r"```json\s*(.*?)\s*```", json_str, re.DOTALL) | |
if json_match: | |
json_str = json_match.group(1) | |
return json.loads(json_str) | |
except (json.JSONDecodeError, AttributeError): | |
return None | |
load_dotenv() | |
app = FastAPI(title="EduScope AI") | |
# Configure CORS | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
async def login_google(): | |
return { | |
"url": f"https://accounts.google.com/o/oauth2/auth?response_type=code&client_id={GOOGLE_CLIENT_ID}&redirect_uri={GOOGLE_REDIRECT_URI}&scope=openid%20profile%20email&access_type=offline" | |
} | |
async def auth_google(code: str, db: SessionLocal = Depends(get_db)): | |
token_url = "https://accounts.google.com/o/oauth2/token" | |
data = { | |
"code": code, | |
"client_id": GOOGLE_CLIENT_ID, | |
"client_secret": GOOGLE_CLIENT_SECRET, | |
"redirect_uri": GOOGLE_REDIRECT_URI, | |
"grant_type": "authorization_code", | |
} | |
response = requests.post(token_url, data=data) | |
access_token = response.json().get("access_token") | |
user_info = requests.get("https://www.googleapis.com/oauth2/v1/userinfo", headers={"Authorization": f"Bearer {access_token}"}).json() | |
user = db.query(User).filter(User.id == user_info["id"]).first() | |
if not user: | |
user = User(id=user_info["id"], email=user_info["email"], name=user_info["name"]) | |
db.add(user) | |
db.commit() | |
return {"token": jwt.encode(user_info, GOOGLE_CLIENT_SECRET, algorithm="HS256")} | |
# return user_info.json() | |
async def decode_token(authorization: str = Header(...)): | |
if not authorization.startswith("Bearer "): | |
raise HTTPException( | |
status_code=400, | |
detail="Authorization header must start with 'Bearer '" | |
) | |
token = authorization[len("Bearer "):] # Extract token part | |
try: | |
# Decode and verify the JWT token | |
token_data = jwt.decode(token, GOOGLE_CLIENT_SECRET, algorithms=["HS256"]) | |
return token_data # Return decoded token data | |
except jwt.ExpiredSignatureError: | |
raise HTTPException(status_code=401, detail="Token has expired") | |
except jwt.InvalidTokenError: | |
raise HTTPException(status_code=401, detail="Invalid token") | |
async def get_token(user_data: dict = Depends(decode_token)): | |
return user_data | |
async def create_chat(title: str,user_data: dict = Depends(decode_token), db: SessionLocal = Depends(get_db)): | |
user_id = user_data["id"] | |
chat = Chat(chat_id=str(uuid.uuid4()), user_id=user_id, title=title) | |
db.add(chat) | |
db.commit() | |
return {"chat_id": chat.chat_id, "title": title, "timestamp": chat.timestamp} | |
async def get_chats(user_data: dict = Depends(decode_token), db: SessionLocal = Depends(get_db)): | |
user_id = user_data["id"] | |
chats = db.query(Chat).filter(Chat.user_id == user_id).all() | |
return [{"chat_id": chat.chat_id, "title": chat.title, "timestamp": chat.timestamp} for chat in chats] | |
genai.configure(api_key="AIzaSyDZsN3hnnNQOBLSAznFh7xWbWKNohvqff0") | |
model = genai.GenerativeModel('gemini-1.5-flash') | |
documents = {} | |
chat_history = [] | |
class Document(BaseModel): | |
id: str | |
name: str | |
content: str | |
timestamp: str | |
class Query(BaseModel): | |
text: str | |
selected_docs: List[str] | |
class ChatMessage(BaseModel): | |
id: str | |
type: str # 'user' or 'assistant' | |
content: str | |
timestamp: str | |
referenced_docs: List[str] = [] | |
class Analysis(BaseModel): | |
insight: str | |
pareto_analysis: dict | |
def extract_text_from_file(file: UploadFile): | |
""" | |
Extract text from various file types | |
Supports: PDF, DOCX, XLSX, CSV, TXT | |
""" | |
file_extension = os.path.splitext(file.filename)[1].lower() | |
content = file.file.read() | |
try: | |
if file_extension == '.pdf': | |
pdf_reader = PyPDF2.PdfReader(io.BytesIO(content)) | |
text = "\n".join([page.extract_text() for page in pdf_reader.pages]) | |
elif file_extension == '.docx': | |
doc = docx.Document(io.BytesIO(content)) | |
text = "\n".join([para.text for para in doc.paragraphs]) | |
elif file_extension == '.xlsx': | |
wb = openpyxl.load_workbook(io.BytesIO(content), read_only=True) | |
text = "" | |
for sheet in wb: | |
for row in sheet.iter_rows(values_only=True): | |
text += " ".join(str(cell) for cell in row if cell is not None) + "\n" | |
elif file_extension == '.csv': | |
csv_reader = csv.reader(io.StringIO(content.decode('utf-8'))) | |
text = "\n".join([" ".join(row) for row in csv_reader]) | |
elif file_extension == '.txt': | |
text = content.decode('utf-8') | |
elif file_extension in ['.ppt', '.pptx']: | |
ppt = pptx.Presentation(io.BytesIO(content)) | |
text = "" | |
for slide in ppt.slides: | |
for shape in slide.shapes: | |
if hasattr(shape, "text"): | |
text += shape.text + "\n" | |
else: | |
raise ValueError(f"Unsupported file type: {file_extension}") | |
return text | |
except Exception as e: | |
raise HTTPException(status_code=400, detail=f"Error processing file: {str(e)}") | |
async def upload_document(file: UploadFile = File(...)): | |
try: | |
text = extract_text_from_file(file) | |
doc_id = str(uuid.uuid4()) | |
document = Document( | |
id=doc_id, | |
name=file.filename, | |
content=text, | |
timestamp=datetime.now().isoformat() | |
) | |
documents[doc_id] = document | |
return document.dict() | |
except HTTPException as e: | |
raise e | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"Unexpected error: {str(e)}") | |
async def get_documents(): | |
return list(documents.values()) | |
async def analyze_text(query: Query): | |
# try: | |
# Combine content from selected documents | |
combined_context = "\n\n".join([ | |
f"Document '{documents[doc_id].name}':\n{documents[doc_id].content}" | |
for doc_id in query.selected_docs | |
]) | |
prompt = f""" | |
Analyze the following text in the context of this query: {query.text} | |
Context from multiple documents: | |
{combined_context} | |
Provide: | |
1. Detailed insights and analysis, comparing information across documents when relevant | |
2. Apply the Pareto Principle (80/20 rule) to identify the most important aspects | |
Format the response as JSON with 'insight' and 'pareto_analysis' keys. | |
Example format: | |
{{ | |
"insight": "Key findings and analysis from the documents...", | |
"pareto_analysis": {{ | |
"vital_few": "The 20% of factors that drive 80% of the impact...", | |
"trivial_many": "The remaining 80% of factors that contribute 20% of the impact..." | |
}} | |
}} | |
also give a complete html document with the illustrative analysis like pie charts, bar charts,graphs etc. | |
""" | |
response = model.generate_content(prompt) | |
response_text = response.text | |
# print(response_text) | |
# Create chat message | |
message = ChatMessage( | |
id=str(uuid.uuid4()), | |
type="user", | |
content=query.text, | |
timestamp=datetime.now().isoformat(), | |
referenced_docs=query.selected_docs | |
) | |
chat_history.append(message) | |
# print(response_text) | |
# Create assistant response | |
# analysis = { | |
# "insight": response_text.split("Pareto Analysis:")[0].strip(), | |
# "pareto_analysis": { | |
# "vital_few": response_text.split("Vital Few (20%):")[1].split("Trivial Many")[0].strip(), | |
# "trivial_many": response_text.split("Trivial Many (80%):")[1].strip() | |
# } | |
# } | |
analysis = parse_json_from_gemini(response_text) | |
assistant_message = ChatMessage( | |
id=str(uuid.uuid4()), | |
type="assistant", | |
content=json.dumps(analysis, indent=4), | |
timestamp=datetime.now().isoformat(), | |
referenced_docs=query.selected_docs | |
) | |
chat_history.append(assistant_message) | |
if '```html' in response_text: | |
html = response_text.split('```html')[1] | |
html = html.split('```')[0] | |
html = html.strip() | |
assistant_message = ChatMessage( | |
id=str(uuid.uuid4()), | |
type="assistant", | |
content=html, | |
timestamp=datetime.now().isoformat(), | |
referenced_docs=query.selected_docs | |
) | |
chat_history.append(assistant_message) | |
return analysis | |
# except Exception as e: | |
# raise HTTPException(status_code=500, detail=str(e)) | |
async def get_chat_history(): | |
return chat_history | |
async def clear_all(): | |
chat_history.clear() | |
documents.clear() | |
return {"message": "All Data cleared successfully"} | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=8000) |