from starlette.middleware.base import BaseHTTPMiddleware from starlette.middleware import Middleware from starlette.middleware.sessions import SessionMiddleware import time from starlette.requests import Request import logging from otree.database import db, NEW_IDMAP_EACH_REQUEST from otree.common import _SECRET, lock import asyncio import threading from contextlib import nullcontext logger = logging.getLogger('otree.perf') lock2 = asyncio.Lock() # URL paths that should NOT use the global lock # Populated from routes after they're built to avoid duplicating URL patterns PATHS_WITHOUT_LOCK = None def _get_paths_without_lock(): """Lazy load paths from routes to avoid circular import.""" global PATHS_WITHOUT_LOCK if PATHS_WITHOUT_LOCK is None: from otree.urls import get_paths_without_lock PATHS_WITHOUT_LOCK = get_paths_without_lock() return PATHS_WITHOUT_LOCK class CommitTransactionMiddleware(BaseHTTPMiddleware): async def dispatch(self, request, call_next): path = request.url.path # Check if this path should opt out of the global lock use_lock = not any(path.startswith(prefix) for prefix in _get_paths_without_lock()) lock_context = lock2 if use_lock else nullcontext() async with lock_context: if NEW_IDMAP_EACH_REQUEST: db.new_session() response = await call_next(request) if response.status_code < 500: db.commit() else: db.rollback() return response class PerfMiddleware(BaseHTTPMiddleware): async def dispatch(self, request, call_next): start = time.time() response = await call_next(request) # heroku has 'X-Request-ID' request_id = request.headers.get('X-Request-ID') if request_id: # only log this info on Heroku elapsed = time.time() - start msec = int(elapsed * 1000) msg = f'own_time={msec}ms request_id={request_id}' logger.info(msg) return response