João Pedro
fix some wording in some strings
c9ae4da
import os
import wandb
import streamlit as st
from transformers import LayoutLMv3Processor, LayoutLMv3ForSequenceClassification
from pdf2image import convert_from_bytes
from PIL import Image
wandb_api_key = os.getenv("WANDB_API_KEY")
if not wandb_api_key:
st.error(
"Couldn't find WanDB API key. Please set it up as an environemnt variable",
icon="🚨",
)
else:
wandb.login(key=wandb_api_key)
labels = [
'budget',
'email',
'form',
'handwritten',
'invoice',
'language',
'letter',
'memo',
'news article',
'questionnaire',
'resume',
'scientific publication',
'specification',
]
id2label = {i: label for i, label in enumerate(labels)}
label2id = {v: k for k, v in id2label.items()}
if 'model' not in st.session_state:
st.session_state.model = LayoutLMv3ForSequenceClassification.from_pretrained("model/layoutlmv3/")
if 'processor' not in st.session_state:
st.session_state.processor = LayoutLMv3Processor.from_pretrained("model/layoutlmv3/")
model = st.session_state.model
processor = st.session_state.processor
st.title("Document Classification with LayoutLMv3")
uploaded_file = st.file_uploader(
"Upload Document", type=["pdf", "jpg", "png"], accept_multiple_files=False
)
feedback_table = wandb.Table(columns=[
'image', 'filetype', 'predicted_label', 'predicted_label_id',
'correct_label', 'correct_label_id'
])
if 'wandb_run' not in st.session_state:
st.session_state.wandb_run = wandb.init(project='hydra-classifier', name='feedback-loop')
@st.cache_data
def classify_image(_image):
print(f'Encoding image with index {i}')
encoding = processor(
image,
return_tensors="pt",
truncation=True,
max_length=512,
)
print(f'Predicting image with index {i}')
outputs = model(**encoding)
prediction = outputs.logits.argmax(-1)[0].item()
return prediction
if uploaded_file:
if uploaded_file.type == "application/pdf":
images = convert_from_bytes(uploaded_file.getvalue())
else:
images = [Image.open(uploaded_file)]
for i, image in enumerate(images):
st.image(image, caption=f'Uploaded Image {i}', use_container_width=True)
prediction = classify_image(image)
st.write(f"Prediction: {id2label[prediction]}")
feedback = st.radio(
"Is the classification correct?", ("Yes", "No"),
key=f'prediction-{i}'
)
if feedback == "No":
correct_label = st.selectbox(
"Please select the correct label:", labels,
key=f'selectbox-{i}'
)
print(f'Correct label for image {i}: {correct_label}')
# Add a button to confirm feedback and log it
if st.button(f"Add feedback for Image {i}", key=f'add-{i}'):
feedback_table.add_data(
wandb.Image(image),
uploaded_file.type,
id2label[prediction],
prediction,
correct_label,
label2id[correct_label],
)
if st.button("Submit all feedback", key=f'submit'):
run = st.session_state.wandb_run
run.log({'feedback_table': feedback_table})
run.finish()
st.success(f"Feedback submitted!")