ijundi's picture
Update app.py
046993f verified
import json
import os
from typing import TypedDict, List
import pandas as pd
import streamlit as st
from huggingface_hub import HfFileSystem
class InputFieldDict(TypedDict):
name: str
type: str
title: str
# Function to get user ID from URL
def get_user_id_from_url():
user_id = st.query_params.get("user_id", [""])[0] # Default to empty string if not found
return user_id
HF_TOKEN = os.environ.get("HF_TOKEN_WRITE")
print("is none?", HF_TOKEN is None)
hf_fs = HfFileSystem(token=HF_TOKEN)
##### Change these variables to your own paths #####
input_repo_path = 'datasets/ijundi/annotate-test'
output_repo_path = 'datasets/ijundi/annotate-test'
to_annotate_file_name = 'to_annotate.csv' # CSV file to annotate
COLS_TO_SAVE = ['id']
fields: List[InputFieldDict] = [
##### Change these fields to your own fields (name is column name from to_annotate_file_name) #####
InputFieldDict(name="text", type="input_col", title="**Text:**"),
InputFieldDict(name="score", type="input_col", title="**Score:**"),
##### Change these fields to your own fields #####
InputFieldDict(name="name 3", type="slider", title="Question 3"),
InputFieldDict(name="name 4", type="markdown", title="---\n Section"),
InputFieldDict(name="name 5", type="slider", title="Question 5"),
InputFieldDict(name="name 6", type="slider", title="Question 6"),
InputFieldDict(name="sep", type="markdown", title="---"),
InputFieldDict(name="name 8", type="slider", title="Question 8"),
InputFieldDict(name="name 9", type="slider", title="Question 9"),
InputFieldDict(name="name 10", type="slider", title="Question 10"),
InputFieldDict(name="name 11", type="slider", title="Question 11"),
InputFieldDict(name="name 12", type="text", title="Question 12"),
]
def read_data(_path):
with hf_fs.open(input_repo_path + '/' + _path) as f:
return pd.read_csv(f)
def read_saved_data():
_path = get_path()
if hf_fs.exists(output_repo_path + '/' + _path):
with hf_fs.open(output_repo_path + '/' + _path) as f:
return json.load(f)
return None
# Write a remote file
def save_data(data):
hf_fs.mkdir(f"{output_repo_path}/{data['user_id']}")
with hf_fs.open(f"{output_repo_path}/{get_path()}", "w") as f:
f.write(json.dumps(data, default=str))
def get_path():
return f"{st.session_state.user_id}/{st.session_state.current_index}.json"
#################################### Streamlit App ####################################
# Function to navigate rows
def navigate(index_change):
st.session_state.current_index += index_change
print(st.session_state.current_index)
# https://discuss.streamlit.io/t/click-twice-on-button-for-changing-state/45633/2
st.rerun()
# st.set_page_config(layout='wide')
# Title of the app
st.title("Simple Annotation App")
# Load the data to annotate
if 'data' not in st.session_state:
st.session_state.data = read_data(to_annotate_file_name)
# Initialize the current index
if 'current_index' not in st.session_state:
st.session_state.current_index = -1
if st.session_state.current_index == -1:
st.session_state.user_id = st.text_input('Please enter your user ID to proceed', value=get_user_id_from_url())
if st.button("Next"):
navigate(1)
elif st.session_state.current_index < len(st.session_state.data):
st.write(f"username is {st.session_state.user_id}")
# Creating the form
with st.form("feedback_form"):
index = st.session_state.current_index
data_collected = read_saved_data()
st.session_state.default_values = {}
st.session_state.data_inputs = {}
for field in fields:
_name, _type, _title = field.values()
key = _name + str(index)
match _type:
case 'input_col':
st.write(_title)
st.write(st.session_state.data.iloc[index][_name])
case 'markdown':
st.markdown(_title)
case _:
value = st.session_state.default_values[_name] = data_collected[_name] if data_collected else \
{'slider': 0,
'text': None}[_type]
if _type == 'slider':
st.session_state.data_inputs[_name] = st.slider(_title, min_value=0, max_value=100, step=25,
key=key,
value=value)
else:
st.session_state.data_inputs[_name] = st.text_area(_title, key=key, value=value)
submitted = st.form_submit_button("Submit")
if submitted:
with st.spinner(text="saving"):
save_data({
'user_id': st.session_state.user_id,
'index': st.session_state.current_index,
**{k: st.session_state.data.iloc[index][k] for k in COLS_TO_SAVE},
**st.session_state.data_inputs
})
st.success("Feedback submitted successfully!")
navigate(1)
else:
st.write("Finished all data points!")
# Navigation buttons
if st.session_state.current_index > 0:
if st.button("Previous"):
with st.spinner(text="in progress"):
navigate(-1)
if 0 <= st.session_state.current_index < len(st.session_state.data):
st.write(f"Page {st.session_state.current_index + 1} out of {len(st.session_state.data)}")