% habit formation study

clc, clear
rng(0,'twister');

D = dir('*.csv');

% initialize subject-level arrays
subjProbe = [];
subjLearnSocial = [];
subjLearnIndiv = [];

for s = 1:length({D.name})
    %% Data organizing
    opts = detectImportOptions(D(s).name);
    opts = setvartype(opts, 'phase', 'char');
    opts = setvartype(opts, 'block_type', 'char');
    opts = setvartype(opts, 'left_shape_reward_level', 'char');
    opts = setvartype(opts, 'right_shape_reward_level', 'char');
    dat = readtable(D(s).name, opts);
    
    % main variables (filter by valid choices)
    validIdx = ~isnan(dat.choice);
    choice = dat.choice(validIdx);
    trialType = dat.trial_type(validIdx);  % 1=normal, 0=probe
    phase = dat.phase(validIdx);
    blockType = dat.block_type(validIdx);
    leftShapeRew = dat.left_shape_reward_level(validIdx);
    rightShapeRew = dat.right_shape_reward_level(validIdx);
    rewLeft = dat.rewardLeft(validIdx);
    rewRight = dat.rewardRight(validIdx);
    
    %% probe choices: P(choose originally-high shape)
    % indices: phase(1=train, 2=test) x blockType(1=social, 2=indiv)
    probeHigh = zeros(2,2);
    probeTotal = zeros(2,2);
    
    for i = 1:height(choice)
        if trialType(i) == 0  % probe trial
            phaseIdx = 1 + strcmp(phase{i}, 'test');
            blockIdx = 1 + strcmp(blockType{i}, 'individual');
            
            probeTotal(phaseIdx, blockIdx) = probeTotal(phaseIdx, blockIdx) + 1;
            
            % did sub choose the originally-high shape?
            if (choice(i)==37 && strcmp(leftShapeRew{i}, 'high')) || ...
               (choice(i)==39 && strcmp(rightShapeRew{i}, 'high'))
                probeHigh(phaseIdx, blockIdx) = probeHigh(phaseIdx, blockIdx) + 1;
            end
        end
    end
    probeProp = probeHigh ./ probeTotal;  % [train;test] x [social,indiv]
    
    %% learning curves: P(choose currently-optimal) in test phase
 
    % find test phase normal trials by block type
    testSocialIdx = find(trialType==1 & strcmp(phase,'test') & strcmp(blockType,'social'));
    testIndivIdx = find(trialType==1 & strcmp(phase,'test') & strcmp(blockType,'individual'));
    
    % track choosing currently-optimal (higher reward on that trial)
    chooseOptSocial = zeros(length(testSocialIdx), 1);
    chooseOptIndiv = zeros(length(testIndivIdx), 1);
    
    for j = 1:length(testSocialIdx)
        i = testSocialIdx(j);
        chooseOptSocial(j) = (choice(i)==37 && rewLeft(i) >= rewRight(i)) || ...
                             (choice(i)==39 && rewRight(i) >= rewLeft(i));
    end
    
    for j = 1:length(testIndivIdx)
        i = testIndivIdx(j);
        chooseOptIndiv(j) = (choice(i)==37 && rewLeft(i) >= rewRight(i)) || ...
                            (choice(i)==39 && rewRight(i) >= rewLeft(i));
    end
    
    % bin into thirds (early, mid, late)
    nSoc = length(testSocialIdx);
    nInd = length(testIndivIdx);
    binSoc = floor(nSoc/3);
    binInd = floor(nInd/3);

    learnSocial = [mean(chooseOptSocial(1:binSoc)), ...
                   mean(chooseOptSocial(binSoc+1:2*binSoc)), ...
                   mean(chooseOptSocial(2*binSoc+1:end))];

    learnIndiv = [mean(chooseOptIndiv(1:binInd)), ...
                  mean(chooseOptIndiv(binInd+1:2*binInd)), ...
                  mean(chooseOptIndiv(2*binInd+1:end))];

    
    %% store subject data
    % probe: [train_soc, train_ind, test_soc, test_ind]
    subjProbe(s,:) = [probeProp(1,1), probeProp(1,2), probeProp(2,1), probeProp(2,2)];
    
    % learning curves: [early, mid, late]
    subjLearnSocial(s,:) = learnSocial;
    subjLearnIndiv(s,:) = learnIndiv;
end

save habitSpecs subjProbe subjLearnSocial subjLearnIndiv

%% group stats
nSubj = size(subjProbe, 1);
fprintf('\n============================================================\n');
fprintf('Group Results\n');
fprintf('============================================================\n');

%% ------------------------------------------------------------------------
%  Probe Choices
%% ------------------------------------------------------------------------
fprintf('\n=== Probe Choices (Probability of choosing originally-high shape) ===\n');

% habit strength = test probe - 0.5
habitSocial = subjProbe(:,3) - 0.5;
habitIndiv = subjProbe(:,4) - 0.5;

% learning = training probe - 0.5
learningSocial = subjProbe(:,1) - 0.5;
learningIndiv = subjProbe(:,2) - 0.5;


fprintf('Training:\n');
fprintf('  Social: M=%.3f, SD=%.3f\n', mean(subjProbe(:,1)), std(subjProbe(:,1)));
fprintf('  Individual: M=%.3f, SD=%.3f\n', mean(subjProbe(:,2)), std(subjProbe(:,2)));
fprintf('Test:\n');
fprintf('  Social: M=%.3f, SD=%.3f\n', mean(subjProbe(:,3)), std(subjProbe(:,3)));
fprintf('  Individual: M=%.3f, SD=%.3f\n', mean(subjProbe(:,4)), std(subjProbe(:,4)));

fprintf('\n--- Learning (training probe vs chance) ---\n');
[~,p,ci,stats] = ttest(subjProbe(:,1), 0.5);
d = mean(learningSocial)/std(learningSocial);
fprintf('Social: t(%d)=%.2f, p=%.3f, d=%.2f, 95%% CI [%.3f, %.3f]\n', ...
    stats.df, stats.tstat, p, d, ci(1), ci(2));

[~,p,ci,stats] = ttest(subjProbe(:,2), 0.5);
d = mean(learningIndiv)/std(learningIndiv);
fprintf('Individual: t(%d)=%.2f, p=%.3f, d=%.2f, 95%% CI [%.3f, %.3f]\n', ...
    stats.df, stats.tstat, p, d, ci(1), ci(2));

fprintf('\n--- Habit Strength (test probe vs chance) ---\n');
[~,p,ci,stats] = ttest(subjProbe(:,3), 0.5);
d = mean(habitSocial)/std(habitSocial);
fprintf('Social: t(%d)=%.2f, p=%.3f, d=%.2f, 95%% CI [%.3f, %.3f]\n', ...
    stats.df, stats.tstat, p, d, ci(1), ci(2));

[~,p,ci,stats] = ttest(subjProbe(:,4), 0.5);
d = mean(habitIndiv)/std(habitIndiv);
fprintf('Individual: t(%d)=%.2f, p=%.3f, d=%.2f, 95%% CI [%.3f, %.3f]\n', ...
    stats.df, stats.tstat, p, d, ci(1), ci(2));

fprintf('\n--- Social vs Individual Habit ---\n');
[~,p,ci,stats] = ttest(subjProbe(:,3), subjProbe(:,4));
d = mean(habitSocial - habitIndiv)/std(habitSocial - habitIndiv);
fprintf('Paired t-test: t(%d)=%.2f, p=%.3f, d=%.2f, 95%% CI [%.3f, %.3f]\n', ...
    stats.df, stats.tstat, p, d, ci(1), ci(2));

fprintf('\nHabit Strength:\n');
fprintf('  Social: %.3f (vs chance)\n', mean(habitSocial));
fprintf('  Individual: %.3f (vs chance)\n', mean(habitIndiv));
fprintf('  Difference: %.3f\n', mean(habitSocial) - mean(habitIndiv));

%% ------------------------------------------------------------------------
%  Learning Curves - Adaptation to Reversal
%% ------------------------------------------------------------------------
fprintf('\n=== Learning Curves (Probability of choosing currently-optimal in test) ===\n');


fprintf('Social:\n');
fprintf('  Early: M=%.3f, SD=%.3f\n', nanmean(subjLearnSocial(:,1)), nanstd(subjLearnSocial(:,1)));
fprintf('  Mid: M=%.3f, SD=%.3f\n', nanmean(subjLearnSocial(:,2)), nanstd(subjLearnSocial(:,2)));
fprintf('  Late: M=%.3f, SD=%.3f\n', nanmean(subjLearnSocial(:,3)), nanstd(subjLearnSocial(:,3)));
fprintf('Individual:\n');
fprintf('  Early: M=%.3f, SD=%.3f\n', nanmean(subjLearnIndiv(:,1)), nanstd(subjLearnIndiv(:,1)));
fprintf('  Mid: M=%.3f, SD=%.3f\n', nanmean(subjLearnIndiv(:,2)), nanstd(subjLearnIndiv(:,2)));
fprintf('  Late: M=%.3f, SD=%.3f\n', nanmean(subjLearnIndiv(:,3)), nanstd(subjLearnIndiv(:,3)));

fprintf('\n--- Adaptation (Late - Early) ---\n');
adaptSocial = subjLearnSocial(:,3) - subjLearnSocial(:,1);
adaptIndiv = subjLearnIndiv(:,3) - subjLearnIndiv(:,1);

fprintf('Social adaptation: M=%.3f, SD=%.3f\n', nanmean(adaptSocial), nanstd(adaptSocial));
fprintf('Individual adaptation: M=%.3f, SD=%.3f\n', nanmean(adaptIndiv), nanstd(adaptIndiv));

% is adaptation different from 0?
[~,p,ci,stats] = ttest(adaptSocial);
d = nanmean(adaptSocial)/nanstd(adaptSocial);
fprintf('\nSocial adaptation vs 0: t(%d)=%.2f, p=%.3f, d=%.2f\n', stats.df, stats.tstat, p, d);

[~,p,ci,stats] = ttest(adaptIndiv);
d = nanmean(adaptIndiv)/nanstd(adaptIndiv);
fprintf('Individual adaptation vs 0: t(%d)=%.2f, p=%.3f, d=%.2f\n', stats.df, stats.tstat, p, d);

% social vs individual adaptation
[~,p,ci,stats] = ttest(adaptSocial, adaptIndiv);
d = nanmean(adaptSocial - adaptIndiv)/nanstd(adaptSocial - adaptIndiv);
fprintf('\nSocial vs Individual adaptation: t(%d)=%.2f, p=%.3f, d=%.2f, 95%% CI [%.3f, %.3f]\n', ...
    stats.df, stats.tstat, p, d, ci(1), ci(2));

%% ------------------------------------------------------------------------
%  3. Summary
%% ------------------------------------------------------------------------
fprintf('\n=== Summary ===\n');
fprintf('\nHypothesis: Social blocks create stronger habits (less adaptation)\n');
fprintf('\nResults:\n');
fprintf('1. Habit strength (test probe - 0.5):\n');
fprintf('   Social: %.3f, Individual: %.3f\n', mean(habitSocial), mean(habitIndiv));
fprintf('   Prediction: Social > Individual → %s\n', ...
    ternary(mean(habitSocial) > mean(habitIndiv), 'Supports hypothesis', 'Does not support hypothesis'));

fprintf('\n2. Adaptation rate (late - early in test):\n');
fprintf('   Social: %.3f, Individual: %.3f\n', nanmean(adaptSocial), nanmean(adaptIndiv));
fprintf('   Prediction: Social < Individual → %s\n', ...
    ternary(nanmean(adaptSocial) < nanmean(adaptIndiv), 'Supports hypothesis', 'Does not support hypothesis'));

fprintf('\n============================================================\n');


%% helper function
function out = ternary(cond, a, b)
    if cond
        out = a;
    else
        out = b;
    end
end