File size: 3,487 Bytes
2077c12
 
990a627
 
 
44ccc03
2077c12
 
 
 
 
 
 
 
990a627
2077c12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d7ab726
 
 
e7f63f2
2077c12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import streamlit as st
import ollama
import torch
import os
import logging
import accelerate

from langchain import hub
from langchain_community.document_loaders.web_base import WebBaseLoader
from langchain_community.vectorstores.faiss import FAISS
from langchain.text_splitter import RecursiveCharacterTextSplitter
from typing import List, Tuple, Dict, Any
from langchain_ollama import OllamaEmbeddings



##### Logging

# Logging configuration
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(levelname)s - %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
)

logger = logging.getLogger(__name__)


prompt = hub.pull("rlm/rag-prompt")

embeddings = OllamaEmbeddings(model='nomic-embed-text', base_url="http://localhost:11434")

@st.cache_resource(show_spinner=True)
def extract_model_names(
    models_info: Dict[str, List[Dict[str, Any]]],
) -> Tuple[str, ...]:
    """
    Extract model names from the provided models information.

    Args:
        models_info (Dict[str, List[Dict[str, Any]]]): Dictionary containing information about available models.

    Returns:
        Tuple[str, ...]: A tuple of model names.
    """
    logger.info("Extracting model names from models_info")
    model_names = tuple(model["name"] for model in models_info["models"])
    logger.info(f"Extracted model names: {model_names}")
    return model_names

def format_docs(docs):
    return "\n\n".join([doc.page_content for doc in docs])

## RAG with URL
def rag_with_url(target_url, model, prompt):

    loader = WebBaseLoader(target_url)
    raw_document = loader.load()
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200, length_function=len)
    splited_document = text_splitter.split_documents(raw_document)
    vector_store = FAISS.from_documents(splited_document, embeddings)
    retriever = vector_store.as_retriever()
    relevant_documents = retriever.get_relevant_documents(prompt)
    final_prompt = prompt + " " + " ".join([doc.page_content for doc in relevant_documents])


    AI_Respose = model.invoke(final_prompt)

    return AI_Respose.content

def main() -> None:

    st.set_page_config(layout="centered")

    st.title("🧠 This is a RAG Chatbot to any URL with LLama-3.2 and Langchain !!!")

    st.write("You can select which local LLM from Ollama to be used !!!")

    from transformers import AutoTokenizer, AutoModelForCausalLM

    token = os.environ["HF_TOKEN"]
    
    tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-3B-Instruct",token=token)
    model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-3B-Instruct", token=token)  


    # URL text box for user input
    url_input = st.text_input("Enter a URL to be queried:", "")

    with st.form("llm-form"):
        text = st.text_area("Enter your question below:")
        submit = st.form_submit_button("Submit")

    if "chat_history" not in st.session_state:
        st.session_state['chat_history'] = []

    if submit and text:
        with st.spinner("Generating response..."):
            response = rag_with_url(url_input, model, text)
            st.session_state['chat_history'].append({"user": text, "ollama": response})
            st.write(response)

    st.write("## Chat History")
    for chat in reversed(st.session_state['chat_history']):
        st.write(f"**🧑 User**: {chat['user']}")
        st.write(f"**🧠 Assistant**: {chat['ollama']}")
        st.write("---")

if __name__ == "__main__":
    main()