Spaces:
Build error
Build error
| from langchain import PromptTemplate, LLMChain | |
| from langchain.llms import CTransformers, HuggingFacePipeline, GooglePalm | |
| import os | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain.vectorstores import Chroma | |
| from langchain.chains import RetrievalQA | |
| from langchain.embeddings import HuggingFaceBgeEmbeddings, HuggingFaceEmbeddings | |
| from io import BytesIO | |
| from langchain.document_loaders import PyPDFLoader | |
| import gradio as gr | |
| import chromadb | |
| from constants import CHROMA_SETTINGS | |
| from io import BytesIO | |
| import gradio as gr | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, AutoModelForSeq2SeqLM, AutoModel | |
| import gc | |
| from langchain.schema.runnable import RunnableLambda, RunnablePassthrough | |
| from langchain.chat_models import ChatGooglePalm | |
| import google.generativeai as genai | |
| from langchain_google_genai import GoogleGenerativeAIEmbeddings | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| #model= AutoModelForCausalLM.from_pretrained(local_llm, device_map= device) | |
| llm= ChatGooglePalm() | |
| #llm= HuggingFacePipeline.from_model_id(model_id=local_llm, task='text-generation', device=0, pipeline_kwargs={"max_new_tokens": 1000}) | |
| persist_directory = os.environ.get('PERSIST_DIRECTORY') | |
| target_source_chunks = int(os.environ.get('TARGET_SOURCE_CHUNKS',4)) | |
| google_api_key= os.environ.get('GOOGLE_API_KEY') | |
| print("Loading embeddings model...") | |
| embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001") | |
| #embeddings= pipeline("feature-extraction", model="WhereIsAI/UAE-Large-V1") | |
| # Chroma client | |
| chroma_client = chromadb.PersistentClient(settings=CHROMA_SETTINGS , path=persist_directory) | |
| db = Chroma(persist_directory=persist_directory, embedding_function=embeddings, client_settings=CHROMA_SETTINGS, client=chroma_client) | |
| prompt_template = """Use the following pieces of information to answer the user's question. | |
| If you don't know the answer, just say that you don't know, don't try to make up an answer. | |
| Context: {context} | |
| Question: {question} | |
| Only return the helpful answer below and nothing else. | |
| Helpful answer: | |
| """ | |
| prompt = PromptTemplate(template=prompt_template, input_variables=['context', 'question']) | |
| retriever = db.as_retriever(search_kwargs={"k": target_source_chunks}) | |
| # activate/deactivate the streaming StdOut callback for LLMs | |
| chain_type_kwargs = {"prompt": prompt} | |
| input_gradio= gr.Text( | |
| label="Prompt", | |
| show_label=False, | |
| max_lines=2, | |
| placeholder="Enter your question here", | |
| container=False, | |
| ) | |
| def get_response(input_gradio ): | |
| query=input_gradio | |
| qa = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever, return_source_documents= False, chain_type_kwargs=chain_type_kwargs, verbose=True) | |
| response= qa(query) | |
| return response['result'] | |
| iface= gr.Interface( | |
| fn=get_response, | |
| inputs=input_gradio, | |
| outputs="text", | |
| title="Tsetlin Machine Chatbot", | |
| description="A chatbot that uses the LLM to answer anything regarding TM", | |
| allow_flagging='never' | |
| ) | |
| # Interactive questions and answers | |
| iface.launch() | |