Baldezo313's picture
Upload 4 files
d5fd2c1 verified
raw
history blame
3.95 kB
import os
import torch
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, pipeline
from langchain.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import FAISS
from langchain_huggingface import HuggingFaceEmbeddings
from langchain.llms import HuggingFacePipeline
from langchain.chains import ConversationalRetrievalChain
from langchain.prompts import PromptTemplate
from langchain.chains.question_answering import load_qa_chain
from langchain.chains.llm import LLMChain
# Load PDF
loader = PyPDFLoader("Medical_Book.pdf")
documents = loader.load()
# Split text
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=20)
all_splits = text_splitter.split_documents(documents)
# Embeddings
model_name = "sentence-transformers/all-mpnet-base-v2"
embeddings = HuggingFaceEmbeddings(model_name=model_name, model_kwargs={"device": "cpu"})
vectorstores = FAISS.from_documents(all_splits, embeddings)
# Load LLM with quantization
llm_model = "ritvik77/Medical_Doctor_AI_LoRA-Mistral-7B-Instruct_FullModel"
bnb_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.float16)
model = AutoModelForCausalLM.from_pretrained(llm_model, quantization_config=bnb_config, trust_remote_code=True, use_cache=True, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(llm_model, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=150, temperature=0.7, top_p=0.9, do_sample=True)
llm = HuggingFacePipeline(pipeline=pipe)
# Prompt templates
condense_prompt = PromptTemplate(
input_variables=["question", "chat_history"],
template="""
You are a helpful medical assistant. Given the following conversation and a follow-up question, rephrase the follow-up question to be a standalone question.
Chat History:
{chat_history}
Follow-up Question:
{question}
Standalone Question:"""
)
qa_prompt = PromptTemplate(
input_variables=["context", "question"],
template="""
Use the following context to answer the question.
Context:
{context}
Question:
{question}
Answer:
"""
)
question_generator = LLMChain(llm=llm, prompt=condense_prompt)
combine_docs_chain = load_qa_chain(llm=llm, chain_type="stuff", prompt=qa_prompt)
def chatbot_response(user_input, max_new_tokens, temperature, context_length):
pipe.model.config.max_new_tokens = int(max_new_tokens)
pipe.model.config.temperature = float(temperature)
pipe.model.config.context_length = int(context_length)
chain = ConversationalRetrievalChain(retriever=vectorstores.as_retriever(), combine_docs_chain=combine_docs_chain, question_generator=question_generator, return_source_documents=True)
chat_history = []
result = chain({"question": user_input, "chat_history": chat_history})
return f"<div style='max-height: 400px; overflow-y: auto;'>{result['answer']}</div>"
interface = gr.Interface(
fn=chatbot_response,
inputs=[
gr.Textbox(lines=2, placeholder="Type your question here...", label="Your Question", interactive=True),
gr.Slider(label="Max New Tokens", minimum=1, maximum=2000, value=150, step=1, interactive=True),
gr.Slider(label="Temperature", minimum=0.1, maximum=1.0, value=0.7, step=0.01, interactive=True),
gr.Slider(label="Context Length", minimum=100, maximum=4000, value=2000, step=1, interactive=True)
],
outputs=gr.HTML(label="Chatbot Response"),
title="🩺 MEDICAL Chatbot",
description="""
<div style='text-align: center;'>
<img src='https://cdn.dribbble.com/users/29678/screenshots/2407580/media/34ee4b818fd4ddb3a616c91ccf4d9cfc.png' alt='Medical Bot' width='100'>
<p>Check the responses from the Medical Llama 3-8B model!</p>
</div>
"""
)
interface.launch()