cjerzak commited on
Commit
9443444
·
verified ·
1 Parent(s): f2cc599

Update app.R

Browse files
Files changed (1) hide show
  1. app.R +165 -31
app.R CHANGED
@@ -27,6 +27,7 @@ library(parallel) # For detecting CPU cores
27
  # ---------------------------------------------------------
28
  # HELPER FUNCTIONS (BASE R)
29
  # ---------------------------------------------------------
 
30
  # 1) Compute Hotelling's T^2 in base R
31
  baseR_hotellingT2 <- function(X, W) {
32
  # For a single assignment W:
@@ -52,12 +53,11 @@ baseR_hotellingT2 <- function(X, W) {
52
  }
53
 
54
  # 2) Generate randomizations in base R, filtering by acceptance probability
55
- # using T^2 and keep the best (lowest) fraction.
56
  baseR_generate_randomizations <- function(n_units, n_treated, X, accept_prob, random_type,
57
  max_draws, batch_size) {
58
 
59
  # For safety, check if exact enumerations will explode:
60
- # If random_type == "exact", we do combn(n_units, n_treated), which might be huge
61
  if (random_type == "exact") {
62
  n_comb_total <- choose(n_units, n_treated)
63
  if (n_comb_total > 1e6) {
@@ -136,21 +136,134 @@ baseR_generate_randomizations <- function(n_units, n_treated, X, accept_prob, ra
136
  list(randomizations = assignment_mat_accepted, balance = T2vals_accepted)
137
  }
138
 
139
- # 3) Base R randomization test: difference in means
140
- baseR_randomization_test <- function(obsW, obsY, allW) {
141
- # obs diff in means
142
- n1 <- sum(obsW)
143
- n0 <- length(obsW) - n1
144
- obs_diff <- mean(obsY[obsW == 1]) - mean(obsY[obsW == 0])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
 
146
  # for each candidate assignment, compute diff in means on obsY
147
- diffs <- apply(allW, 1, function(w) {
148
- mean(obsY[w == 1]) - mean(obsY[w == 0])
149
- })
150
 
151
  # p-value = fraction whose absolute diff >= observed
152
- pval <- mean(abs(diffs) >= abs(obs_diff))
153
- list(p_value = pval, tau_obs = obs_diff)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
  }
155
 
156
  # ---------------------------------------------------------
@@ -310,7 +423,7 @@ ui <- dashboardPage(
310
 
311
  box(width = 8, title = "Test Results", status = "info", solidHeader = TRUE,
312
 
313
- # First row: p-value and observed effect
314
  fluidRow(
315
  column(width = 6, valueBoxOutput("pvalue_box", width = 12)),
316
  column(width = 6, valueBoxOutput("tauobs_box", width = 12))
@@ -322,7 +435,19 @@ ui <- dashboardPage(
322
  column(width = 6, valueBoxOutput("baseR_test_time_box", width = 12))
323
  ),
324
 
 
325
  uiOutput("fi_text"),
 
 
 
 
 
 
 
 
 
 
 
326
  br(),
327
  plotOutput("test_plot", height = "280px")
328
  )
@@ -390,7 +515,6 @@ server <- function(input, output, session) {
390
  "Number treated cannot exceed total units.")
391
  )
392
 
393
- # ------------------ COMPUTING RESULTS TOGGLE ------------------
394
  withProgress(message = "Computing results...", value = 0, {
395
 
396
  # =========== 1) fastrerandomize generation timing ===========
@@ -500,7 +624,6 @@ server <- function(input, output, session) {
500
  # Hardware info (CPU cores, GPU note)
501
  output$hardware_info <- renderUI({
502
  num_cores <- detectCores(logical = TRUE)
503
- # Basic note about GPU (this can be expanded if you have specialized checks)
504
  HTML(paste(
505
  "<strong>System Hardware Info:</strong><br/>",
506
  "Number of CPU cores detected:", num_cores, "<br/>",
@@ -517,8 +640,6 @@ server <- function(input, output, session) {
517
  observeEvent(input$simulateY_btn, {
518
  req(RerandResult())
519
  rr <- RerandResult()
520
-
521
- # We'll just use the first accepted randomization as the "observed" assignment
522
  if (is.null(rr$randomizations) || nrow(rr$randomizations) < 1) {
523
  showNotification("No accepted randomizations found. Cannot simulate Y for the 'observed' assignment.", type = "error")
524
  return(NULL)
@@ -564,7 +685,6 @@ server <- function(input, output, session) {
564
  baseR_test_time <- reactiveVal(NULL)
565
 
566
  observeEvent(input$run_randtest_btn, {
567
- # ------------------ COMPUTING RESULTS TOGGLE ------------------
568
  withProgress(message = "Computing results...", value = 0, {
569
 
570
  req(RerandResult())
@@ -599,7 +719,6 @@ server <- function(input, output, session) {
599
  fastrand_test_time(difftime(t1_testfast, t0_testfast, units = "secs"))
600
 
601
  # =========== 2) base R randomization test timing ===========
602
- # We must also have the base R set of randomizations
603
  req(RerandResult_base())
604
  rr_base <- RerandResult_base()
605
  if (is.null(rr_base$randomizations) || nrow(rr_base$randomizations) < 1) {
@@ -613,7 +732,8 @@ server <- function(input, output, session) {
613
  baseR_randomization_test(
614
  obsW = obsW,
615
  obsY = obsY,
616
- allW = rr_base$randomizations
 
617
  )
618
  }, error = function(e) e)
619
  t1_testbase <- Sys.time()
@@ -632,18 +752,18 @@ server <- function(input, output, session) {
632
  output$pvalue_box <- renderValueBox({
633
  rt <- RandTestResult()
634
  if (is.null(rt)) {
635
- valueBox("---", "p-value", icon = icon("question"), color = "blue")
636
  } else {
637
- valueBox(round(rt$p_value, 4), "p-value", icon = icon("list-check"), color = "purple")
638
  }
639
  })
640
 
641
  output$tauobs_box <- renderValueBox({
642
  rt <- RandTestResult()
643
  if (is.null(rt)) {
644
- valueBox("---", "Observed Effect", icon = icon("question"), color = "maroon")
645
  } else {
646
- valueBox(round(rt$tau_obs, 4), "Observed Effect", icon = icon("bullseye"), color = "maroon")
647
  }
648
  })
649
 
@@ -668,7 +788,7 @@ server <- function(input, output, session) {
668
  }
669
  })
670
 
671
- # If we have a fiducial interval, display it
672
  output$fi_text <- renderUI({
673
  rt <- RandTestResult()
674
  if (is.null(rt) || is.null(rt$FI)) {
@@ -678,29 +798,43 @@ server <- function(input, output, session) {
678
  fi_upper <- round(rt$FI[2], 4)
679
 
680
  tagList(
681
- strong("Fiducial Interval (95%):"),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
682
  p(sprintf("[%.4f, %.4f]", fi_lower, fi_upper))
683
  )
684
  })
685
 
686
  # A simple plot for the randomization distribution (for demonstration).
687
- # In this minimal example, we do not store the entire distribution in 'randomization_test',
688
  # so we simply show the observed effect as a point.
689
  output$test_plot <- renderPlot({
690
  rt <- RandTestResult()
691
  if (is.null(rt)) {
692
- # no test run yet
693
  plot.new()
694
  title("No test results yet.")
695
  return(NULL)
696
  }
697
- # Just display the observed effect
698
  obs_val <- rt$tau_obs
699
 
700
  ggplot(data.frame(x = obs_val, y = 0), aes(x, y)) +
701
  geom_point(size=4, color="red") +
702
  xlim(c(obs_val - abs(obs_val)*2 - 1, obs_val + abs(obs_val)*2 + 1)) +
703
- labs(title = "Observed Treatment Effect",
704
  x = "Effect Size", y = "") +
705
  theme_minimal(base_size = 14) +
706
  geom_vline(xintercept = 0, linetype="dashed", color="gray40")
 
27
  # ---------------------------------------------------------
28
  # HELPER FUNCTIONS (BASE R)
29
  # ---------------------------------------------------------
30
+
31
  # 1) Compute Hotelling's T^2 in base R
32
  baseR_hotellingT2 <- function(X, W) {
33
  # For a single assignment W:
 
53
  }
54
 
55
  # 2) Generate randomizations in base R, filtering by acceptance probability
56
+ # using T^2 and keep the best (lowest) fraction.
57
  baseR_generate_randomizations <- function(n_units, n_treated, X, accept_prob, random_type,
58
  max_draws, batch_size) {
59
 
60
  # For safety, check if exact enumerations will explode:
 
61
  if (random_type == "exact") {
62
  n_comb_total <- choose(n_units, n_treated)
63
  if (n_comb_total > 1e6) {
 
136
  list(randomizations = assignment_mat_accepted, balance = T2vals_accepted)
137
  }
138
 
139
+ # Helper: compute difference in means quickly
140
+ diff_in_means <- function(Y, W) {
141
+ mean(Y[W == 1]) - mean(Y[W == 0])
142
+ }
143
+
144
+ # Helper: for a given tau, relabel outcomes and compute the difference in means for a single permutation
145
+ compute_diff_at_tau_for_oneW <- function(Wprime, obsY, obsW, tau) {
146
+ # Y0_under_null = obsY - obsW * tau
147
+ Y0 <- obsY - obsW * tau
148
+ # Y1_under_null = Y0 + tau
149
+ # But in practice, for assignment Wprime, the observed outcome is:
150
+ # Y'(i) = Y0(i) if Wprime(i) = 0, or Y0(i) + tau if Wprime(i)=1
151
+ Yprime <- Y0
152
+ Yprime[Wprime == 1] <- Y0[Wprime == 1] + tau
153
+ diff_in_means(Yprime, Wprime)
154
+ }
155
+
156
+ # 3a) For base R randomization test: difference in means + optional p-value
157
+ # *without* fiducial interval
158
+ # (We will incorporate the FI logic below.)
159
+ baseR_randomization_test <- function(obsW, obsY, allW, findFI = FALSE, alpha = 0.05) {
160
+ # Observed diff in means
161
+ tau_obs <- diff_in_means(obsY, obsW)
162
 
163
  # for each candidate assignment, compute diff in means on obsY
164
+ diffs <- apply(allW, 1, function(w) diff_in_means(obsY, w))
 
 
165
 
166
  # p-value = fraction whose absolute diff >= observed
167
+ pval <- mean(abs(diffs) >= abs(tau_obs))
168
+
169
+ # optionally compute a fiducial interval
170
+ FI <- NULL
171
+ if (findFI) {
172
+ FI <- baseR_find_fiducial_interval(obsW, obsY, allW, tau_obs, alpha = alpha)
173
+ }
174
+
175
+ list(p_value = pval, tau_obs = tau_obs, FI = FI)
176
+ }
177
+
178
+ # 3b) The fiducial interval logic for base R, mirroring the approach in fastrerandomize:
179
+ # 1) Attempt to find a wide lower and upper bracket via random updates
180
+ # 2) Then a grid search in [lowerBound-1, upperBound*2] for which tau are accepted.
181
+ baseR_find_fiducial_interval <- function(obsW, obsY, allW, tau_obs, alpha = 0.05, c_initial = 2,
182
+ n_search_attempts = 500) {
183
+
184
+ # random bracket approach
185
+ lowerBound_est <- tau_obs - 3*tau_obs
186
+ upperBound_est <- tau_obs + 3*tau_obs
187
+
188
+ z_alpha <- qnorm(1 - alpha)
189
+ k <- 2 / (z_alpha * (2 * pi)^(-1/2) * exp(-z_alpha^2 / 2))
190
+
191
+ # For each iteration, pick one random assignment from allW
192
+ # then see how the implied difference changes, and update the bracket
193
+ n_allW <- nrow(allW)
194
+ for (step_t in seq_len(n_search_attempts)) {
195
+ # pick random assignment
196
+ idx <- sample.int(n_allW, 1)
197
+ Wprime <- allW[idx, ]
198
+
199
+ # ~~~~~ update lowerBound ~~~~~
200
+ # Y0 = obsY - obsW * lowerBound_est
201
+ # Y'(Wprime) = ...
202
+ lowerY0 <- obsY - obsW * lowerBound_est
203
+ Yprime_lower <- lowerY0
204
+ Yprime_lower[Wprime == 1] <- lowerY0[Wprime == 1] + lowerBound_est
205
+
206
+ tau_at_step_lower <- diff_in_means(Yprime_lower, Wprime)
207
+
208
+ c_step <- c_initial
209
+ # difference from obs
210
+ delta <- tau_obs - tau_at_step_lower
211
+
212
+ if (tau_at_step_lower < tau_obs) {
213
+ # move lowerBound up
214
+ lowerBound_est <- lowerBound_est + k * delta * (alpha/2) / step_t
215
+ } else {
216
+ # move it down
217
+ lowerBound_est <- lowerBound_est - k * (-delta) * (1 - alpha/2) / step_t
218
+ }
219
+
220
+ # ~~~~~ update upperBound ~~~~~
221
+ upperY0 <- obsY - obsW * upperBound_est
222
+ Yprime_upper <- upperY0
223
+ Yprime_upper[Wprime == 1] <- upperY0[Wprime == 1] + upperBound_est
224
+
225
+ tau_at_step_upper <- diff_in_means(Yprime_upper, Wprime)
226
+ delta2 <- tau_at_step_upper - tau_obs
227
+
228
+ if (tau_at_step_upper > tau_obs) {
229
+ # move upperBound down
230
+ upperBound_est <- upperBound_est - k * delta2 * (alpha/2) / step_t
231
+ } else {
232
+ # move it up
233
+ upperBound_est <- upperBound_est + k * (-delta2) * (1 - alpha/2) / step_t
234
+ }
235
+ }
236
+
237
+ # Now we do a grid search from (lowerBound_est - 1) to (upperBound_est * 2)
238
+ # in e.g. 100 steps, seeing which tau is "accepted".
239
+ # We'll define "accepted" if the min of:
240
+ # fraction(tau_obs >= distribution_of(tau_pseudo))
241
+ # fraction(tau_obs <= distribution_of(tau_pseudo))
242
+ # is > alpha, i.e. do not reject
243
+ grid_lower <- lowerBound_est - 1
244
+ grid_upper <- upperBound_est * 2
245
+ tau_seq <- seq(grid_lower, grid_upper, length.out = 100)
246
+
247
+ accepted <- logical(length(tau_seq))
248
+ for (i in seq_along(tau_seq)) {
249
+ tau_pseudo <- tau_seq[i]
250
+ # for each row in allW, compute the diff in means if the true effect = tau_pseudo
251
+ # distribution_of(tau_pseudo)
252
+ diffs_pseudo <- apply(allW, 1, function(wp) compute_diff_at_tau_for_oneW(wp, obsY, obsW, tau_pseudo))
253
+ # Then see how often diffs_pseudo >= tau_obs (or <= tau_obs)
254
+ frac_ge <- mean(diffs_pseudo >= tau_obs)
255
+ frac_le <- mean(diffs_pseudo <= tau_obs)
256
+ # min(...) is the typical "two-sided" approach
257
+ accepted[i] <- (min(frac_ge, frac_le) > alpha / 2) # or 0.05 if we want 5% test
258
+ }
259
+
260
+ if (!any(accepted)) {
261
+ # no values accepted => degenerate?
262
+ # We'll return the bracket we found, or NA.
263
+ return(c(NA, NA))
264
+ }
265
+
266
+ c(min(tau_seq[accepted]), max(tau_seq[accepted]))
267
  }
268
 
269
  # ---------------------------------------------------------
 
423
 
424
  box(width = 8, title = "Test Results", status = "info", solidHeader = TRUE,
425
 
426
+ # First row: p-value and observed effect (fastrerandomize)
427
  fluidRow(
428
  column(width = 6, valueBoxOutput("pvalue_box", width = 12)),
429
  column(width = 6, valueBoxOutput("tauobs_box", width = 12))
 
435
  column(width = 6, valueBoxOutput("baseR_test_time_box", width = 12))
436
  ),
437
 
438
+ # Show fastrerandomize FI
439
  uiOutput("fi_text"),
440
+
441
+ # Now show Base R results in a separate row
442
+ tags$hr(),
443
+ fluidRow(
444
+ column(width = 6, valueBoxOutput("pvalue_box_baseR", width = 12)),
445
+ column(width = 6, valueBoxOutput("tauobs_box_baseR", width = 12))
446
+ ),
447
+ fluidRow(
448
+ column(width = 12, uiOutput("fi_text_baseR"))
449
+ ),
450
+
451
  br(),
452
  plotOutput("test_plot", height = "280px")
453
  )
 
515
  "Number treated cannot exceed total units.")
516
  )
517
 
 
518
  withProgress(message = "Computing results...", value = 0, {
519
 
520
  # =========== 1) fastrerandomize generation timing ===========
 
624
  # Hardware info (CPU cores, GPU note)
625
  output$hardware_info <- renderUI({
626
  num_cores <- detectCores(logical = TRUE)
 
627
  HTML(paste(
628
  "<strong>System Hardware Info:</strong><br/>",
629
  "Number of CPU cores detected:", num_cores, "<br/>",
 
640
  observeEvent(input$simulateY_btn, {
641
  req(RerandResult())
642
  rr <- RerandResult()
 
 
643
  if (is.null(rr$randomizations) || nrow(rr$randomizations) < 1) {
644
  showNotification("No accepted randomizations found. Cannot simulate Y for the 'observed' assignment.", type = "error")
645
  return(NULL)
 
685
  baseR_test_time <- reactiveVal(NULL)
686
 
687
  observeEvent(input$run_randtest_btn, {
 
688
  withProgress(message = "Computing results...", value = 0, {
689
 
690
  req(RerandResult())
 
719
  fastrand_test_time(difftime(t1_testfast, t0_testfast, units = "secs"))
720
 
721
  # =========== 2) base R randomization test timing ===========
 
722
  req(RerandResult_base())
723
  rr_base <- RerandResult_base()
724
  if (is.null(rr_base$randomizations) || nrow(rr_base$randomizations) < 1) {
 
732
  baseR_randomization_test(
733
  obsW = obsW,
734
  obsY = obsY,
735
+ allW = rr_base$randomizations,
736
+ findFI = input$findFI # if user wants the FI, do so
737
  )
738
  }, error = function(e) e)
739
  t1_testbase <- Sys.time()
 
752
  output$pvalue_box <- renderValueBox({
753
  rt <- RandTestResult()
754
  if (is.null(rt)) {
755
+ valueBox("---", "p-value (fastrerandomize)", icon = icon("question"), color = "blue")
756
  } else {
757
+ valueBox(round(rt$p_value, 4), "p-value (fastrerandomize)", icon = icon("list-check"), color = "purple")
758
  }
759
  })
760
 
761
  output$tauobs_box <- renderValueBox({
762
  rt <- RandTestResult()
763
  if (is.null(rt)) {
764
+ valueBox("---", "Observed Effect (fastrerandomize)", icon = icon("question"), color = "maroon")
765
  } else {
766
+ valueBox(round(rt$tau_obs, 4), "Observed Effect (fastrerandomize)", icon = icon("bullseye"), color = "maroon")
767
  }
768
  })
769
 
 
788
  }
789
  })
790
 
791
+ # If we have a fiducial interval from fastrerandomize, display it
792
  output$fi_text <- renderUI({
793
  rt <- RandTestResult()
794
  if (is.null(rt) || is.null(rt$FI)) {
 
798
  fi_upper <- round(rt$FI[2], 4)
799
 
800
  tagList(
801
+ strong("Fiducial Interval (fastrerandomize, 95%):"),
802
+ p(sprintf("[%.4f, %.4f]", fi_lower, fi_upper))
803
+ )
804
+ })
805
+
806
+ # If we have a fiducial interval from base R, display it
807
+ output$fi_text_baseR <- renderUI({
808
+ rt <- RandTestResult_base()
809
+ if (is.null(rt) || is.null(rt$FI)) {
810
+ return(NULL)
811
+ }
812
+ fi_lower <- round(rt$FI[1], 4)
813
+ fi_upper <- round(rt$FI[2], 4)
814
+
815
+ tagList(
816
+ strong("Fiducial Interval (base R, 95%):"),
817
  p(sprintf("[%.4f, %.4f]", fi_lower, fi_upper))
818
  )
819
  })
820
 
821
  # A simple plot for the randomization distribution (for demonstration).
822
+ # In this app, we do not store the entire distribution from either method,
823
  # so we simply show the observed effect as a point.
824
  output$test_plot <- renderPlot({
825
  rt <- RandTestResult()
826
  if (is.null(rt)) {
 
827
  plot.new()
828
  title("No test results yet.")
829
  return(NULL)
830
  }
831
+ # Just display the observed effect from fastrerandomize
832
  obs_val <- rt$tau_obs
833
 
834
  ggplot(data.frame(x = obs_val, y = 0), aes(x, y)) +
835
  geom_point(size=4, color="red") +
836
  xlim(c(obs_val - abs(obs_val)*2 - 1, obs_val + abs(obs_val)*2 + 1)) +
837
+ labs(title = "Observed Treatment Effect (fastrerandomize)",
838
  x = "Effect Size", y = "") +
839
  theme_minimal(base_size = 14) +
840
  geom_vline(xintercept = 0, linetype="dashed", color="gray40")