import json import os from typing import TypedDict, List import pandas as pd import streamlit as st from huggingface_hub import HfFileSystem class InputFieldDict(TypedDict): name: str type: str title: str # Function to get user ID from URL def get_user_id_from_url(): user_id = st.query_params.get("user_id", [""])[0] # Default to empty string if not found return user_id HF_TOKEN = os.environ.get("HF_TOKEN_WRITE") print("is none?", HF_TOKEN is None) hf_fs = HfFileSystem(token=HF_TOKEN) ##### Change these variables to your own paths ##### input_repo_path = 'datasets/ijundi/annotate-test' output_repo_path = 'datasets/ijundi/annotate-test' to_annotate_file_name = 'to_annotate.csv' # CSV file to annotate COLS_TO_SAVE = ['id'] fields: List[InputFieldDict] = [ ##### Change these fields to your own fields (name is column name from to_annotate_file_name) ##### InputFieldDict(name="text", type="input_col", title="**Text:**"), InputFieldDict(name="score", type="input_col", title="**Score:**"), ##### Change these fields to your own fields ##### InputFieldDict(name="name 3", type="slider", title="Question 3"), InputFieldDict(name="name 4", type="markdown", title="---\n Section"), InputFieldDict(name="name 5", type="slider", title="Question 5"), InputFieldDict(name="name 6", type="slider", title="Question 6"), InputFieldDict(name="sep", type="markdown", title="---"), InputFieldDict(name="name 8", type="slider", title="Question 8"), InputFieldDict(name="name 9", type="slider", title="Question 9"), InputFieldDict(name="name 10", type="slider", title="Question 10"), InputFieldDict(name="name 11", type="slider", title="Question 11"), InputFieldDict(name="name 12", type="text", title="Question 12"), ] def read_data(_path): with hf_fs.open(input_repo_path + '/' + _path) as f: return pd.read_csv(f) def read_saved_data(): _path = get_path() if hf_fs.exists(output_repo_path + '/' + _path): with hf_fs.open(output_repo_path + '/' + _path) as f: return json.load(f) return None # Write a remote file def save_data(data): hf_fs.mkdir(f"{output_repo_path}/{data['user_id']}") with hf_fs.open(f"{output_repo_path}/{get_path()}", "w") as f: f.write(json.dumps(data, default=str)) def get_path(): return f"{st.session_state.user_id}/{st.session_state.current_index}.json" #################################### Streamlit App #################################### # Function to navigate rows def navigate(index_change): st.session_state.current_index += index_change print(st.session_state.current_index) # https://discuss.streamlit.io/t/click-twice-on-button-for-changing-state/45633/2 st.rerun() # st.set_page_config(layout='wide') # Title of the app st.title("Simple Annotation App") # Load the data to annotate if 'data' not in st.session_state: st.session_state.data = read_data(to_annotate_file_name) # Initialize the current index if 'current_index' not in st.session_state: st.session_state.current_index = -1 if st.session_state.current_index == -1: st.session_state.user_id = st.text_input('Please enter your user ID to proceed', value=get_user_id_from_url()) if st.button("Next"): navigate(1) elif st.session_state.current_index < len(st.session_state.data): st.write(f"username is {st.session_state.user_id}") # Creating the form with st.form("feedback_form"): index = st.session_state.current_index data_collected = read_saved_data() st.session_state.default_values = {} st.session_state.data_inputs = {} for field in fields: _name, _type, _title = field.values() key = _name + str(index) match _type: case 'input_col': st.write(_title) st.write(st.session_state.data.iloc[index][_name]) case 'markdown': st.markdown(_title) case _: value = st.session_state.default_values[_name] = data_collected[_name] if data_collected else \ {'slider': 0, 'text': None}[_type] if _type == 'slider': st.session_state.data_inputs[_name] = st.slider(_title, min_value=0, max_value=100, step=25, key=key, value=value) else: st.session_state.data_inputs[_name] = st.text_area(_title, key=key, value=value) submitted = st.form_submit_button("Submit") if submitted: with st.spinner(text="saving"): save_data({ 'user_id': st.session_state.user_id, 'index': st.session_state.current_index, **{k: st.session_state.data.iloc[index][k] for k in COLS_TO_SAVE}, **st.session_state.data_inputs }) st.success("Feedback submitted successfully!") navigate(1) else: st.write("Finished all data points!") # Navigation buttons if st.session_state.current_index > 0: if st.button("Previous"): with st.spinner(text="in progress"): navigate(-1) if 0 <= st.session_state.current_index < len(st.session_state.data): st.write(f"Page {st.session_state.current_index + 1} out of {len(st.session_state.data)}")