from collections import defaultdict from decimal import Decimal from functools import reduce from typing import List, Dict from otree import settings from otree.database import db, dbq from otree import common from otree.common import ( get_models_module, get_builtin_constant, validate_alphanumeric, get_bots_module, get_constants, ) from otree.currency import RealWorldCurrency from otree.models import Participant, Session from otree.constants import BaseConstants, get_roles, get_role def gcd(a, b): """Return greatest common divisor using Euclid's Algorithm.""" while b: a, b = b, a % b return a def lcm(a, b): """Return lowest common multiple.""" return a * b // gcd(a, b) def lcmm(*args): """Return lcm of args.""" return reduce(lcm, args) class SessionConfigError(Exception): pass NON_EDITABLE_FIELDS = { 'name', 'display_name', 'app_sequence', 'num_demo_participants', 'doc', } class SessionConfig(dict): # convenient access @property def app_sequence(self) -> list: return self['app_sequence'] @property def participation_fee(self) -> RealWorldCurrency: return self['participation_fee'] def get_lcm(self): min_multiple_list = [] for app_name in self['app_sequence']: # if players_per_group is None, 0, etc. min_multiple = get_builtin_constant(app_name, 'players_per_group') or 1 min_multiple_list.append(min_multiple) return lcmm(*min_multiple_list) def get_num_bot_cases(self): num_cases = 1 for app_name in self['app_sequence']: bots_module = get_bots_module(app_name) cases = bots_module.PlayerBot.cases num_cases = max(num_cases, len(cases)) return num_cases def clean(self): for k in ['participation_fee', 'real_world_currency_per_point']: if k not in self: msg = f'SESSION_CONFIG_DEFAULTS is missing {k}' raise SessionConfigError(msg) for k in ['name', 'app_sequence', 'num_demo_participants']: if k not in self: msg = f'Session config is missing {k}' raise SessionConfigError(msg) validate_alphanumeric( self['name'], identifier_description='settings.SESSION_CONFIGS name' ) app_sequence = self['app_sequence'] if len(app_sequence) != len(set(app_sequence)): msg = ( 'settings.SESSION_CONFIGS: ' 'app_sequence of "{}" ' 'must not contain duplicate elements. ' 'If you want multiple rounds, ' 'you should set num_rounds.' ) raise SessionConfigError(msg.format(self['name'])) if len(app_sequence) == 0: msg = 'settings.SESSION_CONFIGS: app_sequence cannot be empty.' raise SessionConfigError(msg) self.setdefault('display_name', self['name']) self.setdefault('doc', '') self['participation_fee'] = RealWorldCurrency(self['participation_fee']) def app_sequence_display(self): app_sequence = [] for app_name in self['app_sequence']: num_rounds = get_builtin_constant(app_name, 'num_rounds') models_module = get_models_module(app_name) if num_rounds > 1: formatted_app_name = '{} ({} rounds)'.format(app_name, num_rounds) else: formatted_app_name = app_name subsssn = { 'doc': getattr(models_module, 'doc', ''), 'name': formatted_app_name, } app_sequence.append(subsssn) return app_sequence def builtin_editable_fields(self): fields = ['participation_fee'] if settings.USE_POINTS: fields.append('real_world_currency_per_point') return fields def custom_editable_fields(self): # should there also be some restriction on # what chars are allowed? because maybe not all chars work # in an HTML form field (e.g. periods, quotes, etc) # so far, it seems any char works OK, even without escaping # before making an HTML attribute. even '>漢 ."&' # so i'll just put a general recommendation in the docs return [ k for k, v in self.items() if k not in NON_EDITABLE_FIELDS and k not in self.builtin_editable_fields() and type(v) in [bool, int, float, str] ] def editable_fields(self): return self.builtin_editable_fields() + self.custom_editable_fields() def html_field_name(self, field_name): return '{}.{}'.format(self['name'], field_name) def editable_field_html(self, field_name): existing_value = self[field_name] html_field_name = self.html_field_name(field_name) base_attrs = ["name='{}'".format(html_field_name)] if isinstance(existing_value, bool): attrs = [ "type='checkbox'", 'checked' if existing_value else '', # don't use class=form-control because it looks too big, # like it's intended for mobile devices ] elif isinstance(existing_value, int): attrs = [ "type='number'", "required", "step='1'", "value='{}'".format(existing_value), "class='form-control'", ] elif isinstance(existing_value, (float, Decimal)): # convert to float, e.g. participation_fee attrs = [ "class='form-control'", "type='number'", "step='any'", "required", "value='{}'".format(float(existing_value)), ] elif isinstance(existing_value, str): attrs = [ "type='text'", "value='{}'".format(existing_value), "class='form-control'", ] html = ''' {} '''.format( field_name, ' '.join(base_attrs + attrs) ) return html def builtin_editable_fields_html(self): return [self.editable_field_html(k) for k in self.builtin_editable_fields()] def custom_editable_fields_html(self): return [self.editable_field_html(k) for k in self.custom_editable_fields()] def get_session_configs_dict( SESSION_CONFIGS: List[Dict], SESSION_CONFIG_DEFAULTS: Dict ): SESSION_CONFIGS_DICT = {} for config_dict in SESSION_CONFIGS: config_obj = SessionConfig(SESSION_CONFIG_DEFAULTS) config_obj.update(config_dict) config_obj.clean() config_name = config_dict['name'] if config_name in SESSION_CONFIGS_DICT: msg = f"Duplicate SESSION_CONFIG name: {config_name}" raise SessionConfigError(msg) SESSION_CONFIGS_DICT[config_name] = config_obj return SESSION_CONFIGS_DICT SESSION_CONFIGS_DICT = get_session_configs_dict( settings.SESSION_CONFIGS, settings.SESSION_CONFIG_DEFAULTS ) class CreateSessionInvalidArgs(ValueError): pass def create_session( session_config_name, *, num_participants, label='', room_name=None, is_mturk=False, is_demo=False, modified_session_config_fields=None, ) -> Session: num_subsessions = 0 try: session_config = SESSION_CONFIGS_DICT[session_config_name] except KeyError: msg = 'Session config "{}" not found in settings.SESSION_CONFIGS.' raise CreateSessionInvalidArgs(msg.format(session_config_name)) from None else: # copy so that we don't mutate the original # .copy() returns a dict, so need to convert back to SessionConfig session_config = SessionConfig(session_config.copy()) modified_config = modified_session_config_fields or {} # this is for API. don't want to mislead people # to put stuff in the session config that should be in the session. bad_keys = modified_config.keys() & NON_EDITABLE_FIELDS if bad_keys: raise CreateSessionInvalidArgs( f'The following session config fields are not editable: {bad_keys}' ) session_config.update(modified_config) # check validity and converts serialized decimal & currency values # back to their original data type (because they were serialized # when passed through channels session_config.clean() # check that it divides evenly session_lcm = session_config.get_lcm() if num_participants is None: # most games are multiplayer, so if it's under 2, we bump it to 2 num_participants = max(session_lcm, 2) else: if num_participants % session_lcm: msg = ( 'Session Config {}: Number of participants ({}) is not a multiple ' 'of group size ({})' ).format(session_config['name'], num_participants, session_lcm) raise CreateSessionInvalidArgs(msg) session = Session( config=session_config, label=label, is_demo=is_demo, num_participants=num_participants, is_mturk=is_mturk, ) db.add(session) # i think the .commit() is necessary for the object to have a PK, so that FKs can work, # etc. db.commit() try: session_code = session.code participants = [ Participant( id_in_session=id_in_session, session=session, _session_code=session_code, ) for id_in_session in list(range(1, num_participants + 1)) ] db.add_all(participants) db.commit() # participant_values = ( # db.query(Participant) # .filter(Session.id == session.id) # .order_by('id') # .with_entities(Participant.id, Participant.code) # ).all() participant_values = ( db.query(Participant) .join(Session) .filter(Session.id == session.id) .order_by(Participant.id) .with_entities(Participant.id, Participant.code) ).all() num_pages = 0 for app_name in session_config['app_sequence']: views_module = common.get_pages_module(app_name) models_module = get_models_module(app_name) num_rounds = get_builtin_constant(app_name, 'num_rounds') num_subsessions += num_rounds round_numbers = list(range(1, num_rounds + 1)) num_pages += num_rounds * len(views_module.page_sequence) Subsession = models_module.Subsession Group = models_module.Group Player = models_module.Player Constants = get_constants(app_name) subsessions = [ Subsession(round_number=round_number, session=session) for round_number in round_numbers ] db.add_all(subsessions) db.commit() subsessions = ( dbq(Subsession) .filter_by(session=session) .order_by('round_number') .with_entities('id', 'round_number') ) ppg = Constants.get_normalized('players_per_group') if ppg is None or Subsession._has_group_by_arrival_time(): ppg = num_participants num_groups_per_round = int(num_participants / ppg) groups_to_create = [] for ss_id, ss_rd in subsessions: for id_in_subsession in range(1, num_groups_per_round + 1): groups_to_create.append( Group( session=session, subsession_id=ss_id, round_number=ss_rd, id_in_subsession=id_in_subsession, ) ) db.add_all(groups_to_create) groups = ( dbq(Group).filter_by(session=session).order_by('id_in_subsession') ).all() groups_lookup = defaultdict(list) for group in groups: groups_lookup[group.subsession_id].append(group.id) players_to_create = [] for ss_id, ss_rd in subsessions: roles = get_roles(Constants) participant_index = 0 for group_id in groups_lookup[ss_id]: for id_in_group in range(1, ppg + 1): participant = participant_values[participant_index] players_to_create.append( Player( session=session, subsession_id=ss_id, round_number=ss_rd, participant_id=participant[0], group_id=group_id, id_in_group=id_in_group, _role=get_role(roles, id_in_group), ) ) participant_index += 1 # Create players db.add_all(players_to_create) dbq(Participant).filter_by(session=session).update( {Participant._max_page_index: num_pages} ) for subsession in session.get_subsessions(): target = subsession.get_user_defined_target() func = getattr(target, 'creating_session', None) if func: func(subsession) session._set_admin_report_app_names() if room_name is not None: from otree.room import ROOM_DICT room = ROOM_DICT[room_name] room.set_session(session) db.commit() return session except Exception: # another way would be to look into nested transactions, # but this seems simpler. db.delete(session) raise class CreateSessionError(Exception): pass def create_session_traceback_wrapper(**kwargs): ''' catch it at an inner level, so we can give smaller tracebacks on 'creating session' page ''' try: return create_session(**kwargs) except Exception as exc: raise CreateSessionError from exc