from otree.api import *
import random
import math
from id_manager import get_profile

doc = """
PGG with judges + Norms & Empirical Expectations (Pre/Post).
"""

class C(BaseConstants):
    NAME_IN_URL = 'base'
    PLAYERS_PER_GROUP = None
    NUM_ROUNDS = 3  # Keep 3 for testing, change to 15 later
    ENDOWMENT = 20
    MPCR = 0.4

class Subsession(BaseSubsession):
    pass

class Group(BaseGroup):
    pass

class Player(BasePlayer):
    is_judge = models.BooleanField(initial=False)
    subgroup = models.IntegerField(initial=0)
    pgg_group = models.IntegerField(initial=0)
    pid_input = models.StringField(initial='')

    # PGG fields
    contribution = models.IntegerField(min=0, max=20, initial=0)
    payoff_pre = models.FloatField(initial=0.0)
    punishment_received = models.IntegerField(initial=0) 
    group_punishment_cost = models.FloatField(initial=0.0)

    # Judge Fields
    p1_punish = models.IntegerField(min=0, max=5, initial=0)
    p2_punish = models.IntegerField(min=0, max=5, initial=0)
    p3_punish = models.IntegerField(min=0, max=5, initial=0)
    p4_punish = models.IntegerField(min=0, max=5, initial=0)
    p5_punish = models.IntegerField(min=0, max=5, initial=0)
    p6_punish = models.IntegerField(min=0, max=5, initial=0)
    p7_punish = models.IntegerField(min=0, max=5, initial=0)
    p8_punish = models.IntegerField(min=0, max=5, initial=0)
    p9_punish = models.IntegerField(min=0, max=5, initial=0)
    p10_punish = models.IntegerField(min=0, max=5, initial=0)
    p11_punish = models.IntegerField(min=0, max=5, initial=0)
    p12_punish = models.IntegerField(min=0, max=5, initial=0)

    # --- Norms Pre Fields ---
    np_0 = models.IntegerField(label="Avg contrib 0:", min=0, max=20)
    np_5 = models.IntegerField(label="Avg contrib 5:", min=0, max=20)
    np_10 = models.IntegerField(label="Avg contrib 10:", min=0, max=20)
    np_15 = models.IntegerField(label="Avg contrib 15:", min=0, max=20)
    np_20 = models.IntegerField(label="Avg contrib 20:", min=0, max=20)

    nn_0 = models.IntegerField(label="Avg contrib 0:", min=0, max=20)
    nn_5 = models.IntegerField(label="Avg contrib 5:", min=0, max=20)
    nn_10 = models.IntegerField(label="Avg contrib 10:", min=0, max=20)
    nn_15 = models.IntegerField(label="Avg contrib 15:", min=0, max=20)
    nn_20 = models.IntegerField(label="Avg contrib 20:", min=0, max=20)

    expected_contribution = models.IntegerField(label="Avg contribution session:", min=0, max=20)

    # --- Norms Post Fields ---
    np_0_post = models.IntegerField(label="Avg contrib 0:", min=0, max=20)
    np_5_post = models.IntegerField(label="Avg contrib 5:", min=0, max=20)
    np_10_post = models.IntegerField(label="Avg contrib 10:", min=0, max=20)
    np_15_post = models.IntegerField(label="Avg contrib 15:", min=0, max=20)
    np_20_post = models.IntegerField(label="Avg contrib 20:", min=0, max=20)

    nn_0_post = models.IntegerField(label="Avg contrib 0:", min=0, max=20)
    nn_5_post = models.IntegerField(label="Avg contrib 5:", min=0, max=20)
    nn_10_post = models.IntegerField(label="Avg contrib 10:", min=0, max=20)
    nn_15_post = models.IntegerField(label="Avg contrib 15:", min=0, max=20)
    nn_20_post = models.IntegerField(label="Avg contrib 20:", min=0, max=20)

    expected_contribution_post = models.IntegerField(label="Avg contribution session:", min=0, max=20)

    # --- NEW FIELDS FOR DATA EXPORT & PAYMENT ---
    chosen_level_pre = models.IntegerField(initial=0)
    chosen_level_post = models.IntegerField(initial=0)
    
    avg_personal_norm_pre = models.FloatField(initial=0.0)
    avg_personal_norm_post = models.FloatField(initial=0.0)
    avg_contribution_session = models.FloatField(initial=0.0)
    
    bonus_norm_pre = models.CurrencyField(initial=0)
    bonus_norm_post = models.CurrencyField(initial=0)
    bonus_emp_pre = models.CurrencyField(initial=0)
    bonus_emp_post = models.CurrencyField(initial=0)

        # --- NEW FIELDS FOR PART 1 DATA ---
    p1_cooperate_PD = models.BooleanField()
    p1_cooperate_Cond_PD_C = models.BooleanField()
    p1_cooperate_Cond_PD_D = models.BooleanField()
    
    p1_boxes_collected = models.IntegerField()
    p1_bomb = models.IntegerField()
    
    p1_epper_scenario = models.StringField()
    p1_epper_choice = models.IntegerField()
    
    epper_bonus = models.CurrencyField(initial=0)
    epper_bonus_other = models.CurrencyField(initial=0)
    
    # --- OUTPUT FIELDS FOR FINAL PAYMENT ---
    pd_payoff_final = models.CurrencyField(initial=0)
    bret_payoff_final = models.CurrencyField(initial=0)
    epper_payoff_final = models.CurrencyField(initial=0)

     # --- Payment Fields ---
    paid_rounds_str = models.StringField(initial="")
    pgg_payoff_sum = models.CurrencyField(initial=0)
    pgg_czk_total = models.CurrencyField(initial=0)
    norms_total_bonus = models.CurrencyField(initial=0)
    final_payment_czk = models.CurrencyField(initial=0)


### Testing for norms - prefills numbers , comment for the real code

def creating_session(subsession: Subsession):
    levels = [0, 5, 10, 15, 20]

    # 1. Pre Norms: Generate ONLY in Round 1
    if subsession.round_number == 1:
        for p in subsession.get_players():
            for L in levels:
                setattr(p, f'np_{L}', random.randint(0, 20))
                setattr(p, f'nn_{L}', random.randint(0, 20))
            p.expected_contribution = random.randint(0, 20)

    # 2. Post Norms: Generate ONLY in the Last Round
    if subsession.round_number == C.NUM_ROUNDS:
        for p in subsession.get_players():
            for L in levels:
                setattr(p, f'np_{L}_post', random.randint(0, 20))
                setattr(p, f'nn_{L}_post', random.randint(0, 20))
            p.expected_contribution_post = random.randint(0, 20)


# =====================================================
# PAGES
# =====================================================

class IDInput(Page):
    form_model = 'player'
    form_fields = ['pid_input']

    @staticmethod
    def is_displayed(player: Player):
        return player.round_number == 1

    @staticmethod
    def error_message(player, values):
        input_id = values['pid_input'].strip()
        profile = get_profile(input_id)
        
        if not profile:
            return f"ID '{input_id}' not found. Please check your ID from Part 1."
        
        if not profile.get('finished_part_1'):
            return f"ID '{input_id}' exists but did not finish Part 1."

    @staticmethod
    def before_next_page(player, timeout_happened):
        # 1. Load Data
        input_id = player.pid_input.strip()
        profile = get_profile(input_id)
        
        # 2. Store in Participant Vars (for global access and display)
        player.participant.vars['p1_id'] = input_id
        player.participant.vars['p1_data'] = profile
        
        # 3. SAVE TO PLAYER DATABASE FIELDS (Fixes your AttributeError)
        # We use .get() to avoid errors if a field is missing in the JSON
        player.p1_cooperate_PD = profile.get('cooperate_PD')
        player.p1_cooperate_Cond_PD_C = profile.get('cooperate_Cond_PD_C')
        player.p1_cooperate_Cond_PD_D = profile.get('cooperate_Cond_PD_D')
        
        player.p1_boxes_collected = profile.get('boxes_collected', 0)
        player.p1_bomb = profile.get('bomb', 0)
        
        player.p1_epper_scenario = profile.get('epper_scenario_selected')
        player.p1_epper_choice = profile.get('epper_choice_raw')
        
        # 4. Process Epper Payoffs
        # Fix negative value: epper_payoff_me is usually negative in your logic (-750)
        # We want the absolute value for payment (750)
        raw_me = profile.get('epper_payoff_me', 0)
        raw_other = profile.get('epper_payoff_other', 0)
        
        player.epper_bonus = abs(raw_me)     # Save absolute value to DB
        player.epper_bonus_other = raw_other # Save other's payoff to DB
        
        # Store in participant vars for easy access in WaitPage logic if needed
        player.participant.vars['epper_bonus'] = player.epper_bonus
        
        print(f"SUCCESS: Saved Part 1 data for {input_id} to Player DB.")


class Part1CalculationWaitPage(WaitPage):
    wait_for_all_groups = True  # Global matching requires everyone

    @staticmethod
    def is_displayed(player: Player):
        return player.round_number == 1

    @staticmethod
    def after_all_players_arrive(subsession: Subsession):
        all_players = subsession.get_players()

        # ==========================
        # 1. PD PAYMENT (Real Logic)
        # ==========================
        pd_players = list(all_players)
        random.shuffle(pd_players)
        
        # Handle odd number of players (just in case)
        leftover = None
        if len(pd_players) % 2 == 1:
            leftover = pd_players.pop()
            leftover.participant.vars['pd_payoff'] = 0
            leftover.participant.vars['pd_role'] = 'Leftover (Odd number)'

        pd_pairs = []
        while len(pd_players) >= 2:
            pd_pairs.append((pd_players.pop(), pd_players.pop()))

        PD_MATRIX = {
            (True, True): 750, (True, False): 500,
            (False, True): 1000, (False, False): 600
        }

        for p1, p2 in pd_pairs:
            # Randomly assign roles
            if random.choice([True, False]):
                mover, responder = p1, p2
            else:
                mover, responder = p2, p1

            # Get Actions
            # Note: We must use the loaded fields (p1_...)
            action_mover = mover.p1_cooperate_PD # Unconditional
            
            # Responder Conditional Logic
            if action_mover: 
                action_responder = responder.p1_cooperate_Cond_PD_C
            else:
                action_responder = responder.p1_cooperate_Cond_PD_D
            
            # Calculate
            pay_mover = PD_MATRIX.get((action_mover, action_responder), 0)
            pay_responder = PD_MATRIX.get((action_responder, action_mover), 0)

            # Save
            mover.participant.vars['pd_payoff'] = pay_mover
            mover.participant.vars['pd_role'] = 'Mover (Unconditional)'
            
            responder.participant.vars['pd_payoff'] = pay_responder
            responder.participant.vars['pd_role'] = 'Responder (Conditional)'

        # =============================
        # 2. EPPER PAYMENT (Real Logic)
        # =============================
        epper_players = list(all_players)
        random.shuffle(epper_players)
        midpoint = len(epper_players) // 2
        
        group_self = epper_players[:midpoint]
        group_other = epper_players[midpoint:]

        # Group Self: Get their OWN bonus
        available_other_payoffs = []
        for p in group_self:
            # Using the 'epper_bonus' field we loaded in IDInput
            p.participant.vars['epper_final'] = p.epper_bonus 
            p.participant.vars['epper_role'] = 'Paid Self Choice'
            
            # Collect what they give
            available_other_payoffs.append(p.epper_bonus_other)

        # Group Other: Get random OTHER bonus
        random.shuffle(available_other_payoffs)
        for i, p in enumerate(group_other):
            if i < len(available_other_payoffs):
                val = available_other_payoffs[i]
            else:
                val = 0
            p.participant.vars['epper_final'] = val
            p.participant.vars['epper_role'] = 'Paid Other Choice'

        # ==========================
        # 3. BRET PAYMENT
        # ==========================
        for p in all_players:
            if p.p1_bomb:
                p.participant.vars['bret_payoff'] = 0
            else:
                p.participant.vars['bret_payoff'] = (p.p1_boxes_collected or 0) * 30


class Part1Summary(Page):
    @staticmethod
    def is_displayed(player: Player):
        return player.round_number == 1

    @staticmethod
    def vars_for_template(player: Player):
        # 1. Get raw loaded data
        p1_data = player.participant.vars.get('p1_data', {})
        
        # 2. Get calculated results (from WaitPage)
        vars = player.participant.vars
        
        return {
            # Loading Data
            'p1_id': vars.get('p1_id', 'Unknown'),
            'cooperate_PD': p1_data.get('cooperate_PD'),
            'cooperate_Cond_PD_C': p1_data.get('cooperate_Cond_PD_C'),
            'cooperate_Cond_PD_D': p1_data.get('cooperate_Cond_PD_D'),
            'boxes_collected': p1_data.get('boxes_collected'),
            'bomb': p1_data.get('bomb'),
            'epper_scenario': p1_data.get('epper_scenario_selected'),
            'epper_choice': p1_data.get('epper_choice_raw'),
            'epper_payoff_me': p1_data.get('epper_payoff_me'),
            'epper_payoff_other': p1_data.get('epper_payoff_other'),
            'finished': p1_data.get('finished_part_1'),
            
            # Calculated Results (NEW)
            'pd_role': vars.get('pd_role', 'Not calculated'),
            'pd_payoff': vars.get('pd_payoff', 0),
            'epper_role': vars.get('epper_role', 'Not calculated'),
            'epper_final': vars.get('epper_final', 0),
            'bret_payoff': vars.get('bret_payoff', 0),
        }



class Intro(Page):
    @staticmethod
    def is_displayed(player: Player):
        return player.round_number == 1
    
    @staticmethod
    def vars_for_template(player: Player):
        return dict(treatment=player.session.config.get('treatment_type'))


class SetupWaitPage(WaitPage):
    @staticmethod
    def after_all_players_arrive(group: Group):
        subsession = group.subsession
        players = subsession.get_players()
        treatment = subsession.session.config.get('treatment_type')

        # 1. Round 1: Assign fixed judges and subgroups
        if subsession.round_number == 1:
            for p in players:
                p.participant.vars['is_judge'] = False
                p.participant.vars['subgroup'] = 0

            # ONLY human judge treatments get actual judge players
            if treatment in ['human_judge', 'humanAI_judge']:
                # Judges are Player 1 and Player 2
                for p in players:
                    if p.id_in_subsession == 1:
                        p.participant.vars['is_judge'] = True
                        p.participant.vars['subgroup'] = 1
                    elif p.id_in_subsession == 2:
                        p.participant.vars['is_judge'] = True
                        p.participant.vars['subgroup'] = 2

                normal_players = [p for p in players if not p.participant.vars['is_judge']]
                N = len(normal_players)
                
                # Split logic
                valid_splits = []
                for a in range(4, N, 4):
                    b = N - a
                    if b >= 4 and b % 4 == 0:
                        valid_splits.append((a, b))
                
                if not valid_splits:
                    if N % 8 == 0: a, b = N//2, N//2
                    else: raise ValueError(f"Cannot split {N} players into subgroups divisible by 4.")
                else:
                    a, b = min(valid_splits, key=lambda x: abs(x[0] - x[1]))

                random.shuffle(normal_players)
                for p in normal_players[:a]: p.participant.vars['subgroup'] = 1
                for p in normal_players[a:a+b]: p.participant.vars['subgroup'] = 2
            
            else: 
                # no_judge AND AI_judge: Everyone is normal, everyone in Subgroup 1
                for p in players: 
                    p.participant.vars['subgroup'] = 1

        # 2. Every Round: Reset values
        for p in players:
            p.is_judge = p.participant.vars.get('is_judge', False)
            p.subgroup = p.participant.vars.get('subgroup', 0)
            p.pgg_group = 0
            
            # Reset punishment fields
            p.punishment_received = 0
            p.group_punishment_cost = 0.0
            if p.is_judge:
                for i in range(1, 13):
                     setattr(p, f'p{i}_punish', 0)

        # 3. Every Round: Reshuffle PGG groups
        normal_players = [p for p in players if not p.is_judge]
        global_group_id = 1
        for sg in sorted(set(p.subgroup for p in normal_players)):
            sg_players = [p for p in normal_players if p.subgroup == sg]
            random.shuffle(sg_players)
            for i in range(0, len(sg_players), 4):
                for p in sg_players[i:i+4]:
                    p.pgg_group = global_group_id
                global_group_id += 1

class Instructions(Page):

    @staticmethod
    def is_displayed(player: Player):
        return player.round_number == 1
        return True

    @staticmethod
    def vars_for_template(player: Player):
        return dict(
            treatment=player.session.config.get('treatment_type'),
            is_judge=player.is_judge,
            endowment=C.ENDOWMENT,
            mpcr=C.MPCR,
        )

class NormsPersonal(Page):
    form_model = 'player'
    form_fields = ['np_0', 'np_5', 'np_10', 'np_15', 'np_20']
    @staticmethod
    def is_displayed(player: Player):
        return player.round_number == 1

class NormsNormative(Page):
    form_model = 'player'
    form_fields = ['nn_0', 'nn_5', 'nn_10', 'nn_15', 'nn_20']
    @staticmethod
    def is_displayed(player: Player):
        return player.round_number == 1

class NormsEmpirical(Page):
    form_model = 'player'
    form_fields = ['expected_contribution']
    @staticmethod
    def is_displayed(player: Player):
        return player.round_number == 1


class Cooperation(Page):
    form_model = 'player'
    form_fields = ['contribution']
    @staticmethod
    def is_displayed(player: Player):
        return not player.is_judge
    @staticmethod
    def vars_for_template(player: Player):
        return dict(endowment=C.ENDOWMENT, mpcr=C.MPCR)

class ResultsWaitPage(WaitPage):
    @staticmethod
    def after_all_players_arrive(group: Group):
        players = group.subsession.get_players()
        normal_players = [p for p in players if not p.is_judge]
        
        # 1. Calculate PGG outcomes
        pgg_ids = set(p.pgg_group for p in normal_players)
        for g_id in pgg_ids:
            members = [p for p in normal_players if p.pgg_group == g_id]
            total_contrib = sum(m.contribution for m in members)
            for m in members:
                m.payoff_pre = float(C.ENDOWMENT - m.contribution + (C.MPCR * total_contrib))
                m.payoff = math.ceil(m.payoff_pre)

        # 2. Calculate Judge Payoff (Average of Pre-Punishment Payoffs)
        judges = [p for p in players if p.is_judge]
        for j in judges:
            sg_players = [p for p in normal_players if p.subgroup == j.subgroup]
            if sg_players:
                avg_payoff = float(sum(p.payoff for p in sg_players) / len(sg_players))
                j.payoff = math.ceil(avg_payoff)
            else:
                j.payoff = 0

class JudgeWaitPage(WaitPage):
    @staticmethod
    def after_all_players_arrive(group: Group):
        subsession = group.subsession
        treatment = subsession.session.config.get('treatment_type')

        if treatment == 'no_judge':
            return

        # LOGIC FOR AI_judge (No human judges exist)
        if treatment == 'AI_judge':
            normal_players = subsession.get_players() # All players are normal
            subgroups = set(p.subgroup for p in normal_players)
            
            for sg in subgroups:
                sg_players = [p for p in normal_players if p.subgroup == sg]
                if not sg_players: continue
                
                avg_contrib = sum(p.contribution for p in sg_players) / len(sg_players)
                
                for p in sg_players:
                    deviation = avg_contrib - p.contribution
                    # Formula: min(5, max(0, round((avg - contrib) / 2)))
                    points = min(5, max(0, round(deviation / 2)))
                    p.punishment_received = points
            return

        # LOGIC FOR humanAI_judge (Pre-fill human judges)
        if treatment == 'humanAI_judge':
            judges = [p for p in subsession.get_players() if p.is_judge]
            
            for judge in judges:
                sg_players = [p for p in subsession.get_players() 
                              if not p.is_judge and p.subgroup == judge.subgroup]
                if not sg_players: continue

                avg_contrib = sum(p.contribution for p in sg_players) / len(sg_players)

                for i, p in enumerate(sg_players):
                    deviation = avg_contrib - p.contribution
                    points = min(5, max(0, round(deviation / 2)))
                    setattr(judge, f'p{i+1}_punish', points)

class Judge(Page):
    form_model = 'player'

    @staticmethod
    def is_displayed(player: Player):
        treatment = player.session.config.get('treatment_type')
        return player.is_judge and treatment in ['human_judge', 'humanAI_judge']

    @staticmethod
    def get_form_fields(player: Player):
        sg_players = [p for p in player.subsession.get_players() 
                      if not p.is_judge and p.subgroup == player.subgroup]
        return [f'p{i+1}_punish' for i in range(len(sg_players))]

    @staticmethod
    def vars_for_template(player: Player):
        sg_players = [p for p in player.subsession.get_players() 
                      if not p.is_judge and p.subgroup == player.subgroup]
        
        group_items = []
        for i, p in enumerate(sg_players):
            group_items.append({'player': p, 'field': f'p{i+1}_punish'})
            
        pgg_groups = {}
        for item in group_items:
            gid = item['player'].pgg_group
            pgg_groups.setdefault(gid, []).append(item)
            
        return dict(pgg_groups=pgg_groups)

    @staticmethod
    def before_next_page(player: Player, timeout_happened):
        sg_players = [p for p in player.subsession.get_players() 
                      if not p.is_judge and p.subgroup == player.subgroup]
        
        for i, p in enumerate(sg_players):
            val = getattr(player, f'p{i+1}_punish')
            p.punishment_received = val

class FinalWaitPage(WaitPage):
    @staticmethod
    def after_all_players_arrive(group: Group):
        players = group.subsession.get_players()
        normal_players = [p for p in players if not p.is_judge]

        # Apply Punishments
        pgg_ids = set(p.pgg_group for p in normal_players)
        for g_id in pgg_ids:
            members = [p for p in normal_players if p.pgg_group == g_id]
            
            total_punish = sum(m.punishment_received for m in members)
            shared_cost = total_punish / 4.0
            
            for m in members:
                m.group_punishment_cost = shared_cost
                final_val = float(m.payoff_pre - m.punishment_received - shared_cost)
                m.payoff = math.ceil(final_val)

class Results(Page):
    @staticmethod
    def vars_for_template(player: Player):
        if player.is_judge: return dict()

        group_members = [p for p in player.subsession.get_players() 
                         if not p.is_judge and p.pgg_group == player.pgg_group]
        
        anon_members = []
        for i, m in enumerate(group_members):
            anon_members.append({
                'label': chr(65+i), 
                'contribution': m.contribution, 
                'punishment': m.punishment_received
            })
            
        total_punish = sum(m.punishment_received for m in group_members)
        total_contribution = sum(m.contribution for m in group_members)
        return dict(
            anon_members=anon_members, 
            total_group_punish=total_punish,
            total_contribution=total_contribution
        )

class NormsPersonalPost(Page):
    form_model = 'player'
    form_fields = ['np_0_post', 'np_5_post', 'np_10_post', 'np_15_post', 'np_20_post']
    @staticmethod
    def is_displayed(player: Player):
        return player.round_number == C.NUM_ROUNDS

class NormsNormativePost(Page):
    form_model = 'player'
    form_fields = ['nn_0_post', 'nn_5_post', 'nn_10_post', 'nn_15_post', 'nn_20_post']
    @staticmethod
    def is_displayed(player: Player):
        return player.round_number == C.NUM_ROUNDS

class NormsEmpiricalPost(Page):
    form_model = 'player'
    form_fields = ['expected_contribution_post']
    @staticmethod
    def is_displayed(player: Player):
        return player.round_number == C.NUM_ROUNDS


class NormsAndPaymentWaitPage(WaitPage):
    @staticmethod
    def after_all_players_arrive(group: Group):
        subsession = group.subsession
        session = subsession.session

        if subsession.round_number != C.NUM_ROUNDS:
            return

        players_round1 = subsession.in_round(1).get_players()
        players_roundN = subsession.in_round(C.NUM_ROUNDS).get_players()

        # 1. Choose random levels
        levels = [0, 5, 10, 15, 20]
        pre_level = random.choice(levels)
        post_level = random.choice(levels)

        # 2. Calculate Averages (Personal Norms)
        # Pre
        pre_vals = [getattr(p, f'np_{pre_level}') for p in players_round1]
        avg_pre_personal = sum(pre_vals) / len(pre_vals) if pre_vals else 0
        
        # Post
        post_vals = [getattr(p, f'np_{post_level}_post') for p in players_roundN]
        avg_post_personal = sum(post_vals) / len(post_vals) if post_vals else 0

        # 3. Calculate Average Session Contribution (for Empirical)
        all_subsessions = session.get_subsessions()
        all_normal_players = []
        for ss in all_subsessions:
            all_normal_players.extend([p for p in ss.get_players() if not p.is_judge])
        
        total_contrib = sum(p.contribution for p in all_normal_players)
        avg_contrib = total_contrib / len(all_normal_players) if all_normal_players else 0

        # 4. Assign bonuses to PLAYERS (Save to DB)
        for pN in players_roundN:
            p1 = pN.in_round(1)

            # Store common session info in every player for easy export
            pN.chosen_level_pre = pre_level
            pN.chosen_level_post = post_level
            pN.avg_personal_norm_pre = avg_pre_personal
            pN.avg_personal_norm_post = avg_post_personal
            pN.avg_contribution_session = avg_contrib

            # --- BONUS 1: Pre Normative ---
            my_nn_pre = getattr(p1, f'nn_{pre_level}')
            if abs(my_nn_pre - avg_pre_personal) <= 3:
                pN.bonus_norm_pre = 50

            # --- BONUS 2: Pre Empirical ---
            my_emp_pre = p1.expected_contribution
            if abs(my_emp_pre - avg_contrib) <= 3:
                pN.bonus_emp_pre = 50

            # --- BONUS 3: Post Normative ---
            my_nn_post = getattr(pN, f'nn_{post_level}_post')
            if abs(my_nn_post - avg_post_personal) <= 3:
                pN.bonus_norm_post = 50

            # --- BONUS 4: Post Empirical ---
            my_emp_post = pN.expected_contribution_post
            if abs(my_emp_post - avg_contrib) <= 3:
                pN.bonus_emp_post = 50
        # ------------------------------------------------------------------
        # 5. PGG Payment: Random 3 Rounds
        # ------------------------------------------------------------------
        # Ensure we have enough rounds. If NUM_ROUNDS < 3, take all rounds.
        all_round_nums = list(range(1, C.NUM_ROUNDS + 1))
        
        # You need to pick rounds *per player* or *globally*?
        # Usually per player is fine, but oTree WaitPage runs once per group.
        # Let's pick 3 rounds RANDOMLY FOR EACH PLAYER (or same for all, up to you).
        # Standard: Random for each player.

        for pN in players_roundN:
            # 1. Pick 3 random rounds
            if C.NUM_ROUNDS >= 3:
                selected_rounds = random.sample(all_round_nums, 3)
            else:
                selected_rounds = all_round_nums # Take all if less than 3
            
            selected_rounds.sort() # e.g. [2, 5, 9]

            # 2. Calculate PGG Sum
            # Get the player's object for each of those rounds to read .payoff
            pgg_sum = 0
            for r_num in selected_rounds:
                pr = pN.in_round(r_num)
                pgg_sum += pr.payoff

            # 3. Calculate CZK (Multiplier = 3)
            pgg_czk = pgg_sum * 3

            # 4. Total Norms Bonus
            # (Assuming you set these fields in the previous norms block)
            norms_sum = (pN.bonus_norm_pre + pN.bonus_emp_pre + 
                         pN.bonus_norm_post + pN.bonus_emp_post)

             # 5. Final Total Calculation
            # Get Epper earnings from Part 1 (stored in participant vars)
            epper_payoff = pN.participant.vars.get('epper_payoff', 0)
            
            # Sum everything up
            final_total = pgg_czk + norms_sum + epper_payoff
            
            # --- DEBUG PRINT ---
            print(f"PAYMENT: PGG={pgg_czk}, Norms={norms_sum}, Epper={epper_payoff} -> Total={final_total}")

            # 6. Save to DB (Player Fields)
            pN.paid_rounds_str = str(selected_rounds)
            pN.pgg_payoff_sum = pgg_sum
            pN.pgg_czk_total = pgg_czk
            pN.norms_total_bonus = norms_sum
            pN.final_payment_czk = final_total
            
            # Save to participant.payoff for oTree admin tracking (optional but good)
            pN.participant.payoff = final_total


class ResultsSummary(Page):
    @staticmethod
    def is_displayed(player: Player):
        return player.round_number == C.NUM_ROUNDS

    @staticmethod
    def vars_for_template(player: Player):
        # We now read directly from the PLAYER object, not participant.vars
        # This is safer and ensures we see what is in the DB.
        
        # Need to fetch round 1 player for Pre guesses
        p1 = player.in_round(1)
        
        pre_level = player.chosen_level_pre
        post_level = player.chosen_level_post
        
        return dict(
            # Pre Data
            norm_pre_level = pre_level,
            avg_pre_personal = round(player.avg_personal_norm_pre, 2),
            my_nn_pre = getattr(p1, f'nn_{pre_level}'),
            bonus_norm_pre = player.bonus_norm_pre,
            
            my_emp_pre = p1.expected_contribution,
            bonus_emp_pre = player.bonus_emp_pre,
            
            # Post Data
            norm_post_level = post_level,
            avg_post_personal = round(player.avg_personal_norm_post, 2),
            my_nn_post = getattr(player, f'nn_{post_level}_post'),
            bonus_norm_post = player.bonus_norm_post,
            
            my_emp_post = player.expected_contribution_post,
            bonus_emp_post = player.bonus_emp_post,
            
            # Common
            avg_contribution_session = round(player.avg_contribution_session, 2),
            
            # Total
            total_bonus = (player.bonus_norm_pre + player.bonus_emp_pre + 
                           player.bonus_norm_post + player.bonus_emp_post),

            # PGG Info
            paid_rounds_str = player.paid_rounds_str,
            pgg_payoff_sum = player.pgg_payoff_sum,
            pgg_czk_total = player.pgg_czk_total,
            
            # Totals
            norms_total_bonus = player.norms_total_bonus,
            final_payment_czk = player.final_payment_czk               
        )

page_sequence = [
    IDInput, Part1CalculationWaitPage, Part1Summary, Intro, SetupWaitPage, Instructions,
    NormsPersonal, NormsNormative, NormsEmpirical,
    Cooperation, ResultsWaitPage, JudgeWaitPage, Judge, FinalWaitPage, Results,
    NormsPersonalPost, NormsNormativePost, NormsEmpiricalPost,
    NormsAndPaymentWaitPage, ResultsSummary
]