Spaces:
Runtime error
Runtime error
import json | |
import os | |
from typing import TypedDict, List | |
import pandas as pd | |
import streamlit as st | |
from huggingface_hub import HfFileSystem | |
class InputFieldDict(TypedDict): | |
name: str | |
type: str | |
title: str | |
# Function to get user ID from URL | |
def get_user_id_from_url(): | |
user_id = st.query_params.get("user_id", [""])[0] # Default to empty string if not found | |
return user_id | |
HF_TOKEN = os.environ.get("HF_TOKEN_WRITE") | |
print("is none?", HF_TOKEN is None) | |
hf_fs = HfFileSystem(token=HF_TOKEN) | |
##### Change these variables to your own paths ##### | |
input_repo_path = 'datasets/ijundi/annotate-test' | |
output_repo_path = 'datasets/ijundi/annotate-test' | |
to_annotate_file_name = 'to_annotate.csv' # CSV file to annotate | |
COLS_TO_SAVE = ['id'] | |
fields: List[InputFieldDict] = [ | |
##### Change these fields to your own fields (name is column name from to_annotate_file_name) ##### | |
InputFieldDict(name="text", type="input_col", title="**Text:**"), | |
InputFieldDict(name="score", type="input_col", title="**Score:**"), | |
##### Change these fields to your own fields ##### | |
InputFieldDict(name="name 3", type="slider", title="Question 3"), | |
InputFieldDict(name="name 4", type="markdown", title="---\n Section"), | |
InputFieldDict(name="name 5", type="slider", title="Question 5"), | |
InputFieldDict(name="name 6", type="slider", title="Question 6"), | |
InputFieldDict(name="sep", type="markdown", title="---"), | |
InputFieldDict(name="name 8", type="slider", title="Question 8"), | |
InputFieldDict(name="name 9", type="slider", title="Question 9"), | |
InputFieldDict(name="name 10", type="slider", title="Question 10"), | |
InputFieldDict(name="name 11", type="slider", title="Question 11"), | |
InputFieldDict(name="name 12", type="text", title="Question 12"), | |
] | |
def read_data(_path): | |
with hf_fs.open(input_repo_path + '/' + _path) as f: | |
return pd.read_csv(f) | |
def read_saved_data(): | |
_path = get_path() | |
if hf_fs.exists(output_repo_path + '/' + _path): | |
with hf_fs.open(output_repo_path + '/' + _path) as f: | |
return json.load(f) | |
return None | |
# Write a remote file | |
def save_data(data): | |
hf_fs.mkdir(f"{output_repo_path}/{data['user_id']}") | |
with hf_fs.open(f"{output_repo_path}/{get_path()}", "w") as f: | |
f.write(json.dumps(data, default=str)) | |
def get_path(): | |
return f"{st.session_state.user_id}/{st.session_state.current_index}.json" | |
#################################### Streamlit App #################################### | |
# Function to navigate rows | |
def navigate(index_change): | |
st.session_state.current_index += index_change | |
print(st.session_state.current_index) | |
# https://discuss.streamlit.io/t/click-twice-on-button-for-changing-state/45633/2 | |
st.rerun() | |
# st.set_page_config(layout='wide') | |
# Title of the app | |
st.title("Simple Annotation App") | |
# Load the data to annotate | |
if 'data' not in st.session_state: | |
st.session_state.data = read_data(to_annotate_file_name) | |
# Initialize the current index | |
if 'current_index' not in st.session_state: | |
st.session_state.current_index = -1 | |
if st.session_state.current_index == -1: | |
st.session_state.user_id = st.text_input('Please enter your user ID to proceed', value=get_user_id_from_url()) | |
if st.button("Next"): | |
navigate(1) | |
elif st.session_state.current_index < len(st.session_state.data): | |
st.write(f"username is {st.session_state.user_id}") | |
# Creating the form | |
with st.form("feedback_form"): | |
index = st.session_state.current_index | |
data_collected = read_saved_data() | |
st.session_state.default_values = {} | |
st.session_state.data_inputs = {} | |
for field in fields: | |
_name, _type, _title = field.values() | |
key = _name + str(index) | |
match _type: | |
case 'input_col': | |
st.write(_title) | |
st.write(st.session_state.data.iloc[index][_name]) | |
case 'markdown': | |
st.markdown(_title) | |
case _: | |
value = st.session_state.default_values[_name] = data_collected[_name] if data_collected else \ | |
{'slider': 0, | |
'text': None}[_type] | |
if _type == 'slider': | |
st.session_state.data_inputs[_name] = st.slider(_title, min_value=0, max_value=100, step=25, | |
key=key, | |
value=value) | |
else: | |
st.session_state.data_inputs[_name] = st.text_area(_title, key=key, value=value) | |
submitted = st.form_submit_button("Submit") | |
if submitted: | |
with st.spinner(text="saving"): | |
save_data({ | |
'user_id': st.session_state.user_id, | |
'index': st.session_state.current_index, | |
**{k: st.session_state.data.iloc[index][k] for k in COLS_TO_SAVE}, | |
**st.session_state.data_inputs | |
}) | |
st.success("Feedback submitted successfully!") | |
navigate(1) | |
else: | |
st.write("Finished all data points!") | |
# Navigation buttons | |
if st.session_state.current_index > 0: | |
if st.button("Previous"): | |
with st.spinner(text="in progress"): | |
navigate(-1) | |
if 0 <= st.session_state.current_index < len(st.session_state.data): | |
st.write(f"Page {st.session_state.current_index + 1} out of {len(st.session_state.data)}") | |