from otree.api import *
import random
import math

doc = """
PGG with judges, subgroups divisible by 4, reshuffle each round.
"""


# =====================================================
# KONŠTANTY
# =====================================================
class C(BaseConstants):
    NAME_IN_URL = 'base'
    PLAYERS_PER_GROUP = None
    NUM_ROUNDS = 3
    ENDOWMENT = 20
    MPCR = 0.4


# =====================================================
# MODELY
# =====================================================
class Subsession(BaseSubsession):
    pass


class Group(BaseGroup):
    pass


class Player(BasePlayer):
    is_judge = models.BooleanField(initial=False)
    subgroup = models.IntegerField(initial=0)
    pgg_group = models.IntegerField(initial=0)

    # PGG fields
    contribution = models.IntegerField(min=0, max=20, initial=0)
    payoff_pre = models.FloatField(initial=0)  # for debugging / display


# =====================================================
# STRÁNKY
# =====================================================
class Intro(Page):

    @staticmethod
    def is_displayed(player: Player):
        # Only in round 1
        return player.round_number == 1

    @staticmethod
    def vars_for_template(player: Player):
        session = player.session
        return dict(
            treatment=session.config.get('treatment_type')
        )


class SetupWaitPage(WaitPage):

    @staticmethod
    def after_all_players_arrive(group: Group):
        # In a regular WaitPage, after_all_players_arrive gets group
        subsession = group.subsession
        players = subsession.get_players()
        session = subsession.session
        treatment = session.config.get('treatment_type')

        # =========================================
        # 1️⃣ IBA V 1. KOLE: nastav sudcov a subgroup
        # =========================================
        if subsession.round_number == 1:

            # reset
            for p in players:
                p.participant.vars['is_judge'] = False
                p.participant.vars['subgroup'] = 0

            if treatment in ['human_judge', 'humanAI_judge']:

                for p in players:
                    if p.id_in_subsession == 1:
                        p.participant.vars['is_judge'] = True
                        p.participant.vars['subgroup'] = 1
                    elif p.id_in_subsession == 2:
                        p.participant.vars['is_judge'] = True
                        p.participant.vars['subgroup'] = 2

                normal_players = [
                    p for p in players
                    if not p.participant.vars['is_judge']
                ]

                # --- subgroup deliteľné 4 ---
                N = len(normal_players)
                valid_splits = []

                for a in range(4, N, 4):
                    b = N - a
                    if b >= 4 and b % 4 == 0:
                        valid_splits.append((a, b))

                if not valid_splits:
                    raise ValueError("Normálnych hráčov nemožno rozdeliť na subgroup deliteľné 4.")

                a, b = min(valid_splits, key=lambda x: abs(x[0] - x[1]))

                random.shuffle(normal_players)

                for p in normal_players[:a]:
                    p.participant.vars['subgroup'] = 1
                for p in normal_players[a:a + b]:
                    p.participant.vars['subgroup'] = 2

            else:
                for p in players:
                    p.participant.vars['subgroup'] = 1

        # =========================================
        # 2️⃣ KAŽDÉ KOLO: skopíruj FIXNÉ hodnoty
        # =========================================
        for p in players:
            p.is_judge = p.participant.vars.get('is_judge', False)
            p.subgroup = p.participant.vars.get('subgroup', 0)
            p.pgg_group = 0  # reset každý round

        # =========================================
        # 3️⃣ KAŽDÉ KOLO: reshuffle PGG v subgroup
        # =========================================
        normal_players = [p for p in players if not p.is_judge]

        global_group_id = 1

        for sg in sorted(set(p.subgroup for p in normal_players)):
            subgroup_players = [p for p in normal_players if p.subgroup == sg]
            random.shuffle(subgroup_players)

            for i in range(0, len(subgroup_players), 4):
                for p in subgroup_players[i:i + 4]:
                    p.pgg_group = global_group_id
                global_group_id += 1


class Cooperation(Page):

    form_model = 'player'
    form_fields = ['contribution']

    @staticmethod
    def is_displayed(player: Player):
        # Judges do not contribute
        return not player.is_judge

    @staticmethod
    def vars_for_template(player: Player):
        return dict(
            endowment=C.ENDOWMENT,
            mpcr=C.MPCR,
        )


class ResultsWaitPage(WaitPage):

    @staticmethod
    def after_all_players_arrive(group: Group):
        # after_all_players_arrive on a normal WaitPage receives group
        subsession = group.subsession
        players = subsession.get_players()

        # Work only with non-judge players for the PGG
        normal_players = [p for p in players if not p.is_judge]

        # First: compute payoffs for normal players by pgg_group
        groups_dict = {}
        for p in normal_players:
            groups_dict.setdefault(p.pgg_group, []).append(p)

        for g_id, members in groups_dict.items():
            # group size should be 4 by design
            total_contribution = sum(m.contribution for m in members)
            for m in members:
                payoff_pre = C.ENDOWMENT - m.contribution + C.MPCR * total_contribution
                m.payoff_pre = payoff_pre
                m.payoff = math.ceil(payoff_pre)  # integer, stored into oTree payoff

        # NEW: Compute subgroup payoffs for judges
        # Group normal players by subgroup and compute average payoff
        subgroup_payoffs = {}
        for p in normal_players:
            sg = p.subgroup
            subgroup_payoffs.setdefault(sg, []).append(p.payoff)
        
        # Calculate average payoff per subgroup (rounded up)
        subgroup_avg_payoffs = {}
        for sg, payoffs in subgroup_payoffs.items():
            avg_payoff = sum(payoffs) / len(payoffs)
            subgroup_avg_payoffs[sg] = math.ceil(avg_payoff)

        # Assign to judges
        for p in players:
            if p.is_judge:
                sg = p.subgroup
                p.payoff = subgroup_avg_payoffs[sg]
                p.payoff_pre = subgroup_avg_payoffs[sg]  # for consistency in display

        # Handle case where subgroup might be empty (edge case)
        for p in players:
            if p.is_judge and (p.subgroup not in subgroup_avg_payoffs or subgroup_avg_payoffs[p.subgroup] == 0):
                p.payoff = 0
                p.payoff_pre = 0

class Results(Page):

    @staticmethod
    def vars_for_template(player: Player):
        subsession = player.subsession
        players = subsession.get_players()

        if player.is_judge:
            # NEW: Show judge payoff (subgroup average)
            normal_players_in_subgroup = [
                pl for pl in players
                if (not pl.is_judge) and pl.subgroup == player.subgroup
            ]
            num_normal_in_subgroup = len(normal_players_in_subgroup)
            
            return dict(
                is_judge=True,
                round_number=player.round_number,
                subgroup=player.subgroup,
                round_payoff=player.payoff,
                payoff_pre=player.payoff_pre,
                num_normal_in_subgroup=num_normal_in_subgroup,
            )

        subsession = player.subsession
        players = subsession.get_players()

        # Get all non-judge players in the same pgg_group
        group_members = [
            pl for pl in players
            if (not pl.is_judge) and pl.pgg_group == player.pgg_group
        ]
        total_contribution = sum(pl.contribution for pl in group_members)

        # Create anonymous labels A, B, C, D for group members
        anonymous_members = []
        for i, pl in enumerate(group_members):
            label = chr(65 + i)  # A=65, B=66, C=67, D=68
            anonymous_members.append({
                'label': label,
                'contribution': pl.contribution,
            })

        return dict(
            is_judge=False,
            round_number=player.round_number,
            my_contribution=player.contribution,
            total_contribution=total_contribution,
            payoff_pre=player.payoff_pre,
            round_payoff=player.payoff,
            anonymous_members=anonymous_members,
        )


class SetupInfo(Page):

    @staticmethod
    def vars_for_template(player: Player):
        subsession = player.subsession
        session = player.session
        return dict(
            round=player.round_number,
            treatment=session.config.get('treatment_type'),
            players=subsession.get_players(),
        )


# =====================================================
# SEKVENCIA
# =====================================================
page_sequence = [
    Intro,
    SetupWaitPage,
    Cooperation,
    ResultsWaitPage,
    Results,
    SetupInfo,  # optional: keep as debug at the end of round
]
