# install.packages("~/Documents/fastrerandomize-software/fastrerandomize",repos = NULL, type = "source",force = F)
# ============================================================
# app.R | Shiny App for Rerandomization with fastrerandomize
# ============================================================
# 1) The user can upload or simulate a covariate dataset (X).
# 2) They specify rerandomization parameters: n_treated, acceptance prob, etc.
# 3) The app generates a set of accepted randomizations under rerandomization.
# 4) The user can optionally upload or simulate outcomes (Y) and run a randomization test.
# 5) The app displays distribution of the balance measure (e.g., Hotelling's T^2)
# and final p-value/fiducial interval, along with run-time comparisons between
# fastrerandomize and base R methods.
#
# ----------------------------
# Load required packages
# ----------------------------
options(error=NULL)
library(shiny)
library(shinydashboard)
library(DT) # For data tables
library(ggplot2) # For basic plotting
library(fastrerandomize) # Our rerandomization package
library(parallel) # For detecting CPU cores
# For production apps, ensure fastrerandomize is installed:
# install.packages("devtools")
# devtools::install_github("cjerzak/fastrerandomize-software/fastrerandomize")
# ---------------------------------------------------------
# UI Section
# ---------------------------------------------------------
ui <- dashboardPage(
# ========== Header =================
dashboardHeader(
title = span(
style = "font-weight: 600; font-size: 14px;",
a(
href = "https://fastrerandomize.github.io/",
"fastrerandomize.github.io",
target = "_blank",
style = "color: white; text-decoration: underline; font-family: 'Helvetica Neue', Helvetica, Arial, sans-serif;"
)
)
),
# ========== Sidebar ================
dashboardSidebar(
sidebarMenu(
menuItem("1. Data & Covariates", tabName = "datatab", icon = icon("database")),
menuItem("2. Generate Randomizations", tabName = "gennet", icon = icon("random")),
menuItem("3. Randomization Test", tabName = "randtest", icon = icon("flask")),
# ---- Here is the minimal "Share" button HTML + JS inlined in Shiny ----
# We wrap it in tags$div(...) and tags$script(HTML(...)) so it is recognized
# by Shiny. You can adjust the styling or placement as needed.
tags$div(
style = "text-align: left; margin: 1em 0 1em 1em;",
HTML('
'),
# Insert the JS as well
tags$script(
HTML("
(function() {
const shareBtn = document.getElementById('share-button');
// Reusable helper function to show a small “Copied!” message
function showCopyNotification() {
const notification = document.createElement('div');
notification.innerText = 'Copied to clipboard';
notification.style.position = 'fixed';
notification.style.bottom = '20px';
notification.style.right = '20px';
notification.style.backgroundColor = 'rgba(0, 0, 0, 0.8)';
notification.style.color = '#fff';
notification.style.padding = '8px 12px';
notification.style.borderRadius = '4px';
notification.style.zIndex = '9999';
document.body.appendChild(notification);
setTimeout(() => { notification.remove(); }, 2000);
}
shareBtn.addEventListener('click', function() {
const currentURL = window.location.href;
const pageTitle = document.title || 'Check this out!';
// If browser supports Web Share API
if (navigator.share) {
navigator.share({
title: pageTitle,
text: '',
url: currentURL
})
.catch((error) => {
console.log('Sharing failed', error);
});
} else {
// Fallback: Copy URL
if (navigator.clipboard && navigator.clipboard.writeText) {
navigator.clipboard.writeText(currentURL).then(() => {
showCopyNotification();
}, (err) => {
console.error('Could not copy text: ', err);
});
} else {
// Double fallback for older browsers
const textArea = document.createElement('textarea');
textArea.value = currentURL;
document.body.appendChild(textArea);
textArea.select();
try {
document.execCommand('copy');
showCopyNotification();
} catch (err) {
alert('Please copy this link:\\n' + currentURL);
}
document.body.removeChild(textArea);
}
}
});
})();
")
)
),
# ---- End: Minimal Share button snippet ----
tags$div(
style = "text-align: left; margin: 4em 0 1em 1em;",
HTML("
Citation:
fastrerandomize (2025).
PDF |
BibTeX
"
)
)
)
),
# ========== Body ===================
dashboardBody(
# A little CSS to keep the design timeless and clean
tags$head(
tags$style(HTML("
.smalltext { font-size: 90%; color: #555; }
.shiny-output-error { color: red; }
.shiny-input-container { margin-bottom: 15px; }
"))
),
tabItems(
# ------------------------------------------------
# 1) Data & Covariates Tab
# ------------------------------------------------
tabItem(
tabName = "datatab",
fluidRow(
box(width = 5, title = "Covariate Data: Upload or Simulate",
status = "primary", solidHeader = TRUE,
radioButtons("data_source", "Data Source:",
choices = c("Upload CSV" = "upload",
"Simulate data" = "simulate"),
selected = "simulate"),
conditionalPanel(
condition = "input.data_source == 'upload'",
fileInput("file_covariates", "Choose CSV File",
accept = c(".csv")),
helpText("Columns = features/covariates, rows = units.")
),
conditionalPanel(
condition = "input.data_source == 'simulate'",
numericInput("sim_n", "Number of units (rows)",
value = 64, min = 10),
numericInput("sim_p", "Number of covariates (columns)",
value = 32, min = 2),
actionButton("simulate_btn", "Simulate X")
)
),
box(width = 7, title = "Preview of Covariates (X)",
status = "info", solidHeader = TRUE,
DTOutput("covariates_table"))
)
),
# ------------------------------------------------
# 2) Generate Randomizations Tab
# ------------------------------------------------
tabItem(
tabName = "gennet",
fluidRow(
box(width = 4, title = "Rerandomization Parameters",
status = "primary", solidHeader = TRUE,
numericInput("n_treated", "Number Treated (n_treated)",
value = 10, min = 1),
selectInput("random_type", "Randomization Type:",
choices = c("Monte Carlo" = "monte_carlo",
"Exact" = "exact"),
selected = "monte_carlo"),
numericInput("accept_prob", "Acceptance Probability (stringency)",
value = 0.01, min = 0.0001, max = 1),
conditionalPanel(
condition = "input.random_type == 'monte_carlo'",
numericInput("max_draws", "Max Draws (MC)", value = 1e5, min = 1e3),
numericInput("batch_size", "Batch Size (MC)", value = 1e3, min = 1e2)
),
actionButton("generate_btn", "Generate")
),
box(width = 8, title = "Summary of Accepted Randomizations",
status = "info", solidHeader = TRUE,
# First row of boxes: accepted randomizations and min balance measure
fluidRow(
column(width = 6, valueBoxOutput("n_accepted_box", width = 12)),
column(width = 6, valueBoxOutput("balance_min_box", width = 12))
),
# Second row of boxes: fastrerandomize time & base R time
fluidRow(
column(width = 6, valueBoxOutput("fastrerand_time_box", width = 12)),
column(width = 6, valueBoxOutput("baseR_time_box", width = 12))
),
br(),
plotOutput("balance_hist", height = "250px"),
# Hardware info note
br(),
uiOutput("hardware_info")
)
)
),
# ------------------------------------------------
# 3) Randomization Test Tab
# ------------------------------------------------
tabItem(
tabName = "randtest",
fluidRow(
box(
width = 4, title = "Randomization Test Setup",
status = "primary", solidHeader = TRUE,
# (Existing UI elements for Y already in your code)
radioButtons("outcome_source", "Outcome Data (Y):",
choices = c("Simulate Y" = "simulate",
"Upload CSV" = "uploadY"),
selected = "simulate"),
conditionalPanel(
condition = "input.outcome_source == 'simulate'",
numericInput("true_tau", "True Effect (simulate)", 1, step = 0.5),
numericInput("noise_sd", "Noise SD for Y", 0.5, step = 0.1),
actionButton("simulateY_btn", "Simulate Y")
),
conditionalPanel(
condition = "input.outcome_source == 'uploadY'",
fileInput("file_outcomes", "Choose CSV File with outcome vector Y",
accept = c(".csv")),
helpText("Single column with length = #units.")
),
br(),
actionButton("run_randtest_btn", "Run Test"),
checkboxInput("findFI", "Compute Fiducial Interval?", value = TRUE)
),
box(
width = 6, title = "Preview of Outcomes (Y)",
status = "info", solidHeader = TRUE,
DTOutput("outcomes_table")
)
),
fluidRow(
box(
width = 4, title = NULL, status = NULL,
background = NULL, solidHeader = FALSE, collapsible = FALSE,
tags$p("Note: Relative speedups greatest for large number of accepted randomizations.",
style = "color:#555; font-size:90%; margin:0;")
),
box(width = 8, title = "Test Results", status = "info", solidHeader = TRUE,
# First row: p-value and observed effect (fastrerandomize)
fluidRow(
column(width = 6, valueBoxOutput("pvalue_box", width = 12)),
column(width = 6, valueBoxOutput("tauobs_box", width = 12))
),
# Second row: fastrerandomize test time & base R test time
fluidRow(
column(width = 6, valueBoxOutput("fastrerand_test_time_box", width = 12)),
column(width = 6, valueBoxOutput("baseR_test_time_box", width = 12))
),
# Show fastrerandomize FI
uiOutput("fi_text"),
# Now show Base R results in a separate row
tags$hr(),
fluidRow(
column(width = 6, valueBoxOutput("pvalue_box_baseR", width = 12)),
column(width = 6, valueBoxOutput("tauobs_box_baseR", width = 12))
),
fluidRow(
column(width = 12, uiOutput("fi_text_baseR"))
),
br(),
plotOutput("test_plot", height = "280px")
)
)
)
) # end tabItems
) # end dashboardBody
) # end dashboardPage
# ---------------------------------------------------------
# SERVER
# ---------------------------------------------------------
server <- function(input, output, session) {
# -------------------------------------------------------
# 1. Covariate Data Handling
# -------------------------------------------------------
# We store the covariate matrix X in a reactiveVal for convenient reuse
X_data <- reactiveVal(NULL)
# Observe file input or simulation for X
observeEvent(input$file_covariates, {
req(input$file_covariates)
inFile <- input$file_covariates
df <- tryCatch(read.csv(inFile$datapath, header = TRUE),
error = function(e) NULL)
if (!is.null(df)) {
X_data(as.matrix(df))
}
})
# If the user clicks "Simulate X"
observeEvent(input$simulate_btn, {
n <- input$sim_n
p <- input$sim_p
# Basic simulation of N(0,1) data
simX <- matrix(rnorm(n * p), nrow = n, ncol = p)
X_data(simX)
})
# Show X in table
output$covariates_table <- renderDT({
req(X_data())
# Round all numeric columns to 3 significant digits
df <- as.data.frame(X_data())
numeric_cols <- sapply(df, is.numeric)
df[numeric_cols] <- lapply(df[numeric_cols], signif, digits = 3)
datatable(df, options = list(scrollX = TRUE, pageLength = 10))
})
# -------------------------------------------------------
# 2. Generate Rerandomizations
# -------------------------------------------------------
# We'll keep the accepted randomizations from fastrerandomize in RerandResult
# and from base R in RerandResult_base.
RerandResult <- reactiveVal(NULL)
RerandResult_base <- reactiveVal(NULL)
# We also store their run times
fastrand_time <- reactiveVal(NULL)
baseR_time <- reactiveVal(NULL)
observeEvent(input$generate_btn, {
req(X_data())
validate(
need(nrow(X_data()) >= input$n_treated,
"Number treated cannot exceed total units.")
)
withProgress(message = "Computing results...", value = 0, {
# =========== 1) fastrerandomize generation timing ===========
t0_fast <- Sys.time()
out <- tryCatch({
generate_randomizations(
n_units = nrow(X_data()),
n_treated = input$n_treated,
X = X_data(),
randomization_accept_prob= input$accept_prob,
randomization_type = input$random_type,
max_draws = if (input$random_type == "monte_carlo") input$max_draws else NULL,
batch_size = if (input$random_type == "monte_carlo") input$batch_size else NULL,
verbose = FALSE
)
}, error = function(e) e)
t1_fast <- Sys.time()
if (inherits(out, "error")) {
showNotification(paste("Error generating randomizations (fastrerandomize):", out$message), type = "error")
RerandResult(NULL)
} else {
RerandResult(out)
}
fastrand_time(difftime(t1_fast, t0_fast, units = "secs"))
# =========== 2) base R generation timing ===========
t0_base <- Sys.time()
out_base <- tryCatch({
generate_randomizations_R(
n_units = nrow(X_data()),
n_treated = input$n_treated,
X = X_data(),
accept_prob= input$accept_prob,
random_type= input$random_type,
max_draws = if (input$random_type == "monte_carlo") input$max_draws else NULL,
batch_size = if (input$random_type == "monte_carlo") input$batch_size else NULL
)
}, error = function(e) e)
t1_base <- Sys.time()
if (inherits(out_base, "error")) {
showNotification(paste("Error generating randomizations (base R):", out_base$message), type = "error")
RerandResult_base(NULL)
} else {
RerandResult_base(out_base)
}
baseR_time(difftime(t1_base, t0_base, units = "secs"))
})
})
# Summaries of accepted randomizations
output$n_accepted_box <- renderValueBox({
rr <- RerandResult()
if (is.null(rr) || is.null(rr$randomizations)) {
valueBox("0", "Accepted Randomizations", icon = icon("ban"), color = "red")
} else {
nAcc <- nrow(rr$randomizations)
valueBox(nAcc, "Accepted Randomizations", icon = icon("check"), color = "green")
}
})
output$balance_min_box <- renderValueBox({
rr <- RerandResult()
if (is.null(rr) || is.null(rr$balance)) {
valueBox("---", "Min Balance Measure", icon = icon("question"), color = "orange")
} else {
minBal <- round(min(rr$balance), 3)
valueBox(minBal, "Min Balance Measure", icon = icon("thumbs-up"), color = "blue")
}
})
# Timings for generation: fastrerandomize and base R
output$fastrerand_time_box <- renderValueBox({
tm <- fastrand_time()
if (is.null(tm)) {
valueBox("---", "fastrerandomize generation time (secs)", icon = icon("clock"), color = "teal")
} else {
valueBox(round(as.numeric(tm), 3), "fastrerandomize generation time (secs)",
icon = icon("clock"), color = "teal")
}
})
output$baseR_time_box <- renderValueBox({
tm <- baseR_time()
if (is.null(tm)) {
valueBox("---", "base R generation time (secs)", icon = icon("clock"), color = "lime")
} else {
valueBox(round(as.numeric(tm), 3), "base R generation time (secs)",
icon = icon("clock"), color = "lime")
}
})
# Plot histogram of the balance measure (fastrerandomize result)
output$balance_hist <- renderPlot({
rr <- RerandResult()
req(rr, rr$balance)
df <- data.frame(balance = rr$balance)
ggplot(df, aes(x = balance)) +
geom_histogram(binwidth = diff(range(df$balance))/30, fill = "darkblue", alpha = 0.7) +
labs(title = "Distribution of Balance Statistic",
subtitle = "Among Accepted Randomizations",
x = "Balance (i.e., T^2)",
y = "Frequency") +
theme_minimal(base_size = 14)
})
# Hardware info (CPU cores, GPU note)
output$hardware_info <- renderUI({
num_cores <- detectCores(logical = TRUE)
HTML(paste(
"System Hardware Info:
",
"Number of CPU cores detected:", num_cores, "
",
"With additional CPU or GPU, greater speedups can be expected.
",
"Note: Speedups greatest in high-dimensional or large-N settings.
"
))
})
# -------------------------------------------------------
# 3. Randomization Test
# -------------------------------------------------------
Y_data <- reactiveVal(NULL)
# (A) If user simulates Y
observeEvent(input$simulateY_btn, {
req(RerandResult())
rr <- RerandResult()
if (is.null(rr$randomizations) || nrow(rr$randomizations) < 1) {
showNotification("No accepted randomizations found. Cannot simulate Y for the 'observed' assignment.", type = "error")
return(NULL)
}
obsW <- rr$randomizations[1, ]
nunits <- length(obsW)
# Basic data generation: Y = X * beta + tau * W + noise
Xval <- X_data()
if (is.null(Xval)) {
showNotification("No covariate data found to help simulate outcomes. Using intercept-only model.", type="warning")
Xval <- matrix(0, nrow = nunits, ncol = 1)
}
# random coefficients
beta <- rnorm(ncol(Xval), 0, 1)
linear_part <- Xval %*% beta
Ysim <- as.numeric(linear_part +
obsW * input$true_tau +
rnorm(nunits, 0, input$noise_sd))
Y_data(Ysim)
})
# (B) If user uploads Y
observeEvent(input$file_outcomes, {
req(input$file_outcomes)
inFile <- input$file_outcomes
dfy <- tryCatch(read.csv(inFile$datapath, header = FALSE), error=function(e) NULL)
if (!is.null(dfy)) {
if (ncol(dfy) > 1) {
showNotification("Please provide a single-column CSV for Y.", type="error")
} else {
Y_data(as.numeric(dfy[[1]]))
}
}
})
# Render a preview of Y
output$outcomes_table <- renderDT({
req(Y_data()) # Make sure Y_data is not NULL
# Convert to data frame for DT
dfy <- data.frame(obsW = RerandResult()$randomizations[1, ],
Y = Y_data())
# Optionally round numeric data
dfy[] <- lapply(dfy, function(col) {
if (is.numeric(col)) signif(col, 3) else col
})
datatable(
dfy,
options = list(scrollX = TRUE, pageLength = 5)
)
})
# The randomization test result:
RandTestResult <- reactiveVal(NULL)
RandTestResult_base <- reactiveVal(NULL)
# We'll store their times:
fastrand_test_time <- reactiveVal(NULL)
baseR_test_time <- reactiveVal(NULL)
observeEvent(input$run_randtest_btn, {
withProgress(message = "Computing results...", value = 0, {
req(RerandResult())
rr <- RerandResult()
req(rr$randomizations)
if (is.null(Y_data())) {
showNotification("No outcome data Y found. Upload or simulate first.", type="error")
return(NULL)
}
obsW <- rr$randomizations[1, ]
obsY <- Y_data()
# =========== 1) fastrerandomize randomization_test timing ===========
t0_testfast <- Sys.time()
outTest <- tryCatch({
randomization_test(
obsW = obsW,
obsY = obsY,
candidate_randomizations = rr$randomizations,
findFI = input$findFI
)
}, error=function(e) e)
t1_testfast <- Sys.time()
if (inherits(outTest, "error")) {
showNotification(paste("Error in randomization_test (fastrerandomize):", outTest$message), type="error")
RandTestResult(NULL)
} else {
RandTestResult(outTest)
}
fastrand_test_time(difftime(t1_testfast, t0_testfast, units = "secs"))
# =========== 2) base R randomization test timing ===========
req(RerandResult_base())
rr_base <- RerandResult_base()
if (is.null(rr_base$randomizations) || nrow(rr_base$randomizations) < 1) {
showNotification("No base R randomizations found. Cannot run base R test.", type = "error")
RandTestResult_base(NULL)
return(NULL)
}
t0_testbase <- Sys.time()
outTestBase <- tryCatch({
randomization_test_R(
obsW = obsW,
obsY = obsY,
allW = rr_base$randomizations,
findFI = input$findFI # if user wants the FI, do so
)
}, error = function(e) e)
t1_testbase <- Sys.time()
if (inherits(outTestBase, "error")) {
showNotification(paste("Error in randomization_test (base R):", outTestBase$message), type="error")
RandTestResult_base(NULL)
} else {
RandTestResult_base(outTestBase)
}
baseR_test_time(difftime(t1_testbase, t0_testbase, units = "secs"))
})
})
# Display p-value and observed tau (from the fastrerandomize test)
output$pvalue_box <- renderValueBox({
rt <- RandTestResult()
if (is.null(rt)) {
valueBox("---", "p-value (fastrerandomize)", icon = icon("question"), color = "blue")
} else {
valueBox(round(rt$p_value, 4), "p-value (fastrerandomize)", icon = icon("list-check"), color = "purple")
}
})
output$tauobs_box <- renderValueBox({
rt <- RandTestResult()
if (is.null(rt)) {
valueBox("---", "Observed Effect", icon = icon("question"), color = "maroon")
} else {
valueBox(round(rt$tau_obs, 4), "Observed Effect", icon = icon("bullseye"), color = "maroon")
}
})
# Times for randomization test
output$fastrerand_test_time_box <- renderValueBox({
tm <- fastrand_test_time()
if (is.null(tm)) {
valueBox("---", "fastrerandomize test time (secs)", icon = icon("clock"), color = "teal")
} else {
valueBox(round(as.numeric(tm), 3), "fastrerandomize test time (secs)",
icon = icon("clock"), color = "teal")
}
})
output$baseR_test_time_box <- renderValueBox({
tm <- baseR_test_time()
if (is.null(tm)) {
valueBox("---", "base R test time (secs)", icon = icon("clock"), color = "lime")
} else {
valueBox(round(as.numeric(tm), 3), "base R test time (secs)",
icon = icon("clock"), color = "lime")
}
})
# If we have a fiducial interval from fastrerandomize, display it
#output$fi_text <- renderUI({
# rt <- RandTestResult()
# if (is.null(rt) || is.null(rt$FI)) {
# return(NULL)
# }
# fi_lower <- round(rt$FI[1], 4)
# fi_upper <- round(rt$FI[2], 4)
#})
# If we have a fiducial interval from base R, display it
output$fi_text_baseR <- renderUI({
rt <- RandTestResult_base()
if (is.null(rt) || is.null(rt$FI)) {
return(NULL)
}
fi_lower <- round(rt$FI[1], 4)
fi_upper <- round(rt$FI[2], 4)
tagList(
strong("Fiducial Interval (95%):"),
p(sprintf("[%.4f, %.4f]", fi_lower, fi_upper))
)
})
# A simple plot for the randomization distribution (for demonstration).
# In this app, we do not store the entire distribution from either method,
# so we simply show the observed effect as a point.
output$test_plot <- renderPlot({
rt <- RandTestResult()
if (is.null(rt)) {
plot.new()
title("No test results yet.")
return(NULL)
}
# Just display the observed effect from fastrerandomize
obs_val <- rt$tau_obs
ggplot(data.frame(x = obs_val, y = 0), aes(x, y)) +
geom_point(size=4, color="red") +
xlim(c(obs_val - abs(obs_val)*2 - 1, obs_val + abs(obs_val)*2 + 1)) +
labs(title = "Observed Treatment Effect (fastrerandomize)",
x = "Effect Size", y = "") +
theme_minimal(base_size = 14) +
geom_vline(xintercept = 0, linetype="dashed", color="gray40")
})
}
# ---------------------------------------------------------
# Run the Application
# ---------------------------------------------------------
shinyApp(ui = ui, server = server)