Spaces:
Runtime error
Runtime error
File size: 4,656 Bytes
fe1089d 2492536 fe1089d 5d99c07 dacf466 226ad46 dacf466 fe1089d f301e04 a597c76 f301e04 fe1089d 5d99c07 fe1089d 2492536 5d99c07 fe1089d f301e04 21aad16 5d99c07 a597c76 5d99c07 a597c76 5d99c07 2492536 f301e04 ba1dc89 2492536 fe1089d 21aad16 dacf466 ba1dc89 226ad46 fe1089d 2492536 fe1089d 2492536 b324c38 fe1089d 2492536 fe1089d f301e04 fe1089d 5d99c07 fe1089d 30049a9 fe1089d d2116db b324c38 fe1089d 30049a9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
# controller for the application that calls the model and explanation functions
# returns the updated conversation history and extra elements
# external imports
import gradio as gr
# internal imports
from model import godel
from model import mistral
from explanation import (
attention as attention_viz,
interpret_shap as shap_int,
interpret_captum as cpt_int,
)
# simple chat function that calls the model
# formats prompts, calls for an answer and returns updated conversation history
def vanilla_chat(
model, message: str, history: list, system_prompt: str, knowledge: str = ""
):
print(f"Running normal chat with {model}.")
# formatting the prompt using the model's format_prompt function
prompt = model.format_prompt(message, history, system_prompt, knowledge)
# generating an answer using the model's respond function
answer = model.respond(prompt)
# updating the chat history with the new answer
history.append((message, answer))
# returning the updated history
return "", history
def explained_chat(
model, xai, message: str, history: list, system_prompt: str, knowledge: str = ""
):
print(f"Running explained chat with {xai} with {model}.")
# formatting the prompt using the model's format_prompt function
# message, history, system_prompt, knowledge = mdl.prompt_limiter(
# message, history, system_prompt, knowledge
# )
prompt = model.format_prompt(message, history, system_prompt, knowledge)
print(f"Formatted prompt: {prompt}")
# generating an answer using the methods chat function
answer, xai_graphic, xai_markup, xai_plot = xai.chat_explained(model, prompt)
# updating the chat history with the new answer
history.append((message, answer))
# returning the updated history, xai graphic and xai plot elements
return "", history, xai_graphic, xai_markup, xai_plot
# main interference function that calls chat functions depending on selections
def interference(
prompt: str,
history: list,
knowledge: str,
system_prompt: str,
xai_selection: str,
model_selection: str,
):
# if no proper system prompt is given, use a default one
if system_prompt in ("", " "):
system_prompt = """
You are a helpful, respectful and honest assistant.
Always answer as helpfully as possible, while being safe.
"""
# if a model is selected, grab the model instance
if model_selection.lower() == "mistral":
model = mistral
print("Identified model as Mistral")
else:
model = godel
print("Identified model as GODEL")
# if a XAI approach is selected, grab the XAI module instance
# and call the explained chat function
if xai_selection in ("SHAP", "Attention"):
# matching selection
match xai_selection.lower():
case "shap":
if model_selection.lower() == "mistral":
xai = cpt_int
else:
xai = shap_int
case "attention":
xai = attention_viz
case _:
# use Gradio warning to display error message
gr.Warning(f"""
There was an error in the selected XAI Approach.
It is "{xai_selection}"
""")
# raise runtime exception
raise RuntimeError("There was an error in the selected XAI approach.")
# call the explained chat function with the model instance
prompt_output, history_output, xai_interactive, xai_markup, xai_plot = (
explained_chat(
model=model,
xai=xai,
message=prompt,
history=history,
system_prompt=system_prompt,
knowledge=knowledge,
)
)
# if no XAI approach is selected call the vanilla chat function
else:
# calling the vanilla chat function
prompt_output, history_output = vanilla_chat(
model=model,
message=prompt,
history=history,
system_prompt=system_prompt,
knowledge=knowledge,
)
# set XAI outputs to disclaimer html/none
xai_interactive, xai_markup, xai_plot = (
"""
<div style="text-align: center"><h4>Without Selected XAI Approach,
no graphic will be displayed</h4></div>
""",
[("", "")],
None,
)
# return the outputs
return prompt_output, history_output, xai_interactive, xai_markup, xai_plot
|