## -----------------------------------------------------------------------------
library(dplyr)
library(ggplot2)
library(ggalluvial)
library(wompwomp)
set.seed(43)

## -----------------------------------------------------------------------------
df <- data.frame(
  tissue = c(
    "BRAIN", "BRAIN", "BRAIN",
    "STOMACH", "STOMACH", "STOMACH", "STOMACH", "STOMACH", "STOMACH",
    "HEART", "HEART", "HEART", "HEART", "HEART", "HEART", "HEART",
    "T CELL", "T CELL",
    "B CELL", "B CELL", "B CELL", "B CELL", "B CELL", "B CELL", "B CELL", "B CELL", "B CELL"
  ),
  cluster = c(
    1, 1, 2,
    1, 2, 2, 2, 2, 2,
    1, 3, 3, 3, 3, 3, 3,
    4, 4,
    4, 4, 4, 4, 4, 4, 4, 4, 4
  )
)

# preprocess (manual) and plot — unsorted + uncolored
df |> 
  dplyr::group_by(tissue, cluster) |> 
  dplyr::summarize(value = dplyr::n()) |> 
  dplyr::ungroup() |> 
  dplyr::mutate(dplyr::across(c(tissue, cluster), as.character)) |> 
  print() -> clus_df_gather

## -----------------------------------------------------------------------------
# unsorted plot
clus_df_gather |> 
  ggplot(aes(y = value, axis1 = tissue, axis2 = cluster)) +
  # alluvia color is column1
  geom_alluvium(aes(fill = tissue), width = 1/12) +
  geom_stratum(aes(fill = after_stat(stratum)), width = 1/12, color = "grey") +
  geom_label(stat = "stratum", aes(label = after_stat(stratum))) +
  scale_x_discrete(
    limits = c("tissue", "cluster"),
    expand = c(.05, .05)
  )

## -----------------------------------------------------------------------------
# sort (tidy) and plot - sorted + uncolored
clus_df_gather |> 
  sort_to_uncross(cols = c(tissue, cluster), wt = value, options = list(weighted_metric = TRUE)) |> 
  print() -> clus_df_gather_sort

clus_df_gather_sort |> 
  ggplot(aes(y = value, axis1 = tissue, axis2 = cluster)) +
  # alluvia color is column1
  geom_alluvium(aes(fill = tissue), width = 1/12) +
  geom_stratum(aes(fill = after_stat(stratum)), width = 1/12, color = "grey") +
  geom_label(stat = "stratum", aes(label = after_stat(stratum))) +
  scale_x_discrete(
    limits = c("tissue", "cluster"),
    expand = c(.05, .05)
  )

## -----------------------------------------------------------------------------
# color (tidy) and plot - sorted + colored
clus_df_gather_sort |> 
  get_lode_clusters(cols = c(tissue, cluster), wt = value) |> 
  print() -> cluster_mapping

clus_df_gather_sort |> 
  ggplot(aes(y = value, axis1 = tissue, axis2 = cluster)) +
  # alluvia color is column1
  geom_alluvium(aes(fill = tissue), width = 1/12) +
  geom_stratum(aes(fill = after_stat(stratum)), width = 1/12, color = "grey") +
  geom_label(stat = "stratum", aes(label = after_stat(stratum))) +
  scale_x_discrete(
    limits = c("tissue", "cluster"),
    expand = c(.05, .05)
  ) +
  scale_fill_manual(values = lode_cluster_pal(data = clus_df_gather_sort, cols = c(tissue, cluster), mapping = cluster_mapping))

## -----------------------------------------------------------------------------
crossing_edges_out <- compute_crossing_objective(clus_df_gather_sort, cols = c("tissue", "cluster"), wt = "value")
print(crossing_edges_out$output_objective)

## -----------------------------------------------------------------------------
sessionInfo()

