Spaces:
Runtime error
Runtime error
File size: 3,237 Bytes
2492536 d2116db 2492536 d2116db 517fd4c 1f063be 2492536 58a02af d2116db 2492536 f5ebee7 d2116db f5ebee7 5d99c07 f5ebee7 d2116db 1f063be a597c76 d2116db 2492536 67a34bd d2116db 58a02af 67a34bd d2116db 58a02af 2492536 58a02af 2492536 58a02af d2116db 2492536 d2116db 2492536 d2116db 2492536 67a34bd 1f063be d2116db 2492536 d2116db 2492536 d2116db 2492536 a597c76 229e14c 226ad46 229e14c 226ad46 d2116db |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |
# markup module that provides marked up text as an array
# external imports
import numpy as np
from numpy import ndarray
# internal imports
from utils import formatting as fmt
# main function that assigns each text snipped a marked bucket
def markup_text(input_text: list, text_values: ndarray, variant: str):
print(f"Marking up text {input_text} for {variant}.")
# naming of the 11 buckets
bucket_tags = ["-5", "-4", "-3", "-2", "-1", "0", "+1", "+2", "+3", "+4", "+5"]
# flatten the values depending on the source
# attention is averaged, SHAP summed up
if variant == "shap":
text_values = np.transpose(text_values)
text_values = fmt.flatten_attribution(text_values)
elif variant == "visualizer":
text_values = fmt.flatten_attention(text_values)
if text_values.size != len(input_text):
raise ValueError(
"Length of input text and attribution values do not match. "
f"Text: {len(input_text)}, Attributions: {len(text_values)}"
)
# determine the minimum and maximum values
min_val, max_val = np.min(text_values), np.max(text_values)
# separate the threshold calculation for negative and positive values
# visualization negative thresholds are all 0 since attention always positive
if variant == "visualizer":
neg_thresholds = np.linspace(
0, 0, num=(len(bucket_tags) - 1) // 2 + 1, endpoint=False
)[1:]
# standard config for 5 negative buckets
else:
neg_thresholds = np.linspace(
min_val, 0, num=(len(bucket_tags) - 1) // 2 + 1, endpoint=False
)[1:]
# creating positive thresholds between 0 and max values
pos_thresholds = np.linspace(0, max_val, num=(len(bucket_tags) - 1) // 2 + 1)[1:]
# combining thresholds
thresholds = np.concatenate([neg_thresholds, [0], pos_thresholds])
# init empty marked text list
marked_text = []
# looping over each text snippet and attribution value
for text, value in zip(input_text, text_values):
# validating text and skipping empty text/special tokens
# if text not in fmt.SPECIAL_TOKENS:
# setting initial bucket at lowest
bucket = "-5"
# looping over all bucket and their threshold
for i, threshold in zip(bucket_tags, thresholds):
# updating assigned bucket if value is above threshold
if value >= threshold:
bucket = i
# finally adding text and bucket assignment to list of tuples
marked_text.append((text, str(bucket)))
# returning list of marked text snippets as list of tuples
return marked_text
# function that defines color codes
# coloring along SHAP style coloring for consistency
def color_codes():
return {
# -5 to -1: Strong Light Sky Blue to Lighter Sky Blue
# 0: white (assuming default light mode)
# +1 to +5 light pink to strong magenta
"-5": "#008bfb",
"-4": "#68a1fd",
"-3": "#96b7fe",
"-2": "#bcceff",
"-1": "#dee6ff",
"0": "#ffffff",
"+1": "#ffd9d9",
"+2": "#ffb3b5",
"+3": "#ff8b92",
"+4": "#ff5c71",
"+5": "#ff0051",
}
|