LLM_summarizer / rag_pipeline.py
arpita-23's picture
Upload 2 files
18b6b25 verified
import os
from langchain_community.document_loaders import TextLoader
from langchain.text_splitter import CharacterTextSplitter
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_chroma import Chroma
from langchain_huggingface import HuggingFacePipeline
from langchain.chains import RetrievalQA
from transformers import pipeline
import json
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_community.document_loaders import TextLoader
# Load and preprocess the texts
def load_and_preprocess(file_path):
loader = TextLoader(file_path)
documents = loader.load()
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
return text_splitter.split_documents(documents)
# Embed and store the documents in a vector database
def create_vector_store(documents, persist_directory):
embeddings = HuggingFaceEmbeddings()
vector_store = Chroma.from_documents(documents, embeddings, persist_directory=persist_directory)
return vector_store
# Initialize the LLM
def initialize_llm():
generator = pipeline('text-generation', model='gpt2')
return HuggingFacePipeline(pipeline=generator)
# Build the RAG pipeline
def build_rag_pipeline(vector_store, llm):
retriever = vector_store.as_retriever(search_type="similarity", search_kwargs={"k": 3})
return RetrievalQA(llm=llm, retriever=retriever)
# Main function
def main():
# Load and preprocess the texts
gita_docs = load_and_preprocess(r'C:\LLM_summerizer\bhagavad_gita_verses.csv')
yoga_sutras_docs = load_and_preprocess(r'C:\LLM_summerizer\yoga_raw.txt')
documents = gita_docs + yoga_sutras_docs
# Create vector store
vector_store = create_vector_store(documents, persist_directory='vector_store')
# Initialize LLM
llm = initialize_llm()
# Build RAG pipeline
rag_pipeline = build_rag_pipeline(vector_store, llm)
# Example query
query = "What does the Bhagavad Gita say about selfless action?"
result = rag_pipeline.run(query)
# Format output as JSON
output = {
"query": query,
"answer": result
}
# Print the output
print(json.dumps(output, indent=2))
if __name__ == "__main__":
main()