Spaces:
Runtime error
Runtime error
import streamlit as st | |
from PyPDF2 import PdfReader | |
from transformers import pipeline, AutoTokenizer, AutoModel | |
from sklearn.feature_extraction.text import TfidfVectorizer | |
import faiss | |
import numpy as np | |
# Load the Hugging Face model for text generation | |
def load_text_generator(): | |
return pipeline("text2text-generation", model="google/flan-t5-base") | |
# Load the Hugging Face model for embeddings | |
def load_embedding_model(): | |
tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2") | |
model = AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2") | |
return tokenizer, model | |
text_generator = load_text_generator() | |
embedding_tokenizer, embedding_model = load_embedding_model() | |
# Function to extract text from PDF | |
def extract_pdf_content(pdf_file): | |
reader = PdfReader(pdf_file) | |
content = "" | |
for page in reader.pages: | |
content += page.extract_text() | |
return content | |
# Function to split content into chunks | |
def chunk_text(text, chunk_size=500): | |
words = text.split() | |
return [" ".join(words[i:i + chunk_size]) for i in range(0, len(words), chunk_size)] | |
# Function to compute embeddings | |
def compute_embeddings(text_chunks): | |
embeddings = [] | |
for chunk in text_chunks: | |
inputs = embedding_tokenizer(chunk, return_tensors="pt", truncation=True, padding=True) | |
outputs = embedding_model(**inputs) | |
embeddings.append(outputs.pooler_output.detach().numpy()[0]) | |
return np.array(embeddings) | |
# Function to build FAISS index | |
def build_faiss_index(embeddings): | |
dimension = embeddings.shape[1] | |
index = faiss.IndexFlatL2(dimension) # L2 distance for similarity | |
index.add(embeddings) | |
return index | |
# Function to search in FAISS index | |
def search_faiss_index(index, query_embedding, text_chunks, top_k=3): | |
distances, indices = index.search(query_embedding, top_k) | |
return [(text_chunks[idx], distances[0][i]) for i, idx in enumerate(indices[0])] | |
# Function to generate structured content | |
def generate_professional_content(topic): | |
prompt = f"Explain '{topic}' in bullet points, highlighting key concepts, examples, and applications." | |
response = text_generator(prompt, max_length=300, num_return_sequences=1) | |
return response[0]['generated_text'] | |
# Function to compute query embedding | |
def compute_query_embedding(query): | |
inputs = embedding_tokenizer(query, return_tensors="pt", truncation=True, padding=True) | |
outputs = embedding_model(**inputs) | |
return outputs.pooler_output.detach().numpy() | |
# Streamlit app | |
st.title("Generative AI for Electrical Engineering Education with FAISS") | |
st.sidebar.header("AI-Based Tutor with Vector Search") | |
# File upload section | |
uploaded_file = st.sidebar.file_uploader("Upload Study Material (PDF)", type=["pdf"]) | |
topic = st.sidebar.text_input("Enter a topic (e.g., Newton's Third Law)") | |
if uploaded_file: | |
# Extract and process file content | |
content = extract_pdf_content(uploaded_file) | |
st.sidebar.success(f"{uploaded_file.name} uploaded successfully!") | |
# Chunk and compute embeddings | |
chunks = chunk_text(content) | |
embeddings = compute_embeddings(chunks) | |
# Build FAISS index | |
index = build_faiss_index(embeddings) | |
st.write("**File Processed and Indexed for Search**") | |
st.write(f"Total chunks created: {len(chunks)}") | |
# Generate study material | |
if st.button("Generate Study Material"): | |
if topic: | |
st.header(f"Study Material: {topic}") | |
# Compute query embedding | |
query_embedding = compute_query_embedding(topic) | |
# Search FAISS index | |
if uploaded_file: | |
results = search_faiss_index(index, query_embedding, chunks, top_k=3) | |
st.write("**Relevant Content from Uploaded File:**") | |
for result, distance in results: | |
st.write(f"- {result} (Similarity: {distance:.2f})") | |
else: | |
st.warning("No file uploaded. Generating AI-based content instead.") | |
# Generate AI content | |
ai_content = generate_professional_content(topic) | |
st.write("**AI-Generated Content:**") | |
st.write(ai_content) | |
else: | |
st.warning("Please enter a topic!") | |