###########################################################################################
# Title:  Gradio Interface to LLM-chatbot with dynamic RAG-funcionality and ChromaDB
# Author: Andreas Fischer
# Date:   October 10th, 2024
# Last update: October 26th, 2024
##########################################################################################

import os
import torch
from transformers import AutoTokenizer, AutoModel # chromaDB
from datetime import datetime, date #add_doc, 
import chromadb #chromaDB
from chromadb import Documents, EmbeddingFunction, Embeddings #chromaDB
from chromadb.utils import embedding_functions #chromaDB
import ocrmypdf #convertPDF
from pypdf import PdfReader #convertPDF
import re #format_prompt
import gradio as gr # multimodal_response
from huggingface_hub import InferenceClient # multimodal_response
import json # multimodal_response (on-prem)
import requests # multimodal_response (on-prem)

#---------------------------------------------------
# Specify models for text generation and embeddings
#---------------------------------------------------

myModel="mistralai/Mixtral-8x7b-instruct-v0.1"
#myModel="meta-llama/Llama-3.1-8B-Instruct"
#myModel="QuantFactory/gemma-2-9b-it-SimPO-GGUF"
#myModel="bartowski/gemma-2-9b-it-GGUF"
#mod=myModel
#tok=AutoTokenizer.from_pretrained(mod) #,token="hf_...")
#cha=[{"role":"system","content":"A"},{"role":"user","content":"B"},{"role":"assistant","content":"C"}]
#cha=[{"role":"user","content":"U1"},{"role":"assistant","content":"A1"},{"role":"user","content":"U2"},{"role":"assistant","content":"A2"}]
#res=tok.apply_chat_template(cha)
#print(tok.decode(res))

if("GGUF" in myModel): # start Llama-cpp-server for GGUF-models on premises:
  #modelPath="/home/af/gguf/models/bartowski/gemma-2-9b-it-GGUF/gemma-2-9b-it-Q4_K_M.gguf"
  modelPath="/home/af/gguf/models/QuantFactory/gemma-2-9b-it-SimPO-GGUF/gemma-2-9b-it-SimPO.Q4_K_M.gguf"
  if(os.path.exists(modelPath)==False):
    #url="https://huggingface.co/bartowski/gemma-2-9b-it-GGUF/resolve/main/gemma-2-9b-it-Q4_K_M.gguf?download=true"
    url="https://huggingface.co/QuantFactory/gemma-2-9b-it-SimPO-GGUF/resolve/main/gemma-2-9b-it-SimPO.Q4_K_M.gguf?download=true"
    response = requests.get(url)
    with open("./model.gguf", mode="wb") as file:
      file.write(response.content)
    print("Model downloaded")  
    modelPath="./model.gguf"
  print(modelPath)
  import subprocess
  command = ["python3", "-m", "llama_cpp.server", "--model", modelPath, "--host", "0.0.0.0", "--port", "2600", "--n_threads", "4", "--n_gpu_layers","42"] #20
  subprocess.Popen(command)
  print("Server ready!")

#url="http://0.0.0.0:2600/v1/completions"  
#body={"prompt":"test","max_tokens":1000, "echo":"False","stream":"False"} #e.g. Mixtral-Instruct
#test=requests.post(url, json=body, stream=False)
  
jina = AutoModel.from_pretrained('jinaai/jina-embeddings-v2-base-de', trust_remote_code=True, torch_dtype=torch.bfloat16)
#jira.save_pretrained("jinaai_jina-embeddings-v2-base-de")   
device='cuda:0' if torch.cuda.is_available() else 'cpu'
jina.to(device) #cuda:0
print(device)


#-----------------
# ChromaDB-client
#-----------------
 
class JinaEmbeddingFunction(EmbeddingFunction):
  def __call__(self, input: Documents) -> Embeddings:    
    embeddings = jina.encode(input) #max_length=2048
    return(embeddings.tolist())

dbPath = "/home/af/Schreibtisch/Code/gradio/Chroma/db/" 
onPrem = True if(os.path.exists(dbPath)) else False 
if(onPrem==False): dbPath="/home/user/app/db/"

print(dbPath)
client = chromadb.PersistentClient(path=dbPath)
print(client.heartbeat()) 
print(client.get_version())  
print(client.list_collections()) 

jina_ef=JinaEmbeddingFunction()
embeddingModel=jina_ef
databases=[(date.today(),"0")] # start a list of databases


#---------------------------------------------------------------------
# Function for formatting single message according to prompt template
#---------------------------------------------------------------------

def format_prompt0(message, history):
  prompt = "<s>"
  #for user_prompt, bot_response in history:
  #  prompt += f"[INST] {user_prompt} [/INST]"
  #  prompt += f" {bot_response}</s> "  
  prompt += f"[INST] {message} [/INST]"
  return prompt


#-------------------------------------------------------------------------
# Function for formatting multiturn-dialogue according to prompt template
#-------------------------------------------------------------------------

def format_prompt(message, history=None, system=None, RAGAddon=None, system2=None, zeichenlimit=None,historylimit=4, removeHTML=False,
  startOfString="<s>", template0=" [INST] {system} [/INST] </s>",template1=" [INST] {message} [/INST]",template2=" {response}</s>"): # mistralai/Mixtral-8x7B-Instruct-v0.1
  #startOfString="<bos>",template0="<start_of_turn>user\n{system}<end_of_turn>\n<start_of_turn>model\n<end_of_turn>\n",template1="<start_of_turn>user\n{message}<end_of_turn>\n<start_of_turn>model\n",template2="{response}<end_of_turn>\n"): # google/gemma-2-2b-it
  #startOfString="<|begin_of_text|><", template0="<|start_header_id|>system<|end_header_id|>\n\nCutting Knowledge Date: December 2023\nToday Date: 26 Jul 2024\n\n{system}\n<|eot_id|>", template1="<|start_header_id|>user<|end_header_id|>\n\n{message}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", template2="{response}</eot_id>"): # meta-llama/Llama-3.1-8B-Instruct
  if zeichenlimit is None: zeichenlimit=1000000000 # :-)  
  prompt = ""
  if RAGAddon is not None:
    system += RAGAddon
  if system is not None:
    prompt += template0.format(system=system) #"<s>"
  if history is not None:
    for user_message, bot_response in history[-historylimit:]:
      if user_message is None: user_message = "" 
      if bot_response is None: bot_response = ""
      bot_response = re.sub("\n\n<details>((.|\n)*?)</details>","", bot_response) # remove RAG-compontents
      if removeHTML==True: bot_response = re.sub("<(.*?)>","\n", bot_response)    # remove HTML-components in general (may cause bugs with markdown-rendering)
      if user_message is not None: prompt += template1.format(message=user_message[:zeichenlimit])  
      if bot_response is not None: prompt += template2.format(response=bot_response[:zeichenlimit]) 
  if message is not None: prompt += template1.format(message=message[:zeichenlimit])                
  if system2 is not None:
    prompt += system2
  return startOfString+prompt


#--------------------------------------------
# Function for converting pdf-files to text
#--------------------------------------------

def convertPDF(pdf_file, allow_ocr=False):
    reader = PdfReader(pdf_file)
    full_text = ""
    page_list = []       
    def extract_text_from_pdf(reader):
        full_text = ""
        page_list = []
        page_count = 1
        for idx, page in enumerate(reader.pages):
            text = page.extract_text()
            if len(text) > 0:
                page_list.append(text)
                #full_text += f"---- Page {idx} ----\n" + text + "\n\n"
                page_count += 1    
        return full_text.strip(), page_count, page_list
    # Check if there are any images
    image_count = sum(len(page.images) for page in reader.pages)
    # If there are images and not much content, you may want to perform OCR on the document
    if allow_ocr:
        print(f"{image_count} Images") 
        if image_count > 0 and len(full_text) < 1000:
            out_pdf_file = pdf_file.replace(".pdf", "_ocr.pdf")
            ocrmypdf.ocr(pdf_file, out_pdf_file, force_ocr=True)        
            reader = PdfReader(out_pdf_file)
    # Extract text:
    full_text, page_count, page_list = extract_text_from_pdf(reader)
    l = len(page_list)
    print(f"{l} Pages")
    # Extract metadata
    metadata = {
        "author": reader.metadata.author,
        "creator": reader.metadata.creator,
        "producer": reader.metadata.producer,
        "subject": reader.metadata.subject,
        "title": reader.metadata.title,
        "image_count": image_count,
        "page_count": page_count,
        "char_count": len(full_text),
    }    
    return page_list, full_text, metadata


#------------------------------------------
# Function for splitting text with overlap
#------------------------------------------

def split_with_overlap0(text,chunk_size=3500, overlap=700):
  """ Split text in chunks based on number of characters (chunk_size) with chunks overlapping (overlap)"""
  chunks=[]
  step=max(1,chunk_size-overlap)
  for i in range(0,len(text),step):
    end=min(i+chunk_size,len(text))
    chunks.append(text[i:end])
  return chunks

import re
def split_with_overlap(text, chunk_size=3500, overlap=700, pattern=r'([.!;?][ \n\r]|[\n\r]{2,})', variant=1, verbose=False):
    """ Split text in chunks based on regex (pattern) matches. By default the pattern is '([.!;?][ \\n\\r]|[\\n\\r]{2,})' Chunks are no longer than a certain number of characters (chunk_size) with chunks overlapping (overlap).
     By default (variant=1) chunking is based on complete sentences, but it's also possible to split only within the left overlap region and within the rest of the chunk-size (variant==2) or strictly within both overlap-regions (variant=3).    
    """
    chunks = []
    overlap=min(overlap,chunk_size) # Overlap kann nicht größer sein als chunk_size
    step = max(1, chunk_size - overlap) # step richtet sich nach chunk_size und overlap
    def find_pattern(text): # Funktion zur Suche nach dem Muster
        return re.search(pattern, text)
    i, lastEnd = 0,0
    while i<len(text):
      print("i="+str(i))
      end = min(i + chunk_size, len(text)) 
      pattern_match = find_pattern(text[i:end]) # erstes Vorkommnis (if any)
      matchesStart = [x.start() for x in re.finditer(pattern, text[i:end])] # start aller matches
      matchesEnd =   [x.start() for x in re.finditer(pattern, text[i:end])] # end aller matches
      step = max(1, chunk_size - overlap) # Normalerweise beträgt ein Step chunk_size - overlap     
      if pattern_match: # Wenn (mindestens) ein Satzzeichen gefunden wurde        
        for s in matchesStart: # gehe jedes Satzzeichen durch
          if ((variant<=2 and s>=overlap) or (variant==3 and s>=overlap and s>(chunk_size-overlap))): # wenn das Satzzeichen nicht im Overlap links liegt (1) oder zusätzlich im reechten Overlap liegt (2) - wobei letzteres unvollständige Sätze bedeuten kann
            end=s+i+1 # Setze end auf den Start des Patterns/Satzzeichens im gesamten Text
            if(verbose==True): print("***move end:"+str(end)+"; step="+str(step))
            if(s<(chunk_size-overlap)):step=min(step,max(1,s-overlap)) # Springe mit step höchstens zum Ende des Satzzeichens (nur erforderlich, wenn end nicht im Overlap)
        if ((variant==1 and i>0) or (variant>=2 and pattern_match.start()<overlap and i>0)): # wenn das erste Satzzeichen im Overlap liegt
          i=i+pattern_match.start()+1     # Verzichte auf Textteile vor dem ersten Satzzeichen                               
      if(verbose==True): print("i="+str(i)+"; end="+str(end)+"; step="+str(step)+"; len="+str(len(text))+"; match="+str(pattern_match)+"; text="+text[i:end]+"; rest="+text[end:])
      if(end>lastEnd): # wenn das Ende sich verschoben hat (und nicht nur den Satzbeginn zu einem bereits bekannten Satz abschneidet)
        chunks.append(text[i:end])        
      lastEnd=end
      if(verbose==True): print("Text at position "+str(i)+": "+text[i:end])
      i += step
    if(len(text[end:])>0): chunks.append(text[end:]) # Ergänze am ende etwaigen Rest 
    return chunks

fiveChars=  "(?<![ \n\(]bspw|[ \n]inkl)"
fourChars=  "(?<![ \n\(]sog|[ \n]Mio|[ \n]Mrd|[ \n]Tsd|[ \n]Tel)" 
threeChars= "(?<!www|bzw|etc|ggf|[ \n\(]al|[ \n\(]St|[ \n\(]dh|[ \n\(]va|[ \n\(]ca|[ \n\(]Dr|[ \n\(]Hr|[ \n\(]Fr|[0-9]ff)"
twoChars=   "(?<![ \n\(][A-Za-zΆ-Ωά-ωäöüß])"
oneChars=   "(?<![0-9.])"
sentenceRegex="(?<=[^.]{4})"+fiveChars+fourChars+threeChars+twoChars+oneChars+"[.?!](?![A-Za-zΆ-Ωά-ωäöüß0-9.!?'\"])"
sectionRegex="\n[ ]*\n[\n ]*"
splitRegex="("+sentenceRegex+"|"+sectionRegex+")"


#---------------------------------------------------------------
# Function for adding docs to ChromaDB and/or return collection
#---------------------------------------------------------------

def add_doc(path, session):
  global device
  print("def add_doc!")
  print(path)
  anhang=False
  if(str.lower(path).endswith(".pdf") and os.path.exists(path)):
      doc=convertPDF(path)
      if(len(doc[0])>5):         
        if(not "cuda" in device):
          doc="\n\n".join(doc[0][0:5]) 
          gr.Info("PDF uploaded to DB_"+str(session)+", start Indexing excerpt (demo-mode: first 5 pages on CPU setups)!")
        else:
          doc="\n\n".join(doc[0]) 
          gr.Info("PDF uploaded to DB_"+str(session)+", start Indexing!")                
      else:
        doc="\n\n".join(doc[0]) 
        gr.Info("PDF uploaded to DB_"+str(session)+", start Indexing!")                 
      anhang=True
  else:
    gr.Info("No PDF attached - answer based on DB_"+str(session)+".")          
  client = chromadb.PersistentClient(path=dbPath)
  print(str(client.list_collections()))
  print(str(session))
  dbName="DB_"+str(session)
  if(not "name="+dbName in str(client.list_collections())): 
    #  client.delete_collection(name=dbName) 
    collection = client.create_collection(
      name=dbName,
      embedding_function=embeddingModel,
      metadata={"hnsw:space": "cosine"})
  else:
    collection = client.get_collection(
      name=dbName, embedding_function=embeddingModel)
  if(anhang==True):
    corpus=split_with_overlap(doc,3500,700,pattern=splitRegex) 
    print("Length of corpus: "+str(len(corpus)))
    print("Corpus:"+str(corpus))
    then = datetime.now()
    x=collection.get(include=[])["ids"]
    print(len(x))
    if(len(x)==0):
      chunkSize=40000
      for i in range(round(len(corpus)/chunkSize+0.5)): #0 is first batch, 3 is last (incomplete) batch given 133497 texts
        print("embed batch "+str(i)+" of "+str(round(len(corpus)/chunkSize+0.5)))
        ids=list(range(i*chunkSize,(i*chunkSize+chunkSize)))
        batch=corpus[i*chunkSize:(i*chunkSize+chunkSize)]
        textIDs=[str(id) for id in ids[0:len(batch)]]
        ids=[str(id+len(x)+1) for id in ids[0:len(batch)]] # id refers to chromadb-unique ID      
        collection.add(documents=batch, ids=ids, 
          metadatas=[{"date": str("2024-10-10")} for b in batch]) #"textID":textIDs, "id":ids, 
        print("finished batch "+str(i)+" of "+str(round(len(corpus)/40000+0.5)))  
    now = datetime.now()
    gr.Info(f"Indexing complete!")
    print(now-then) #zu viel GB für sentences (GPU), bzw. 0:00:10.375087 für chunks
  return(collection)


#--------------------------------------------------------
# Function for response to user queries and pot. addenda
#--------------------------------------------------------

def multimodal_response(message, history, dropdown, hfToken, request: gr.Request):
  print("def multimodal response!")
  if(hfToken.startswith("hf_")): # use HF-hub with custom token if token is provided
    inferenceClient = InferenceClient(model=myModel, token=hfToken)
  else:
    inferenceClient = InferenceClient(myModel)  
  global databases
  if request:
    session=request.session_hash
  else:
    session="0"
  length=str(len(history))
  print(databases)
  if(not databases[-1][1]==session):
    databases.append((date.today(),session))
    #print(databases)
  query=message["text"]
  if(len(message["files"])>0): # is there at least one file attached?
    collection=add_doc(message["files"][0], session)
  else: # otherwise, you still want to get the collection with the session-based db
    collection=add_doc(message["text"], session)
  client = chromadb.PersistentClient(path=dbPath)
  print(str(client.list_collections()))
  x=collection.get(include=[])["ids"] 
  ragQuery=[format_prompt(query, history, historylimit=2,
    #startOfString="", template0="{system}\n",template1="USER: {message}\n\n",template2="ASSISTANT: {response}\n\n") if len(history)>0 else query] # embed simply-formated dialogue
    startOfString="", template1="{message}\n\n",template2="") if len(history)>0 else query] # embed simple string of User-queries only
  context=collection.query(query_texts=ragQuery, n_results=3)
  #context=["<Kontext "+str(i)+"> "+str(c)+"</Kontext "+str(i)+">" for i,c in enumerate(context["documents"][0])] 
  context=["Kontext "+str(i+1)+": \""+re.sub("\"","'",str(c))+"\"" for i,c in enumerate(context["documents"][0])]
  gr.Info("Kontext:\n"+str(context))    
  generate_kwargs = dict(
        temperature=float(0.9),
        max_new_tokens=5000,
        top_p=0.95,
        repetition_penalty=1.0,
        do_sample=True,
        seed=42,
  )
  system="Mit Blick auf das folgende Gespräch und den relevanten Kontext, antworte auf die aktuelle Frage des Nutzers. "+\
  "Antworte ausschließlich auf Basis der Informationen im Kontext.\n\nKontext:\n\n"+\
  str("\n\n".join(context))
  #"Given the following conversation, relevant context, and a follow up question, "+\
  #"reply with an answer to the current question the user is asking. "+\
  #"Return only your response to the question given the above information "+\
  #"following the users instructions as needed.\n\nContext:"+\
  print(system)
  #formatted_prompt = format_prompt0(system+"\n"+query, history)
  formatted_prompt = format_prompt(query, history,system=system)
  print(formatted_prompt)
  output = ""
  if(not "GGUF" in myModel):  
    try:
      stream = inferenceClient.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)    
      for response in stream:
        output += response.token.text
        yield output
    except Exception as e:
      output = "Für weitere Antworten von der KI gebe bitte einen gültigen HuggingFace-Token an."
      if(len(context)>0): 
        output += "\nBis dahin helfen dir hoffentlich die folgenden Quellen weiter:"
      yield output
      print(str(e))
  else:
    try:
      #generate_kwargs["prompt"]=formatted_prompt #
      generate_kwargs={"prompt":formatted_prompt,"max_tokens":1000, "echo":"False","stream":"True"} #e.g. Mixtral-Instruct    
      url="http://0.0.0.0:2600/v1/completions"  
      response=""
      buffer=""
      print("URL: "+url)
      print("User: "+str(message)+"\nAssistant: ")
      for text in requests.post(url, json=generate_kwargs, stream=True):  #-H 'accept: application/json' -H 'Content-Type: application/json'
        if buffer is None: buffer=""
        buffer=str("".join(buffer))
        text=text.decode('utf-8')
        if((text.startswith(": ping -")==False) & (len(text.strip("\n\r"))>0)): buffer=buffer+str(text)
        buffer=buffer.split('"finish_reason": null}]}')
        if(len(buffer)==1):
          buffer="".join(buffer)
          pass
        if(len(buffer)==2):
          part=buffer[0]+'"finish_reason": null}]}'  
          if(part.lstrip('\n\r').startswith("data: ")): part=part.lstrip('\n\r').replace("data: ", "")
          try: 
            part = str(json.loads(part)["choices"][0]["text"])
            print(part, end="", flush=True)
            output += part
            buffer="" 
          except Exception as e:
            print("Exception:"+str(e))
            pass
        yield output         
    except Exception as e:
      output = "Die KI antwortet gerade nicht."
      if(len(context)>0): 
        output += "\nBis dahin helfen dir hoffentlich die folgenden Quellen weiter:"
      yield output
      print(str(e))
  if(len(context)>0):
    output=output+"\n\n<br><details open><summary><strong>Quellen</strong></summary><br><ul>"+ "".join(["<li>" + c + "</li>" for c in context])+"</ul></details>"
  yield output

#------------------------------
# Launch Gradio-ChatInterface
#------------------------------


i=gr.ChatInterface(multimodal_response,
  title="Frag dein PDF",
  multimodal=True,
  additional_inputs=[
    gr.Dropdown(
      info="Wähle eine Variante",
      choices=["1","2","3"],
      value="1",
      label="Variante"),
           gr.Textbox(
      value="",
      label="HF_token"),   
  ])
i.launch() #allowed_paths=["."])