% habit figures


clc, clear
load('habitSpecs.mat')

nSubj = size(subjProbe, 1);

%% figure 1: probe choices by phase and block type

figure('Position', [100 100 500 400])

% get means and SEMs
% subjProbe columns: [train_soc, train_ind, test_soc, test_ind]
means = [nanmean(subjProbe(:,1)), nanmean(subjProbe(:,2)), ...
         nanmean(subjProbe(:,3)), nanmean(subjProbe(:,4))];
sds = [nanstd(subjProbe(:,1)), nanstd(subjProbe(:,2)), ...
       nanstd(subjProbe(:,3)), nanstd(subjProbe(:,4))];
sems = sds / sqrt(nSubj);


x = [1, 2, 4, 5];


colors = [0.3, 0.3, 0.3;   
          0.7, 0.7, 0.7;   
          0.3, 0.3, 0.3;  
          0.7, 0.7, 0.7];  

hold on

% plot bars
for i = 1:4
    bar(x(i), means(i), 0.8, 'FaceColor', colors(i,:), 'EdgeColor', 'k', 'LineWidth', 1.5);
end

% error bars (SEM)
errorbar(x, means, sems, 'k', 'LineStyle', 'none', 'LineWidth', 1.5, 'CapSize', 8);

% individual data points
jitterAmount = 0.15;
rng(42);
jitterVals = (rand(nSubj, 1) - 0.5) * jitterAmount;

for i = 1:4
    data = subjProbe(:, i);
    scatter(x(i) + jitterVals, data, 30, 'k', 'filled', 'MarkerFaceAlpha', 0.6);
end

% chance line
yline(0.5, 'r--', 'LineWidth', 1.5);

% labels
set(gca, 'XTick', [1.5, 4.5], 'XTickLabel', {'Training', 'Test'}, 'FontSize', 12);
ylabel('P(Choose Originally-High Shape)', 'FontSize', 12);
ylim([0 1.1]);
title(sprintf('Probe Choices: Habit Formation'), 'FontSize', 14);

% legend - create dummy plots for legend
h1 = bar(nan, nan, 'FaceColor', [0.3 0.3 0.3], 'EdgeColor', 'k');
h2 = bar(nan, nan, 'FaceColor', [0.7 0.7 0.7], 'EdgeColor', 'k');
legend([h1, h2], {'Social', 'Individual'}, 'Location', 'northeast', 'FontSize', 11);

hold off
set(gcf, 'Color', 'w')


%% figure 2: block-by-block Line Plot

D = dir('*.csv');

% 6 social blocks, 6 individual blocks per subject
socialByBlock = nan(nSubj, 6);
indivByBlock = nan(nSubj, 6);

for s = 1:nSubj
    opts = detectImportOptions(D(s).name);
    opts = setvartype(opts, 'phase', 'char');
    opts = setvartype(opts, 'block_type', 'char');
    opts = setvartype(opts, 'left_shape_reward_level', 'char');
    opts = setvartype(opts, 'right_shape_reward_level', 'char');
    dat = readtable(D(s).name, opts);
    
    validIdx = ~isnan(dat.choice);
    dat = dat(validIdx, :);
    
    % detect trials per block based on total trials
    % old version: 144 trials = 12 per block
    % new version: 192 trials = 16 per block
    nTrials = height(dat);
    if nTrials == 144
        trialsPerBlock = 12;
    elseif nTrials == 192
        trialsPerBlock = 16;
    else
        trialsPerBlock = 12;  % fallback
    end
    
    % add block number (1-12)
    dat.block_num = floor((0:height(dat)-1)' / trialsPerBlock) + 1;
    
    %get probe trials
    probes = dat(dat.trial_type == 0, :);
    
    % track social and individual block indices separately
    socialIdx = 0;
    indivIdx = 0;
    
    for block = 1:12
        blockProbes = probes(probes.block_num == block, :);
        if height(blockProbes) > 0
            blockType = blockProbes.block_type{1};
            
            % get P(choose high)
            choseHigh = (blockProbes.choice == 37 & strcmp(blockProbes.left_shape_reward_level, 'high')) | ...
                        (blockProbes.choice == 39 & strcmp(blockProbes.right_shape_reward_level, 'high'));
            pHigh = mean(choseHigh);
            
            if strcmp(blockType, 'social')
                socialIdx = socialIdx + 1;
                if socialIdx <= 6
                    socialByBlock(s, socialIdx) = pHigh;
                end
            else
                indivIdx = indivIdx + 1;
                if indivIdx <= 6
                    indivByBlock(s, indivIdx) = pHigh;
                end
            end
        end
    end
end

% get means and SEMs
socialMeans = nanmean(socialByBlock, 1);
socialSEMs = nanstd(socialByBlock, 0, 1) / sqrt(nSubj);
indivMeans = nanmean(indivByBlock, 1);
indivSEMs = nanstd(indivByBlock, 0, 1) / sqrt(nSubj);

% plot
figure('Position', [100 100 600 450])

x = 1:6;

hold on

% plot lines with error bars
errorbar(x, socialMeans, socialSEMs, 'o-', 'Color', 'k', ...
    'LineWidth', 2, 'MarkerSize', 8, 'MarkerFaceColor', 'k', 'CapSize', 5);
errorbar(x, indivMeans, indivSEMs, 's--', 'Color', [0.5 0.5 0.5], ...
    'LineWidth', 2, 'MarkerSize', 8, 'MarkerFaceColor', [0.5 0.5 0.5], 'CapSize', 5);

% chance line
yline(0.5, 'r--', 'LineWidth', 1);
xline(3.5, 'k--', 'LineWidth', 1.5)

% labels
set(gca, 'XTick', 1:6, 'XTickLabel', {'1', '2', '3', '4', '5', '6'}, 'FontSize', 12);
xlabel('Block', 'FontSize', 12);
ylabel('P(Choose Originally-High Shape)', 'FontSize', 12);
title(sprintf('Probe Choices by Block'), 'FontSize', 14);
ylim([0 1]);
xlim([0.5 6.5]);

% floating text for Train and Test
text(2, 0.05, 'Training', 'HorizontalAlignment', 'center', 'FontSize', 12, 'FontWeight', 'bold');
text(5, 0.05, 'Test', 'HorizontalAlignment', 'center', 'FontSize', 12, 'FontWeight', 'bold');

% legend
legend({'Social', 'Individual'}, 'Location', 'northeast', 'FontSize', 11);

hold off
set(gcf, 'Color', 'w')

%% print summary stats

fprintf('\n=== SUMMARY STATISTICS ===\n');
fprintf('\nProbe Choices:\n');
fprintf('Training Social:    M = %.3f, SD = %.3f\n', means(1), sds(1));
fprintf('Training Individual: M = %.3f, SD = %.3f\n', means(2), sds(2));
fprintf('Test Social:        M = %.3f, SD = %.3f\n', means(3), sds(3));
fprintf('Test Individual:    M = %.3f, SD = %.3f\n', means(4), sds(4));

fprintf('\nBlock-by-block means:\n');
fprintf('Social:     '); fprintf('%.2f  ', socialMeans); fprintf('\n');
fprintf('Individual: '); fprintf('%.2f  ', indivMeans); fprintf('\n');