# ============================================================ # COMPARISON.R # Overlays 2-light (Analysis_PILOT.R) and 3-light (Analysis_3L_PILOT.R) # analyses in the same graphs — Sections 2–8. # # FILTER: correlationhigh == 1 (Easy LA) is DROPPED from 2L data. # # CONDITION → COLOUR: # 2L Easy (HA) [2L correlationhigh == 2] green #4dac26 # 3L High [3L correlationhigh == 1] red #d6604d # 2L Difficult [2L correlationhigh == 0] dark blue #2166ac # 3L Low [3L correlationhigh == 0] light blue #92c5de # # By-treatment facets use only treatments common to both datasets: # AND, OR, INHIBIT, EITHER, JOINT. # max_correct is normalised to frac_correct = max_correct / n_guesses # (12 for 2L, 16 for 3L) so both datasets share the same [0,1] x-axis. # ============================================================ library(tidyverse) script_dir <- dirname(rstudioapi::getActiveDocumentContext()$path) # 1. SOURCE STRUCTURAL SCRIPTS ------ # Helper: structural scripts use rstudioapi::getActiveDocumentContext()$path to # locate their own directory. When sourced from COMPARISON.R that resolves to # *this* file's directory, which is wrong for scripts in sub-folders. # source_fixed() patches the script_dir assignment before eval-ing the code. source_fixed <- function(file_path, env) { dir_path <- normalizePath(dirname(file_path), winslash = "/") code <- readLines(file_path, warn = FALSE) # Patch 1: replace script_dir assignment derived from rstudioapi so the # script sees its own directory, not COMPARISON.R's directory. code <- sub( pattern = "script_dir <- dirname\\(rstudioapi::getActiveDocumentContext\\(\\)\\$path\\)", replacement = sprintf('script_dir <- "%s"', dir_path), x = code, fixed = FALSE ) # Patch 2: any nested source(file.path(script_dir, "*.R")) calls inside the # script would suffer the same rstudioapi problem. Replace them with # source_fixed(..., environment()) so the fix propagates to DOfiles too. code <- gsub( pattern = 'source\\(file\\.path\\(script_dir,\\s*"([^"]+\\.R)"\\)\\)', replacement = 'source_fixed(file.path(script_dir, "\\1"), environment())', x = code, fixed = FALSE ) eval(parse(text = paste(code, collapse = "\n")), envir = env) invisible(env) } # 2L scripts live in the same directory as COMPARISON.R — plain source() works. env_2L1 <- new.env(parent = globalenv()) source(file.path(script_dir, "Structural_modelselection_2lights1.R"), local = env_2L1) env_2L2 <- new.env(parent = globalenv()) source(file.path(script_dir, "Structural_modelselection_2lights.R"), local = env_2L2) env_2L3 <- new.env(parent = globalenv()) source(file.path(script_dir, "Structural_modelselection_2lights3.R"), local = env_2L3) env_2L4 <- new.env(parent = globalenv()) source(file.path(script_dir, "Structural_modelselection_2lights4.R"), local = env_2L4) # 3L scripts live in "3 Lights/" — use source_fixed() so their internal # source() calls resolve relative to that sub-directory. dir_3L <- file.path(script_dir, "3 Lights") # env_3L1 <- new.env(parent = globalenv()) # source_fixed(file.path(dir_3L, "Structural_3lights.R"), env_3L1) # # env_3L2 <- new.env(parent = globalenv()) # source_fixed(file.path(dir_3L, "Structural_3lights (2).R"), env_3L2) env_3L1 <- new.env(parent = globalenv()) source_fixed(file.path(dir_3L, "Structural_modelselection.R"), env_3L1) env_3L2 <- new.env(parent = globalenv()) source_fixed(file.path(dir_3L, "Structural_modelselection(2).R"), env_3L2) tryCatch(grDevices::dev.off(), error = function(e) invisible(NULL)) # 2. BUILD Data_2L (mirrors Analysis_PILOT.R sections 1–5) ---- df_2L1 <- env_2L1$df_final_post %>% mutate(session = "2L(1)") df_2L2 <- env_2L2$df_final_post %>% mutate(session = "2L(2)") df_2L3 <- env_2L3$df_final_post %>% mutate(session = "2L(3)") df_2L4 <- env_2L4$df_final_post %>% mutate(session = "2L(4)") n_2L1 <- n_distinct(df_2L1$subject_id) n_2L2 <- n_distinct(df_2L2$subject_id) n_2L3 <- n_distinct(df_2L3$subject_id) df_2L2 <- df_2L2 %>% mutate(subject_id = subject_id + n_2L1) df_2L3 <- df_2L3 %>% mutate(subject_id = subject_id + n_2L1 + n_2L2) df_2L4 <- df_2L4 %>% mutate(subject_id = subject_id + n_2L1 + n_2L2 + n_2L3) Data_2L <- bind_rows(df_2L1, df_2L2, df_2L3, df_2L4) %>% filter(!(treatment == "AND" & session %in% c("2L(2)", "2L(4)"))) %>% mutate(session = factor(session, levels = c("2L(1)", "2L(2)", "2L(3)", "2L(4)"))) %>% arrange(session, subject_id, round_order, Guess_Number) %>% mutate(correlationhigh = case_when( treatment == "ALONE_difficult" & session == "2L(3)" ~ 0L, session == "2L(3)" ~ 2L, treatment == "AND" & session == "2L(2)" ~ 1L, treatment == "AND" & session == "2L(1)" ~ 0L, treatment == "EITHER" & session == "2L(2)" ~ 1L, treatment == "OR" & session == "2L(1)" ~ 1L, treatment == "INHIBIT"& session == "2L(1)" ~ 2L, treatment == "JOINT" & session == "2L(1)" ~ 2L, treatment == "AND" & session == "2L(4)" ~ 0L, treatment == "EITHER" & session == "2L(4)" ~ 2L, session == "2L(4)" ~ 1L, TRUE ~ 0L )) # 3. BUILD Data_3L (mirrors Analysis_3L_PILOT.R sections 1–5) ---- df_3L1 <- env_3L1$df_final_post %>% mutate(session = "3L(1)") df_3L2 <- env_3L2$df_final_post %>% mutate(session = "3L(2)") n_3L1 <- n_distinct(df_3L1$subject_id) df_3L2 <- df_3L2 %>% mutate(subject_id = subject_id + n_3L1) Data_3L <- bind_rows(df_3L1, df_3L2) %>% mutate( session = factor(session, levels = c("3L(1)", "3L(2)")), correlationhigh = as.integer( (session == "3L(1)" & treatment %in% c("AND", "INHIBIT")) | (session == "3L(2)" & treatment %in% c("JOINT", "OR", "EITHER")) ) ) %>% arrange(session, subject_id, round_order, Guess_Number) # 4. COMPARISON SETUP ------ common_treatments <- c("AND", "OR", "INHIBIT", "EITHER", "JOINT") cmp_levels <- c("2L_Easy_HA", "3L_High", "2L_Difficult", "3L_Low") cmp_labels <- c( "2L_Easy_HA" = "2L Easy (HA)", "3L_High" = "3L High", "2L_Difficult" = "2L Difficult", "3L_Low" = "3L Low" ) cmp_colours <- c( "2L_Easy_HA" = "#4dac26", "3L_High" = "#d6604d", "2L_Difficult" = "#2166ac", "3L_Low" = "#92c5de" ) cmp_fills <- cmp_colours # ---- 2L: drop Easy (LA), add condition Data_2L_cmp <- Data_2L %>% filter(correlationhigh != 1L) %>% mutate( condition = factor( if_else(correlationhigh == 2L, "2L_Easy_HA", "2L_Difficult"), levels = cmp_levels ), n_guesses = 12L, source = "2L", Rule_used_std = Rule_used, # ALONE_* treatments have no 3L counterpart; set to NA so they are # excluded from by-treatment comparisons automatically. treatment_cmp = if_else( treatment %in% c("ALONE_easy", "ALONE_difficult"), NA_character_, as.character(treatment) ), frac_correct = max_correct / n_guesses ) # ---- 3L: add condition, strip "_rb" suffix from rule names Data_3L_cmp <- Data_3L %>% mutate( condition = factor( if_else(correlationhigh == 1L, "3L_High", "3L_Low"), levels = cmp_levels ), n_guesses = 16L, source = "3L", Rule_used_std = str_remove(Rule_used, "_rb$"), treatment_cmp = as.character(treatment), frac_correct = max_correct / n_guesses ) # ---- Subject × treatment summary (common columns only) COMMON_COLS <- c( "subject_id", "source", "session", "treatment", "treatment_cmp", "condition", "n_guesses", "max_correct", "frac_correct", "trueextracted", "Rule_used", "Rule_used_std", "posterior", "time_machine", "predicted_correct_self" ) ss_2L <- Data_2L_cmp %>% distinct(across(all_of(COMMON_COLS))) ss_3L <- Data_3L_cmp %>% distinct(across(all_of(COMMON_COLS))) ss_cmp <- bind_rows(ss_2L, ss_3L) # SECTION 4: BAR PLOTS — fraction correct-------- cat("\n===== SECTION 4: BAR PLOTS — FRACTION CORRECT =====\n") # Bin frac_correct into deciles [0, 0.1, ..., 1.0] bin_breaks <- seq(0, 1, by = 0.1) ss_binned <- ss_cmp %>% mutate( fc_bin = cut(frac_correct, breaks = bin_breaks, include.lowest = TRUE, right = FALSE), fc_mid = bin_breaks[as.integer(fc_bin)] + 0.05 ) %>% filter(!is.na(fc_bin)) # Aggregate bar4_agg <- ss_binned %>% group_by(condition, fc_mid) %>% summarise(n = n(), .groups = "drop") %>% group_by(condition) %>% mutate(frac = n / sum(n)) %>% ungroup() means4_agg <- ss_cmp %>% group_by(condition) %>% summarise(mean_fc = mean(frac_correct, na.rm = TRUE), .groups = "drop") p_bar4_agg <- ggplot(bar4_agg, aes(x = fc_mid, y = frac, fill = condition)) + geom_bar(stat = "identity", position = "dodge", width = 0.08, alpha = 0.85) + geom_vline(data = means4_agg, aes(xintercept = mean_fc, colour = condition), linewidth = 0.9, linetype = "dashed") + scale_fill_manual(values = cmp_fills, labels = cmp_labels) + scale_colour_manual(values = cmp_colours, labels = cmp_labels, guide = "none") + scale_x_continuous(breaks = bin_breaks) + labs( x = "Fraction correct (max_correct / n_guesses)", y = "Fraction of subjects", fill = NULL, title = "Distribution of accuracy — aggregate" ) + theme_bw() + theme(legend.position = "bottom", axis.text.x = element_text(angle = 45, hjust = 1, size = 7)) print(p_bar4_agg) # By treatment bar4_treat <- ss_binned %>% filter(!is.na(treatment_cmp), treatment_cmp %in% common_treatments) %>% mutate(treatment_cmp = factor(treatment_cmp, levels = common_treatments)) %>% group_by(treatment_cmp, condition, fc_mid) %>% summarise(n = n(), .groups = "drop") %>% group_by(treatment_cmp, condition) %>% mutate(frac = n / sum(n)) %>% ungroup() means4_treat <- ss_cmp %>% filter(!is.na(treatment_cmp), treatment_cmp %in% common_treatments) %>% mutate(treatment_cmp = factor(treatment_cmp, levels = common_treatments)) %>% group_by(treatment_cmp, condition) %>% summarise(mean_fc = mean(frac_correct, na.rm = TRUE), .groups = "drop") p_bar4_treat <- ggplot(bar4_treat, aes(x = fc_mid, y = frac, fill = condition)) + geom_bar(stat = "identity", position = "dodge", width = 0.08, alpha = 0.85) + geom_vline(data = means4_treat, aes(xintercept = mean_fc, colour = condition), linewidth = 0.9, linetype = "dashed") + scale_fill_manual(values = cmp_fills, labels = cmp_labels) + scale_colour_manual(values = cmp_colours, labels = cmp_labels, guide = "none") + scale_x_continuous(breaks = seq(0, 1, 0.25)) + facet_wrap(~ treatment_cmp, nrow = 1) + labs( x = "Fraction correct (max_correct / n_guesses)", y = "Fraction of subjects", fill = NULL, title = "Distribution of accuracy — by treatment" ) + theme_bw() + theme(legend.position = "bottom", axis.text.x = element_text(angle = 45, hjust = 1, size = 7)) print(p_bar4_treat) # SECTION 5: BAR PLOTS — trueextracted-------- cat("\n===== SECTION 5: BAR PLOTS — TRUEEXTRACTED =====\n") te_cmp <- ss_cmp %>% distinct(subject_id, source, treatment, treatment_cmp, condition, trueextracted) # Aggregate bar5_agg <- te_cmp %>% group_by(condition, trueextracted) %>% summarise(n = n(), .groups = "drop") %>% group_by(condition) %>% mutate(frac = n / sum(n)) %>% ungroup() means5_agg <- te_cmp %>% group_by(condition) %>% summarise(mean_te = mean(trueextracted, na.rm = TRUE), .groups = "drop") p_bar5_agg <- ggplot(bar5_agg, aes(x = trueextracted, y = frac, fill = condition)) + geom_bar(stat = "identity", position = "dodge", width = 0.5, alpha = 0.85) + geom_vline(data = means5_agg, aes(xintercept = mean_te, colour = condition), linewidth = 0.9, linetype = "dashed") + scale_fill_manual(values = cmp_fills, labels = cmp_labels) + scale_colour_manual(values = cmp_colours, labels = cmp_labels, guide = "none") + scale_x_continuous(breaks = 0:1, labels = c("Not extracted", "Extracted")) + labs( x = "True rule extracted", y = "Fraction of subjects", fill = NULL, title = "True rule extraction — aggregate" ) + theme_bw() + theme(legend.position = "bottom") print(p_bar5_agg) # By treatment bar5_treat <- te_cmp %>% filter(!is.na(treatment_cmp), treatment_cmp %in% common_treatments) %>% mutate(treatment_cmp = factor(treatment_cmp, levels = common_treatments)) %>% group_by(treatment_cmp, condition, trueextracted) %>% summarise(n = n(), .groups = "drop") %>% group_by(treatment_cmp, condition) %>% mutate(frac = n / sum(n)) %>% ungroup() means5_treat <- te_cmp %>% filter(!is.na(treatment_cmp), treatment_cmp %in% common_treatments) %>% mutate(treatment_cmp = factor(treatment_cmp, levels = common_treatments)) %>% group_by(treatment_cmp, condition) %>% summarise(mean_te = mean(trueextracted, na.rm = TRUE), .groups = "drop") p_bar5_treat <- ggplot(bar5_treat, aes(x = trueextracted, y = frac, fill = condition)) + geom_bar(stat = "identity", position = "dodge", width = 0.5, alpha = 0.85) + geom_vline(data = means5_treat, aes(xintercept = mean_te, colour = condition), linewidth = 0.9, linetype = "dashed") + scale_fill_manual(values = cmp_fills, labels = cmp_labels) + scale_colour_manual(values = cmp_colours, labels = cmp_labels, guide = "none") + scale_x_continuous(breaks = 0:1, labels = c("Not extracted", "Extracted")) + facet_wrap(~ treatment_cmp, nrow = 1) + labs( x = "True rule extracted", y = "Fraction of subjects", fill = NULL, title = "True rule extraction — by treatment" ) + theme_bw() + theme(legend.position = "bottom") print(p_bar5_treat) # SECTION 6: SUBOPTIMAL EXTRACTION -------- cat("\n===== SECTION 6: SUBOPTIMAL EXTRACTION =====\n") # ---- 2L: 16 rules, n = 12, threshold > 6 rule_table_2L <- expand.grid(p00 = 0:1, p01 = 0:1, p10 = 0:1, p11 = 0:1) %>% as_tibble() %>% mutate(rule_id = row_number()) %>% pivot_longer(-rule_id, names_to = "config_key", values_to = "rule_pred") %>% mutate(Light_Config = case_when( config_key == "p00" ~ "(0,0)", config_key == "p01" ~ "(0,1)", config_key == "p10" ~ "(1,0)", config_key == "p11" ~ "(1,1)" )) %>% select(-config_key) true_ids_2L <- Data_2L_cmp %>% distinct(treatment, Light_Config, Machine_CorrectP) %>% left_join(rule_table_2L, by = "Light_Config", relationship = "many-to-many") %>% group_by(treatment, rule_id) %>% summarise(is_true = all(Machine_CorrectP == rule_pred), .groups = "drop") %>% filter(is_true) %>% select(treatment, true_rule_id = rule_id) subopt_2L <- Data_2L_cmp %>% select(subject_id, treatment, condition, Light_Config, Guess) %>% left_join(rule_table_2L, by = "Light_Config", relationship = "many-to-many") %>% left_join(true_ids_2L, by = "treatment") %>% filter(rule_id != true_rule_id) %>% group_by(subject_id, treatment, condition, rule_id) %>% summarise(n_match = sum(Guess == rule_pred), .groups = "drop") %>% rowwise() %>% mutate(passes = binom.test(n_match, 12, 0.5)$p.value <= 0.01 & n_match > 6) %>% ungroup() %>% group_by(subject_id, treatment, condition) %>% summarise(suboptimalextracted = as.integer(any(passes)), .groups = "drop") # ---- 3L: 256 rules, n = 16, threshold > 8 rule_table_3L <- expand.grid( p000 = 0:1, p001 = 0:1, p010 = 0:1, p011 = 0:1, p100 = 0:1, p101 = 0:1, p110 = 0:1, p111 = 0:1 ) %>% as_tibble() %>% mutate(rule_id = row_number()) %>% pivot_longer(-rule_id, names_to = "config_key", values_to = "rule_pred") %>% mutate(Light_Config = case_when( config_key == "p000" ~ "(0,0,0)", config_key == "p001" ~ "(0,0,1)", config_key == "p010" ~ "(0,1,0)", config_key == "p011" ~ "(0,1,1)", config_key == "p100" ~ "(1,0,0)", config_key == "p101" ~ "(1,0,1)", config_key == "p110" ~ "(1,1,0)", config_key == "p111" ~ "(1,1,1)" )) %>% select(-config_key) true_ids_3L <- Data_3L_cmp %>% distinct(treatment, Light_Config, Machine_CorrectP) %>% left_join(rule_table_3L, by = "Light_Config", relationship = "many-to-many") %>% group_by(treatment, rule_id) %>% summarise(is_true = all(Machine_CorrectP == rule_pred), .groups = "drop") %>% filter(is_true) %>% select(treatment, true_rule_id = rule_id) subopt_3L <- Data_3L_cmp %>% select(subject_id, treatment, condition, Light_Config, Guess) %>% left_join(rule_table_3L, by = "Light_Config", relationship = "many-to-many") %>% left_join(true_ids_3L, by = "treatment") %>% filter(rule_id != true_rule_id) %>% group_by(subject_id, treatment, condition, rule_id) %>% summarise(n_match = sum(Guess == rule_pred), .groups = "drop") %>% rowwise() %>% mutate(passes = binom.test(n_match, 16, 0.5)$p.value <= 0.01 & n_match > 8) %>% ungroup() %>% group_by(subject_id, treatment, condition) %>% summarise(suboptimalextracted = as.integer(any(passes)), .groups = "drop") # ---- Combine extraction stats extract_2L <- te_cmp %>% filter(source == "2L") %>% left_join(subopt_2L, by = c("subject_id", "treatment", "condition")) %>% mutate(randomizing = as.integer(trueextracted == 0 & suboptimalextracted == 0)) extract_3L <- te_cmp %>% filter(source == "3L") %>% left_join(subopt_3L, by = c("subject_id", "treatment", "condition")) %>% mutate(randomizing = as.integer(trueextracted == 0 & suboptimalextracted == 0)) extraction_cmp <- bind_rows(extract_2L, extract_3L) %>% mutate(condition = factor(condition, levels = cmp_levels)) # ---- Summary tables tbl6_agg <- extraction_cmp %>% group_by(condition) %>% summarise( N = n(), true_ext_pct = round(mean(trueextracted, na.rm = TRUE) * 100, 1), subopt_pct = round(mean(suboptimalextracted, na.rm = TRUE) * 100, 1), rand_pct = round(mean(randomizing, na.rm = TRUE) * 100, 1), .groups = "drop" ) cat("\n--- Section 6: Extraction stats by condition ---\n") print(tbl6_agg) tbl6_treat <- extraction_cmp %>% filter(!is.na(treatment_cmp), treatment_cmp %in% common_treatments) %>% group_by(treatment_cmp, condition) %>% summarise( N = n(), true_ext_pct = round(mean(trueextracted, na.rm = TRUE) * 100, 1), subopt_pct = round(mean(suboptimalextracted, na.rm = TRUE) * 100, 1), rand_pct = round(mean(randomizing, na.rm = TRUE) * 100, 1), .groups = "drop" ) cat("\n--- Section 6: Extraction stats by treatment × condition ---\n") print(tbl6_treat) # ---- Bar chart: extraction type by condition (aggregate) extr6_long_agg <- tbl6_agg %>% pivot_longer(c(true_ext_pct, subopt_pct, rand_pct), names_to = "type", values_to = "pct") %>% mutate(type = factor(type, levels = c("true_ext_pct", "subopt_pct", "rand_pct"), labels = c("True Extracted", "Suboptimal", "Randomizing"))) p_extr6_agg <- ggplot(extr6_long_agg, aes(x = type, y = pct, fill = condition)) + geom_bar(stat = "identity", position = "dodge", alpha = 0.85) + scale_fill_manual(values = cmp_fills, labels = cmp_labels) + labs( x = NULL, y = "% of subjects", fill = NULL, title = "Extraction type by condition — aggregate" ) + theme_bw() + theme(legend.position = "bottom") print(p_extr6_agg) # ---- Bar chart: extraction type by treatment × condition extr6_long_treat <- tbl6_treat %>% pivot_longer(c(true_ext_pct, subopt_pct, rand_pct), names_to = "type", values_to = "pct") %>% mutate( type = factor(type, levels = c("true_ext_pct", "subopt_pct", "rand_pct"), labels = c("True Extracted", "Suboptimal", "Randomizing")), treatment_cmp = factor(treatment_cmp, levels = common_treatments) ) p_extr6_treat <- ggplot(extr6_long_treat, aes(x = type, y = pct, fill = condition)) + geom_bar(stat = "identity", position = "dodge", alpha = 0.85) + scale_fill_manual(values = cmp_fills, labels = cmp_labels) + facet_wrap(~ treatment_cmp, nrow = 1) + labs( x = NULL, y = "% of subjects", fill = NULL, title = "Extraction type by condition — by treatment" ) + theme_bw() + theme( legend.position = "bottom", axis.text.x = element_text(angle = 45, hjust = 1, size = 8) ) print(p_extr6_treat) # SECTION 7: Rule_used — posterior bar chart------ cat("\n===== SECTION 7: RULE_USED =====\n") RULE_POST_THRESHOLD <- 0.0 rule_obs_cmp <- ss_cmp %>% filter(is.na(posterior) | posterior >= RULE_POST_THRESHOLD) %>% mutate( # Collapse s_* (3L-specific composite strategies) to "Other" Rule_used_std = if_else(str_starts(Rule_used_std, "s_"), "Other", Rule_used_std) ) # Proportion table by treatment × condition × rule (common treatments only) tbl7 <- rule_obs_cmp %>% filter(!is.na(treatment_cmp), treatment_cmp %in% common_treatments) %>% mutate(treatment_cmp = factor(treatment_cmp, levels = common_treatments)) %>% group_by(treatment_cmp, condition, Rule_used_std) %>% summarise(n = n(), .groups = "drop") %>% group_by(treatment_cmp, condition) %>% mutate(pct = round(n / sum(n) * 100, 1)) %>% ungroup() %>% arrange(treatment_cmp, condition, desc(pct)) print(tbl7) p_rule7 <- tbl7 %>% ggplot(aes(x = reorder(Rule_used_std, -pct), y = pct, fill = condition)) + geom_bar(stat = "identity", position = "dodge", alpha = 0.85) + scale_fill_manual(values = cmp_fills, labels = cmp_labels) + facet_wrap(~ treatment_cmp, nrow = 1, scales = "free_x") + labs( x = "Rule used", y = "% of subjects", fill = NULL, title = "Rule used — by treatment" ) + theme_bw() + theme( axis.text.x = element_text(angle = 45, hjust = 1, size = 7), legend.position = "bottom" ) print(p_rule7) # SECTION 7b: Rule_used — 3-category chart (by treatment)------ cat("\n===== SECTION 7b: RULE CATEGORIES — BY TREATMENT =====\n") # True rule names after "_rb" stripping — same strings as treatment name true_rule_map_cmp <- c( AND = "AND", OR = "OR", INHIBIT = "INHIBIT", EITHER = "EITHER", JOINT = "JOINT" ) rule_cat_cmp <- rule_obs_cmp %>% filter(!is.na(treatment_cmp), treatment_cmp %in% common_treatments) %>% mutate( true_rule = true_rule_map_cmp[as.character(treatment_cmp)], category = case_when( Rule_used_std %in% c("random", "always_p") ~ "Random", Rule_used_std == true_rule ~ "True Extracted", TRUE ~ "Other Rule" ), category = factor(category, levels = c("True Extracted", "Other Rule", "Random")), treatment_cmp = factor(treatment_cmp, levels = common_treatments) ) rule_cat_plot <- rule_cat_cmp %>% group_by(treatment_cmp, condition, category) %>% summarise(n = n(), .groups = "drop") %>% group_by(treatment_cmp, condition) %>% mutate(pct = round(n / sum(n) * 100, 1)) %>% ungroup() p_rule_cat7b <- ggplot(rule_cat_plot, aes(x = category, y = pct, fill = condition)) + geom_bar(stat = "identity", position = "dodge", alpha = 0.85) + scale_fill_manual(values = cmp_fills, labels = cmp_labels) + facet_wrap(~ treatment_cmp, nrow = 1) + labs( x = "Category", y = "% of subjects", fill = NULL, title = "Rule category — by treatment" ) + theme_bw() + theme( axis.text.x = element_text(angle = 45, hjust = 1, size = 8), legend.position = "bottom" ) print(p_rule_cat7b) # SECTION 7c: Rule_used — 3-category chart (aggregate / pooled)----- cat("\n===== SECTION 7c: RULE CATEGORIES — POOLED =====\n") # Pooled across all treatments in ss_cmp (including ALONE for 2L) rule_cat_all_cmp <- rule_obs_cmp %>% mutate( true_rule = case_when( treatment_cmp %in% names(true_rule_map_cmp) ~ true_rule_map_cmp[as.character(treatment_cmp)], treatment == "ALONE_easy" ~ "red", treatment == "ALONE_difficult" ~ "blue", TRUE ~ NA_character_ ), category = case_when( Rule_used_std %in% c("random", "always_p") ~ "Random", !is.na(true_rule) & Rule_used_std == true_rule ~ "True Extracted", TRUE ~ "Other Rule" ), category = factor(category, levels = c("True Extracted", "Other Rule", "Random")) ) rule_cat_pooled <- rule_cat_all_cmp %>% group_by(condition, category) %>% summarise(n = n(), .groups = "drop") %>% group_by(condition) %>% mutate(pct = round(n / sum(n) * 100, 1)) %>% ungroup() p_rule_cat7c <- ggplot(rule_cat_pooled, aes(x = category, y = pct, fill = condition)) + geom_bar(stat = "identity", position = "dodge", alpha = 0.85) + scale_fill_manual(values = cmp_fills, labels = cmp_labels) + labs( x = "Category", y = "Share of subjects", fill = NULL, title = "" ) + theme_bw() + theme( axis.text.x = element_text(angle = 45, hjust = 1, size = 8), legend.position = "bottom" ) print(p_rule_cat7c) # SECTION 8: RULES WITH STRICTLY POSITIVE POSTERIOR -------- cat("\n===== SECTION 8: RULES WITH STRICTLY POSITIVE POSTERIOR =====\n") # For each condition, count distinct rules assigned posterior > 0. # Each row in ss_cmp is one subject × treatment with its best-fit rule and # posterior; a rule "counts" if it appears with posterior > 0 for at least # one subject in that condition. tbl8_agg <- rule_obs_cmp %>% filter(!is.na(posterior), posterior > 0) %>% group_by(condition) %>% summarise( N_subjects = n_distinct(subject_id), N_rules_pos = n_distinct(Rule_used_std), .groups = "drop" ) %>% mutate(condition = factor(condition, levels = cmp_levels)) %>% arrange(condition) cat("\n--- Section 8: Rules with strictly positive posterior by condition ---\n") print(tbl8_agg) tbl8_treat <- rule_obs_cmp %>% filter(!is.na(posterior), posterior > 0, !is.na(treatment_cmp), treatment_cmp %in% common_treatments) %>% mutate(treatment_cmp = factor(treatment_cmp, levels = common_treatments), condition = factor(condition, levels = cmp_levels)) %>% group_by(treatment_cmp, condition) %>% summarise( N_subjects = n_distinct(subject_id), N_rules_pos = n_distinct(Rule_used_std), .groups = "drop" ) %>% arrange(treatment_cmp, condition) cat("\n--- Section 8: Rules with strictly positive posterior by treatment × condition ---\n") print(tbl8_treat) tbl8_avg <- tbl8_treat %>% group_by(condition) %>% summarise( avg_rules_pos = round(mean(N_rules_pos), 2), .groups = "drop" ) %>% mutate(condition = factor(condition, levels = cmp_levels)) %>% arrange(condition) cat("\n--- Section 8: Average rules with strictly positive posterior by condition ---\n") print(tbl8_avg)