|
import numpy as np |
|
import pandas as pd |
|
import torch |
|
from transformers import GPT2Tokenizer |
|
from pathlib import Path |
|
import streamlit as st |
|
from typing import List, Dict, Any, Callable |
|
from pred import * |
|
from load_data import * |
|
|
|
def main(): |
|
tokenizer = GPT2Tokenizer.from_pretrained('gpt2', add_bos_token=True) |
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
|
encoder = Encoder(h=64,n=2, e=64, a=4, o=64).to(device) |
|
decoder = Decoder(h=64,n=2, e=64, a=4, o=50257).to(device) |
|
model = Seq2Seq(encoder, decoder).to(device) |
|
|
|
checkpoint = torch.load('./seq2seq_checkpoint.pt', weights_only=True, map_location=device) |
|
model.load_state_dict(checkpoint['model_state_dict']) |
|
st.title("Footy Commentary Generator") |
|
|
|
st.sidebar.header("Configuration") |
|
|
|
tab_selection = st.sidebar.radio( |
|
"Select Input Method:", |
|
["Random Sample from Test Set", "Custom Input"] |
|
) |
|
|
|
st.sidebar.header("Decoding Configuration") |
|
st.session_state.decoding_mode = st.sidebar.selectbox( |
|
"Decoding Mode", |
|
["greedy", "sample", "top-k", "diverse-beam-search", "min-bayes-risk"] |
|
) |
|
|
|
st.session_state.decoding_params = {} |
|
st.session_state.decoding_params['max_len'] = st.sidebar.slider('Max length', 1, 500, 50) |
|
st.session_state.decoding_params['temperature'] = st.sidebar.slider('Temperature', 0.0, 1.0, 0.1) |
|
if st.session_state.decoding_mode == "top-k": |
|
st.session_state.decoding_params["k"] = st.sidebar.slider("k value", 1, 100, 5) |
|
elif st.session_state.decoding_mode == "diverse-beam-search": |
|
st.session_state.decoding_params["beam_width"] = st.sidebar.slider("beam width", 1, 10, 1) |
|
st.session_state.decoding_params["diversity_penalty"] = st.sidebar.slider("diversity penalty", 0.0, 1.0, 0.1) |
|
elif st.session_state.decoding_mode == "min-bayes-risk": |
|
st.session_state.decoding_params["num_candidates"] = st.sidebar.slider("Number of candidates", 1, 30, 4) |
|
|
|
if tab_selection == "Random Sample from Test Set": |
|
st.header("Generate from Test Dataset") |
|
|
|
col1, col2 = st.columns([3, 1]) |
|
|
|
with col1: |
|
|
|
st.write(f"Test dataset contains 5000 samples") |
|
|
|
with col2: |
|
|
|
if st.button("Generate Random Sample"): |
|
random_idx = np.random.randint(1, 5000) |
|
st.session_state.random_idx = random_idx |
|
st.session_state.ip, st.session_state.ip_mask, st.session_state.tg, st.session_state.tg_mask = get_sample(random_idx) |
|
|
|
|
|
if hasattr(st.session_state, 'random_idx'): |
|
st.subheader(f"Sample #{st.session_state.random_idx}") |
|
st.session_state.x = tokenizer.decode(st.session_state.ip.tolist()[0]) |
|
st.session_state.y = tokenizer.decode(st.session_state.tg.tolist()) |
|
|
|
df = pd.DataFrame.from_dict({'X': [st.session_state.x], 'y': [st.session_state.y]}) |
|
st.dataframe(df.T.reset_index(), width=800) |
|
|
|
|
|
if st.button("Generate Sequence"): |
|
with st.spinner("Generating sequence..."): |
|
print(f'Ip: {st.session_state.ip} | Mask: {st.session_state.ip_mask} \n Mode: {st.session_state.decoding_mode} | Params: {st.session_state.decoding_params}') |
|
st.session_state.tok_output = genOp( |
|
encoder, decoder, device, |
|
st.session_state.ip, |
|
st.session_state.ip_mask, |
|
mode=st.session_state.decoding_mode, |
|
**st.session_state.decoding_params |
|
) |
|
print(f'\n\n\nOutput: {st.session_state.tok_output} \n') |
|
st.session_state.output = tokenizer.decode(st.session_state.tok_output) |
|
|
|
|
|
if hasattr(st.session_state, 'output'): |
|
st.subheader("Generated Sequence") |
|
st.write(st.session_state.output) |
|
|
|
if __name__ == "__main__": |
|
main() |
|
1 |