talk_to_YouTube / app.py
the-confused-coder's picture
Update app.py
58bb07d verified
# other imports
import streamlit as st, os
import spaces
import tempfile
from pytube import YouTube
import whisper
from dotenv import load_dotenv
# langchain imports
from langchain_community.vectorstores import FAISS
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_openai import OpenAIEmbeddings
from langchain_openai import OpenAI
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
# initialize session state variables if not already initialized
if 'chain' not in st.session_state:
st.session_state.chain = None
if 'transcription' not in st.session_state:
st.session_state.transcription = None
if 'thumbnail_url' not in st.session_state:
st.session_state.thumbnail_url = None
# function for getting transcribing YouTube video
def get_yt_trans(yt_obj):
video_audio = yt_obj.streams.filter(only_audio=True).first()
# define model for audio to text conversion
audio_to_text_model = whisper.load_model('base')
# store transcription in local file
with tempfile.TemporaryDirectory() as tmpdir:
audio_file = video_audio.download(output_path=tmpdir)
st.session_state.transcription = audio_to_text_model.transcribe(audio_file, fp16=False)["text"].strip()
# with open("transcription.txt", "w") as file:
# file.write(transcription)
# function for chunking transcription text
def get_vid_text_chunks(transcription_text):
text_splitter = RecursiveCharacterTextSplitter(chunk_size=300, chunk_overlap=20)
chunked_text = text_splitter.split_text(transcription_text)
return chunked_text
# function for creating vector db
def create_vecdb(chunks):
vecdb = FAISS.from_texts(chunks, OpenAIEmbeddings())
# vecdb.save_local('vec_store') ## if needed to store and retrieve locally
return vecdb
# define api keys
load_dotenv()
os.environ['OPEN_API_KEY'] = os.getenv('OPENAI_API_KEY')
# define LLM
llm = OpenAI(temperature=0)
# Set page configs
st.set_page_config(
page_title="YouTube Talks",
page_icon='📽️'
)
st.header('Query YouTube Videos!')
thumbnail_placeholder = st.empty()
st.sidebar.header("URL details:")
video_url = st.sidebar.text_input("Enter video URL")
submit_button = st.sidebar.button("Submit")
user_query = st.text_input("Query the Video!")
video_query_button = st.button("Ask video!")
progress_updates = st.sidebar.empty()
main_ph = st.empty()
# on button press
if submit_button:
main_ph.text("Transcribing Video, please wait...")
progress_updates.text("Transcribing Video...")
# get video transcription
yt_obj = YouTube(video_url)
# display video thumbnail
st.session_state.thumbnail_url = yt_obj.thumbnail_url
thumbnail_placeholder.image(st.session_state.thumbnail_url)
# get transcription
get_yt_trans(yt_obj)
st.sidebar.subheader("Transcription:")
st.sidebar.write(st.session_state.transcription)
progress_updates.text("Making Text Chunks...")
# get text chunks
chunks = get_vid_text_chunks(st.session_state.transcription)
progress_updates.text("Creating Vector DB...")
# create vector db
vector_db = create_vecdb(chunks)
# define main prompt
p_template = '''Answer the question based on the context below. If you can't
answer the question, reply "I don't know".
Context: {context}
Question: {question}'''
prompt = ChatPromptTemplate.from_template(p_template)
# define output parser
parser = StrOutputParser()
# define main chain
st.session_state.chain = {"context": vector_db.as_retriever(), "question": RunnablePassthrough()} | prompt | llm | parser
progress_updates.text("Video Transcribed Successfully!!!")
main_ph.text("Video Transcribed Successfully!")
if video_query_button:
# keep displaying Transcription and Thumbnail in window
st.sidebar.subheader("Transcription:")
st.sidebar.write(st.session_state.transcription)
thumbnail_placeholder.image(st.session_state.thumbnail_url)
# if video not transcribed display error
if st.session_state.chain is None:
st.error("Please transcribe a video first by submitting a URL.")
else:
main_ph.text("Fetching Results...")
# print results
st.subheader("Result:")
main_ph.text("Displaying Results...")
st.write(st.session_state.chain.invoke(user_query))