import asyncio import time from typing import Dict, Any import asyncio import otree.common from otree.channels import utils as channel_utils from otree.models import Participant, BasePlayer, BaseGroup from otree.lookup import get_page_lookup import logging from otree.database import NoResultFound import inspect logger = logging.getLogger(__name__) def _get_live_page_entities(participant_code): try: participant = Participant.objects_get(code=participant_code) except NoResultFound: logger.warning(f'Participant not found: {participant_code}') return lookup = get_page_lookup(participant._session_code, participant._index_in_pages) app_name = lookup.app_name models_module = otree.common.get_models_module(app_name) PageClass = lookup.page_class return app_name, lookup, models_module, PageClass, participant def live_method_is_async(participant_code): app_name, lookup, models_module, PageClass, participant = _get_live_page_entities( participant_code ) return inspect.isasyncgenfunction(PageClass.live_method) async def live_payload_function(participant_code, page_name, payload): app_name, lookup, models_module, PageClass, participant = _get_live_page_entities( participant_code ) # this could be incorrect if the player advances right after liveSend is executed. # maybe just return if it doesn't match. (but leave it in for now and see how much that occurs, # don't want silent failures.) if page_name != PageClass.__name__: logger.warning( f'Ignoring liveSend message from {participant_code} because ' f'they are on page {PageClass.__name__}, not {page_name}.' ) return player = models_module.Player.objects_get( round_number=lookup.round_number, participant=participant ) # it makes sense to check the group first because # if the player forgot to define it on the Player, # we shouldn't fall back to checking the group. you could get an error like # 'Group' has no attribute 'live_auction' which would be confusing. # also, we need this 'group' object anyway. # and this is a good place to show the deprecation warning. group = player.group live_method = PageClass.live_method Player: BasePlayer = models_module.Player pcodes_dict = { d[0]: d[1] for d in Player.objects_filter(group=group) .join(Participant) .with_entities( Player.id_in_group, Participant.code, ) } async for retval in call_live_method_compat(live_method, player, payload): # we should require a return, otherwise the user might forget it. # but this breaks backward compat. if not retval: return if not isinstance(retval, dict): msg = f'live method must return a dict' raise LiveMethodBadReturnValue(msg) if 0 in retval: if len(retval) > 1: raise LiveMethodBadReturnValue( 'If dict returned by live_method has key 0, it must not contain any other keys' ) else: for pid in retval: if pid not in pcodes_dict: msg = f'live_method has invalid return value. No player with id_in_group={repr(pid)}' raise LiveMethodBadReturnValue(msg) pcode_retval = {} for pid, pcode in pcodes_dict.items(): payload = retval.get(pid, retval.get(0)) if payload is not None: pcode_retval[pcode] = { 'otree_success': True, 'live_method_payload': payload, } await _live_send_back( participant._session_code, participant._index_in_pages, pcode_retval ) class LiveMethodBadReturnValue(Exception): pass async def _live_send_back(session_code, page_index, pcode_retval): '''separate function for easier patching''' tasks = [ channel_utils.group_send( group=channel_utils.live_group(session_code, page_index, pcode), data=retval, ) for pcode, retval in pcode_retval.items() ] if tasks: await asyncio.gather(*tasks, return_exceptions=True) ASYNC_OR_YIELD_ISSUE_MSG = "live_method can only be (a) a regular function that returns a value, or (b) an async generator function." async def call_live_method_compat(live_method, player, payload): """ before the "compat" referred to string vs method. now it refers to yield/return style """ # locally the server is not shutting down properly when i do Ctrl+C, # but that was a pre-existing issue with my local dev setup, apparently. if inspect.isasyncgenfunction(live_method): # Case 1: async generator function (async def with yield) try: async for item in live_method(player, payload): yield item except asyncio.CancelledError: # Handle graceful cancellation when server shuts down # this is needed to allow Ctrl+C to stop the current tasks. return elif inspect.iscoroutinefunction(live_method): # Async function (async def without yield) - not allowed raise LiveMethodBadReturnValue(ASYNC_OR_YIELD_ISSUE_MSG) else: # Case 2: regular function result = live_method(player, payload) if inspect.isgenerator(result): raise LiveMethodBadReturnValue(ASYNC_OR_YIELD_ISSUE_MSG) yield result _USER_LOCK_MAX_IDLE_TIME = 3600 class ParticipantLockManager: def __init__(self): self._locks: Dict[str, asyncio.Lock] = {} self._last_used: Dict[str, float] = {} def get_lock(self, pcode: str) -> asyncio.Lock: now = time.time() # Clean on every request (very lightweight) self._cleanup_expired(now) # Get or create lock if pcode not in self._locks: self._locks[pcode] = asyncio.Lock() self._last_used[pcode] = now return self._locks[pcode] def _cleanup_expired(self, now: float): expired = [ pcode for pcode, last_used in self._last_used.items() if ( now - last_used > _USER_LOCK_MAX_IDLE_TIME and not self._locks[pcode].locked() ) ] for pcode in expired: del self._locks[pcode] del self._last_used[pcode] _lock_manager = ParticipantLockManager() def get_participant_scoped_lock(pcode: str) -> asyncio.Lock: return _lock_manager.get_lock(pcode)