Spaces:
Sleeping
Sleeping
import os | |
import wandb | |
import streamlit as st | |
from transformers import LayoutLMv3Processor, LayoutLMv3ForSequenceClassification | |
from pdf2image import convert_from_bytes | |
from PIL import Image | |
wandb_api_key = os.getenv("WANDB_API_KEY") | |
if not wandb_api_key: | |
st.error( | |
"Couldn't find WanDB API key. Please set it up as an environemnt variable", | |
icon="🚨", | |
) | |
else: | |
wandb.login(key=wandb_api_key) | |
labels = [ | |
'budget', | |
'email', | |
'form', | |
'handwritten', | |
'invoice', | |
'language', | |
'letter', | |
'memo', | |
'news article', | |
'questionnaire', | |
'resume', | |
'scientific publication', | |
'specification', | |
] | |
id2label = {i: label for i, label in enumerate(labels)} | |
label2id = {v: k for k, v in id2label.items()} | |
if 'model' not in st.session_state: | |
st.session_state.model = LayoutLMv3ForSequenceClassification.from_pretrained("model/layoutlmv3/") | |
if 'processor' not in st.session_state: | |
st.session_state.processor = LayoutLMv3Processor.from_pretrained("model/layoutlmv3/") | |
model = st.session_state.model | |
processor = st.session_state.processor | |
st.title("Document Classification with LayoutLMv3") | |
uploaded_file = st.file_uploader( | |
"Upload Document", type=["pdf", "jpg", "png"], accept_multiple_files=False | |
) | |
feedback_table = wandb.Table(columns=[ | |
'image', 'filetype', 'predicted_label', 'predicted_label_id', | |
'correct_label', 'correct_label_id' | |
]) | |
if 'wandb_run' not in st.session_state: | |
st.session_state.wandb_run = wandb.init(project='hydra-classifier', name='feedback-loop') | |
def classify_image(_image): | |
print(f'Encoding image with index {i}') | |
encoding = processor( | |
image, | |
return_tensors="pt", | |
truncation=True, | |
max_length=512, | |
) | |
print(f'Predicting image with index {i}') | |
outputs = model(**encoding) | |
prediction = outputs.logits.argmax(-1)[0].item() | |
return prediction | |
if uploaded_file: | |
if uploaded_file.type == "application/pdf": | |
images = convert_from_bytes(uploaded_file.getvalue()) | |
else: | |
images = [Image.open(uploaded_file)] | |
for i, image in enumerate(images): | |
st.image(image, caption=f'Uploaded Image {i}', use_container_width=True) | |
prediction = classify_image(image) | |
st.write(f"Prediction: {id2label[prediction]}") | |
feedback = st.radio( | |
"Is the classification correct?", ("Yes", "No"), | |
key=f'prediction-{i}' | |
) | |
if feedback == "No": | |
correct_label = st.selectbox( | |
"Please select the correct label:", labels, | |
key=f'selectbox-{i}' | |
) | |
print(f'Correct label for image {i}: {correct_label}') | |
# Add a button to confirm feedback and log it | |
if st.button(f"Add feedback for Image {i}", key=f'add-{i}'): | |
feedback_table.add_data( | |
wandb.Image(image), | |
uploaded_file.type, | |
id2label[prediction], | |
prediction, | |
correct_label, | |
label2id[correct_label], | |
) | |
if st.button("Submit all feedback", key=f'submit'): | |
run = st.session_state.wandb_run | |
run.log({'feedback_table': feedback_table}) | |
run.finish() | |
st.success(f"Feedback submitted!") | |