## ----include = FALSE----------------------------------------------------------
# Code chunks are not evaluated when this vignette is built: the default
# computation backend requires a Python runtime with u-stats / numpy /
# torch, which is not available (and should not be downloaded) on the
# machines that build this vignette.
knitr::opts_chunk$set(
  collapse = TRUE,
  comment = "#>",
  fig.width = 7,
  fig.height = 5,
  eval = FALSE
)

## -----------------------------------------------------------------------------
# # From CRAN (once accepted):
# # install.packages("HOIF")
# 
# # Development version from GitHub (also installs the ustats R package):
# # install.packages("devtools")
# devtools::install_github("cxy0714/HOIF")

## -----------------------------------------------------------------------------
# library(HOIF)
# # First call provisions the Python environment automatically:
# # results <- hoif_ate(...)

## -----------------------------------------------------------------------------
# ustats::setup_ustats()             # CPU-only PyTorch (small download)
# ustats::setup_ustats(gpu = TRUE)   # default PyPI PyTorch (CUDA on Linux)
# ustats::check_ustats_setup()       # verify the environment

## -----------------------------------------------------------------------------
# reticulate::use_condaenv("your_env_name", required = TRUE)
# # or set the RETICULATE_PYTHON environment variable

## -----------------------------------------------------------------------------
# library(HOIF)
# 
# set.seed(123)
# n <- 2000
# p <- 10
# 
# # Covariates (all of them are confounders)
# X <- matrix(rnorm(n * p), ncol = p)
# 
# # True propensity score and outcome regressions: linear, loading on
# # ALL covariates
# beta_pi <- c(0.3, -0.2, 0.2, rep(0.25, 7))
# beta1   <- c(0.5,  0.4, 0.3, rep(0.4, 7))
# beta0   <- c(0.3,  0.2, 0.1, rep(0.3, 7))
# 
# true_pi <- plogis(as.vector(X %*% beta_pi))
# A <- rbinom(n, 1, true_pi)
# 
# mu1_true <- as.vector(1 + X %*% beta1)
# mu0_true <- as.vector(X %*% beta0)
# Y <- A * mu1_true + (1 - A) * mu0_true + rnorm(n, 0, 0.2)
# 
# # True targets
# psi1_true <- mean(mu1_true)
# psi0_true <- mean(mu0_true)
# true_ate  <- psi1_true - psi0_true
# cat("True E[Y(1)]:", round(psi1_true, 4), "\n")
# #> True E[Y(1)]: 0.9799
# cat("True E[Y(0)]:", round(psi0_true, 4), "\n")
# #> True E[Y(0)]: -0.016
# cat("True ATE:    ", round(true_ate, 4), "\n")
# #> True ATE:     0.9958

## -----------------------------------------------------------------------------
# idx_nuis <- sample(n, n / 2)
# idx_est  <- setdiff(seq_len(n), idx_nuis)
# 
# X_nuis <- X[idx_nuis, , drop = FALSE]
# A_nuis <- A[idx_nuis]
# Y_nuis <- Y[idx_nuis]
# 
# X_est <- X[idx_est, , drop = FALSE]
# A_est <- A[idx_est]
# Y_est <- Y[idx_est]

## -----------------------------------------------------------------------------
# S <- 1:3  # the working models ignore covariates 4..10
# 
# # Outcome regressions, fitted on the nuisance sample only
# fit_mu1 <- lm(Y_nuis ~ X_nuis[, S], subset = A_nuis == 1)
# fit_mu0 <- lm(Y_nuis ~ X_nuis[, S], subset = A_nuis == 0)
# 
# # Propensity score, fitted on the nuisance sample only
# fit_pi <- glm(A_nuis ~ X_nuis[, S], family = binomial)
# 
# # Predict all nuisance functions on the estimation sample
# mu1_hat <- as.vector(cbind(1, X_est[, S]) %*% coef(fit_mu1))
# mu0_hat <- as.vector(cbind(1, X_est[, S]) %*% coef(fit_mu0))
# pi_hat  <- as.vector(plogis(cbind(1, X_est[, S]) %*% coef(fit_pi)))
# 
# # Ensure propensity scores are bounded away from 0 and 1
# pi_hat <- pmax(pmin(pi_hat, 0.95), 0.05)

## -----------------------------------------------------------------------------
# psi1_aipw <- mean(mu1_hat + A_est / pi_hat * (Y_est - mu1_hat))
# psi0_aipw <- mean(mu0_hat + (1 - A_est) / (1 - pi_hat) * (Y_est - mu0_hat))
# ate_aipw  <- psi1_aipw - psi0_aipw
# 
# cat("AIPW estimates: E[Y(1)] =", round(psi1_aipw, 4),
#     "  E[Y(0)] =", round(psi0_aipw, 4),
#     "  ATE =", round(ate_aipw, 4), "\n")
# #> AIPW estimates: E[Y(1)] = 1.1749   E[Y(0)] = -0.2676   ATE = 1.4425
# cat("AIPW errors:    E[Y(1)] =", round(psi1_aipw - psi1_true, 4),
#     "  E[Y(0)] =", round(psi0_aipw - psi0_true, 4),
#     "  ATE =", round(ate_aipw - true_ate, 4), "\n")
# #> AIPW errors:    E[Y(1)] = 0.195   E[Y(0)] = -0.2516   ATE = 0.4466

## -----------------------------------------------------------------------------
# results_ehoif <- hoif_ate(
#   X = X_est,
#   A = A_est,
#   Y = Y_est,
#   mu1 = mu1_hat,
#   mu0 = mu0_hat,
#   pi = pi_hat,
#   transform_method = "none",  # Use raw covariates
#   inverse_method = "direct",
#   m = 7,                      # Compute up to 7th order
#   sample_split = TRUE,
#   n_folds = 2,
#   seed = 42,
#   backend = "torch"           # Use Python backend
# )
# 
# print(results_ehoif)
# #> HOIF Estimators for Average Treatment Effect
# #> =============================================
# #>
# #> Higher-order correction terms by order:
# #>   Order     ATE   HOIF1  HOIF0
# #> 1     2 -0.4621 -0.2615 0.2006
# #> 2     3 -0.4050 -0.2341 0.1709
# #> 3     4 -0.4335 -0.2480 0.1855
# #> 4     5 -0.4272 -0.2443 0.1829
# #> 5     6 -0.4289 -0.2456 0.1833
# #> 6     7 -0.4293 -0.2455 0.1838
# #>
# #> Estimated AIPW bias correction for the ATE (highest order): -0.4293
# #> (add this value to the first-order AIPW/DR estimate of the ATE to debias it)

## -----------------------------------------------------------------------------
# results_shoif <- hoif_ate(
#   X = X_est,
#   A = A_est,
#   Y = Y_est,
#   mu1 = mu1_hat,
#   mu0 = mu0_hat,
#   pi = pi_hat,
#   transform_method = "none",
#   inverse_method = "direct",
#   m = 7,
#   sample_split = FALSE,
#   backend = "torch"
# )
# 
# print(results_shoif)
# #> HOIF Estimators for Average Treatment Effect
# #> =============================================
# #>
# #> Higher-order correction terms by order:
# #>   Order     ATE   HOIF1  HOIF0
# #> 1     2 -0.4318 -0.2496 0.1821
# #> 2     3 -0.4458 -0.2568 0.1890
# #> 3     4 -0.4373 -0.2520 0.1853
# #> 4     5 -0.4362 -0.2515 0.1848
# #> 5     6 -0.4366 -0.2517 0.1849
# #> 6     7 -0.4367 -0.2517 0.1849
# #>
# #> Estimated AIPW bias correction for the ATE (highest order): -0.4367
# #> (add this value to the first-order AIPW/DR estimate of the ATE to debias it)

## -----------------------------------------------------------------------------
# psi1_ehoif <- psi1_aipw + tail(results_ehoif$HOIF1, 1)
# psi0_ehoif <- psi0_aipw + tail(results_ehoif$HOIF0, 1)
# psi1_shoif <- psi1_aipw + tail(results_shoif$HOIF1, 1)
# psi0_shoif <- psi0_aipw + tail(results_shoif$HOIF0, 1)
# 
# comparison <- data.frame(
#   row.names = c("E[Y(1)]", "E[Y(0)]", "ATE"),
#   Truth = c(psi1_true, psi0_true, true_ate),
#   AIPW  = c(psi1_aipw, psi0_aipw, ate_aipw),
#   eHOIF = c(psi1_ehoif, psi0_ehoif, psi1_ehoif - psi0_ehoif),
#   sHOIF = c(psi1_shoif, psi0_shoif, psi1_shoif - psi0_shoif)
# )
# round(comparison, 4)
# #>           Truth    AIPW   eHOIF   sHOIF
# #> E[Y(1)]  0.9799  1.1749  0.9294  0.9232
# #> E[Y(0)] -0.0160 -0.2676 -0.0838 -0.0826
# #> ATE      0.9958  1.4425  1.0132  1.0058
# 
# # Errors relative to the truth
# round(sweep(comparison[, -1], 1, comparison$Truth, "-"), 4)
# #>            AIPW   eHOIF   sHOIF
# #> E[Y(1)]  0.1950 -0.0505 -0.0567
# #> E[Y(0)] -0.2516 -0.0678 -0.0666
# #> ATE      0.4466  0.0174  0.0100

## -----------------------------------------------------------------------------
# plot(results_ehoif)

## -----------------------------------------------------------------------------
# # B-splines basis
# results_splines <- hoif_ate(
#   X = X_est, A = A_est, Y = Y_est,
#   mu1 = mu1_hat, mu0 = mu0_hat, pi = pi_hat,
#   transform_method = "splines",
#   basis_dim = 5,        # 5 basis functions per covariate
#   degree = 3,           # Cubic splines
#   m = 5,
#   sample_split = FALSE
# )
# 
# # Fourier basis
# results_fourier <- hoif_ate(
#   X = X_est, A = A_est, Y = Y_est,
#   mu1 = mu1_hat, mu0 = mu0_hat, pi = pi_hat,
#   transform_method = "fourier",
#   basis_dim = 4,        # 4 Fourier components per covariate
#   period = 1,
#   m = 5,
#   sample_split = FALSE
# )

## -----------------------------------------------------------------------------
# results_cf <- hoif_ate(
#   X = X_est, A = A_est, Y = Y_est,
#   mu1 = mu1_hat, mu0 = mu0_hat, pi = pi_hat,
#   transform_method = "none",
#   m = 5,
#   sample_split = TRUE,
#   n_folds = 5,          # 5-fold cross-fitting
#   seed = 42
# )

## -----------------------------------------------------------------------------
# # Nonlinear shrinkage
# results_nlshrink <- hoif_ate(
#   X = X_est, A = A_est, Y = Y_est,
#   mu1 = mu1_hat, mu0 = mu0_hat, pi = pi_hat,
#   inverse_method = "nlshrink",
#   m = 5
# )
# 
# # corpcor shrinkage
# results_corpcor <- hoif_ate(
#   X = X_est, A = A_est, Y = Y_est,
#   mu1 = mu1_hat, mu0 = mu0_hat, pi = pi_hat,
#   inverse_method = "corpcor",
#   m = 5
# )

## -----------------------------------------------------------------------------
# results_pure_r <- hoif_ate(
#   X = X_est, A = A_est, Y = Y_est,
#   mu1 = mu1_hat, mu0 = mu0_hat, pi = pi_hat,
#   pure_R_code = TRUE,
#   m = 6  # Maximum 6 for pure R
# )

## -----------------------------------------------------------------------------
# ustats::check_ustats_setup()

## -----------------------------------------------------------------------------
# results <- hoif_ate(
#   X = X_est, A = A_est, Y = Y_est,
#   mu1 = mu1_hat, mu0 = mu0_hat, pi = pi_hat,
#   pure_R_code = TRUE,
#   m = 6  # Limited to order 6
# )

## -----------------------------------------------------------------------------
# # Try shrinkage methods
# results <- hoif_ate(
#   X = X_est, A = A_est, Y = Y_est,
#   mu1 = mu1_hat, mu0 = mu0_hat, pi = pi_hat,
#   inverse_method = "nlshrink"  # or "corpcor"
# )

## -----------------------------------------------------------------------------
# # Trim to [0.05, 0.95]
# pi_hat <- pmax(pmin(pi_hat, 0.95), 0.05)

