Spaces:
Sleeping
Sleeping
# other imports | |
import streamlit as st, os | |
import spaces | |
import tempfile | |
from pytube import YouTube | |
import whisper | |
from dotenv import load_dotenv | |
# langchain imports | |
from langchain_community.vectorstores import FAISS | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain_openai import OpenAIEmbeddings | |
from langchain_openai import OpenAI | |
from langchain_core.prompts import ChatPromptTemplate | |
from langchain_core.output_parsers import StrOutputParser | |
from langchain_core.runnables import RunnablePassthrough | |
# initialize session state variables if not already initialized | |
if 'chain' not in st.session_state: | |
st.session_state.chain = None | |
if 'transcription' not in st.session_state: | |
st.session_state.transcription = None | |
if 'thumbnail_url' not in st.session_state: | |
st.session_state.thumbnail_url = None | |
# function for getting transcribing YouTube video | |
def get_yt_trans(yt_obj): | |
video_audio = yt_obj.streams.filter(only_audio=True).first() | |
# define model for audio to text conversion | |
audio_to_text_model = whisper.load_model('base') | |
# store transcription in local file | |
with tempfile.TemporaryDirectory() as tmpdir: | |
audio_file = video_audio.download(output_path=tmpdir) | |
st.session_state.transcription = audio_to_text_model.transcribe(audio_file, fp16=False)["text"].strip() | |
# with open("transcription.txt", "w") as file: | |
# file.write(transcription) | |
# function for chunking transcription text | |
def get_vid_text_chunks(transcription_text): | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=300, chunk_overlap=20) | |
chunked_text = text_splitter.split_text(transcription_text) | |
return chunked_text | |
# function for creating vector db | |
def create_vecdb(chunks): | |
vecdb = FAISS.from_texts(chunks, OpenAIEmbeddings()) | |
# vecdb.save_local('vec_store') ## if needed to store and retrieve locally | |
return vecdb | |
# define api keys | |
load_dotenv() | |
os.environ['OPEN_API_KEY'] = os.getenv('OPENAI_API_KEY') | |
# define LLM | |
llm = OpenAI(temperature=0) | |
# Set page configs | |
st.set_page_config( | |
page_title="YouTube Talks", | |
page_icon='📽️' | |
) | |
st.header('Query YouTube Videos!') | |
thumbnail_placeholder = st.empty() | |
st.sidebar.header("URL details:") | |
video_url = st.sidebar.text_input("Enter video URL") | |
submit_button = st.sidebar.button("Submit") | |
user_query = st.text_input("Query the Video!") | |
video_query_button = st.button("Ask video!") | |
progress_updates = st.sidebar.empty() | |
main_ph = st.empty() | |
# on button press | |
if submit_button: | |
main_ph.text("Transcribing Video, please wait...") | |
progress_updates.text("Transcribing Video...") | |
# get video transcription | |
yt_obj = YouTube(video_url) | |
# display video thumbnail | |
st.session_state.thumbnail_url = yt_obj.thumbnail_url | |
thumbnail_placeholder.image(st.session_state.thumbnail_url) | |
# get transcription | |
get_yt_trans(yt_obj) | |
st.sidebar.subheader("Transcription:") | |
st.sidebar.write(st.session_state.transcription) | |
progress_updates.text("Making Text Chunks...") | |
# get text chunks | |
chunks = get_vid_text_chunks(st.session_state.transcription) | |
progress_updates.text("Creating Vector DB...") | |
# create vector db | |
vector_db = create_vecdb(chunks) | |
# define main prompt | |
p_template = '''Answer the question based on the context below. If you can't | |
answer the question, reply "I don't know". | |
Context: {context} | |
Question: {question}''' | |
prompt = ChatPromptTemplate.from_template(p_template) | |
# define output parser | |
parser = StrOutputParser() | |
# define main chain | |
st.session_state.chain = {"context": vector_db.as_retriever(), "question": RunnablePassthrough()} | prompt | llm | parser | |
progress_updates.text("Video Transcribed Successfully!!!") | |
main_ph.text("Video Transcribed Successfully!") | |
if video_query_button: | |
# keep displaying Transcription and Thumbnail in window | |
st.sidebar.subheader("Transcription:") | |
st.sidebar.write(st.session_state.transcription) | |
thumbnail_placeholder.image(st.session_state.thumbnail_url) | |
# if video not transcribed display error | |
if st.session_state.chain is None: | |
st.error("Please transcribe a video first by submitting a URL.") | |
else: | |
main_ph.text("Fetching Results...") | |
# print results | |
st.subheader("Result:") | |
main_ph.text("Displaying Results...") | |
st.write(st.session_state.chain.invoke(user_query)) |