element-synapse/synapse/handlers/worker_lock.py
Erik Johnston 1bddd25a85
Port Clock functions to use Duration class (#19229)
This changes the arguments in clock functions to be `Duration` and
converts call sites and constants into `Duration`. There are still some
more functions around that should be converted (e.g.
`timeout_deferred`), but we leave that to another PR.

We also changes `.as_secs()` to return a float, as the rounding broke
things subtly. The only reason to keep it (its the same as
`timedelta.total_seconds()`) is for symmetry with `as_millis()`.

Follows on from https://github.com/element-hq/synapse/pull/19223
2025-12-01 13:55:06 +00:00

371 lines
12 KiB
Python

#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright 2023 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
# Originally licensed under the Apache License, Version 2.0:
# <http://www.apache.org/licenses/LICENSE-2.0>.
#
# [This file includes modifications made by New Vector Limited]
#
#
import logging
import random
from types import TracebackType
from typing import (
TYPE_CHECKING,
AsyncContextManager,
Collection,
)
from weakref import WeakSet
import attr
from twisted.internet import defer
from synapse.logging.context import PreserveLoggingContext
from synapse.logging.opentracing import start_active_span
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage.databases.main.lock import Lock, LockStore
from synapse.util.async_helpers import timeout_deferred
from synapse.util.clock import Clock
from synapse.util.duration import Duration
if TYPE_CHECKING:
from synapse.logging.opentracing import opentracing
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
# This lock is used to avoid creating an event while we are purging the room.
# We take a read lock when creating an event, and a write one when purging a room.
# This is because it is fine to create several events concurrently, since referenced events
# will not disappear under our feet as long as we don't delete the room.
NEW_EVENT_DURING_PURGE_LOCK_NAME = "new_event_during_purge_lock"
class WorkerLocksHandler:
"""A class for waiting on taking out locks, rather than using the storage
functions directly (which don't support awaiting).
"""
def __init__(self, hs: "HomeServer") -> None:
self.hs = hs # nb must be called this for @wrap_as_background_process
self._clock = hs.get_clock()
self._store = hs.get_datastores().main
self._clock = hs.get_clock()
self._notifier = hs.get_notifier()
self._instance_name = hs.get_instance_name()
# Map from lock name/key to set of `WaitingLock` that are active for
# that lock.
self._locks: dict[tuple[str, str], WeakSet[WaitingLock | WaitingMultiLock]] = {}
self._clock.looping_call(self._cleanup_locks, Duration(seconds=30))
self._notifier.add_lock_released_callback(self._on_lock_released)
def acquire_lock(self, lock_name: str, lock_key: str) -> "WaitingLock":
"""Acquire a standard lock, returns a context manager that will block
until the lock is acquired.
Note: Care must be taken to avoid deadlocks. In particular, this
function does *not* timeout.
Usage:
async with handler.acquire_lock(name, key):
# Do work while holding the lock...
"""
lock = WaitingLock(
clock=self._clock,
store=self._store,
handler=self,
lock_name=lock_name,
lock_key=lock_key,
write=None,
)
self._locks.setdefault((lock_name, lock_key), WeakSet()).add(lock)
return lock
def acquire_read_write_lock(
self,
lock_name: str,
lock_key: str,
*,
write: bool,
) -> "WaitingLock":
"""Acquire a read/write lock, returns a context manager that will block
until the lock is acquired.
Note: Care must be taken to avoid deadlocks. In particular, this
function does *not* timeout.
Usage:
async with handler.acquire_read_write_lock(name, key, write=True):
# Do work while holding the lock...
"""
lock = WaitingLock(
clock=self._clock,
store=self._store,
handler=self,
lock_name=lock_name,
lock_key=lock_key,
write=write,
)
self._locks.setdefault((lock_name, lock_key), WeakSet()).add(lock)
return lock
def acquire_multi_read_write_lock(
self,
lock_names: Collection[tuple[str, str]],
*,
write: bool,
) -> "WaitingMultiLock":
"""Acquires multi read/write locks at once, returns a context manager
that will block until all the locks are acquired.
This will try and acquire all locks at once, and will never hold on to a
subset of the locks. (This avoids accidentally creating deadlocks).
Note: Care must be taken to avoid deadlocks. In particular, this
function does *not* timeout.
"""
lock = WaitingMultiLock(
lock_names=lock_names,
write=write,
clock=self._clock,
store=self._store,
handler=self,
)
for lock_name, lock_key in lock_names:
self._locks.setdefault((lock_name, lock_key), WeakSet()).add(lock)
return lock
def notify_lock_released(self, lock_name: str, lock_key: str) -> None:
"""Notify that a lock has been released.
Pokes both the notifier and replication.
"""
self._notifier.notify_lock_released(self._instance_name, lock_name, lock_key)
def _on_lock_released(
self, instance_name: str, lock_name: str, lock_key: str
) -> None:
"""Called when a lock has been released.
Wakes up any locks that might be waiting on this.
"""
locks = self._locks.get((lock_name, lock_key))
if not locks:
return
def _wake_all_locks(
locks: Collection[WaitingLock | WaitingMultiLock],
) -> None:
for lock in locks:
lock.release_lock()
self._clock.call_later(
Duration(seconds=0),
_wake_all_locks,
locks,
)
@wrap_as_background_process("_cleanup_locks")
async def _cleanup_locks(self) -> None:
"""Periodically cleans out stale entries in the locks map"""
self._locks = {key: value for key, value in self._locks.items() if value}
@attr.s(auto_attribs=True, eq=False)
class WaitingLock:
clock: Clock
store: LockStore
handler: WorkerLocksHandler
lock_name: str
lock_key: str
write: bool | None
deferred: "defer.Deferred[None]" = attr.Factory(defer.Deferred)
_inner_lock: Lock | None = None
_retry_interval: float = 0.1
_lock_span: "opentracing.Scope" = attr.Factory(
lambda: start_active_span("WaitingLock.lock")
)
def release_lock(self) -> None:
"""Release the lock (by resolving the deferred)"""
if not self.deferred.called:
with PreserveLoggingContext():
self.deferred.callback(None)
async def __aenter__(self) -> None:
self._lock_span.__enter__()
with start_active_span("WaitingLock.waiting_for_lock"):
while self._inner_lock is None:
self.deferred = defer.Deferred()
if self.write is not None:
lock = await self.store.try_acquire_read_write_lock(
self.lock_name, self.lock_key, write=self.write
)
else:
lock = await self.store.try_acquire_lock(
self.lock_name, self.lock_key
)
if lock:
self._inner_lock = lock
break
try:
# Wait until the we get notified the lock might have been
# released (by the deferred being resolved). We also
# periodically wake up in case the lock was released but we
# weren't notified.
with PreserveLoggingContext():
timeout = self._get_next_retry_interval()
await timeout_deferred(
deferred=self.deferred,
timeout=timeout,
clock=self.clock,
)
except Exception:
pass
return await self._inner_lock.__aenter__()
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc: BaseException | None,
tb: TracebackType | None,
) -> bool | None:
assert self._inner_lock
self.handler.notify_lock_released(self.lock_name, self.lock_key)
try:
r = await self._inner_lock.__aexit__(exc_type, exc, tb)
finally:
self._lock_span.__exit__(exc_type, exc, tb)
return r
def _get_next_retry_interval(self) -> float:
next = self._retry_interval
self._retry_interval = max(5, next * 2)
if self._retry_interval > Duration(minutes=10).as_secs(): # >7 iterations
logger.warning(
"Lock timeout is getting excessive: %ss. There may be a deadlock.",
self._retry_interval,
)
return next * random.uniform(0.9, 1.1)
@attr.s(auto_attribs=True, eq=False)
class WaitingMultiLock:
lock_names: Collection[tuple[str, str]]
write: bool
clock: Clock
store: LockStore
handler: WorkerLocksHandler
deferred: "defer.Deferred[None]" = attr.Factory(defer.Deferred)
_inner_lock_cm: AsyncContextManager | None = None
_retry_interval: float = 0.1
_lock_span: "opentracing.Scope" = attr.Factory(
lambda: start_active_span("WaitingLock.lock")
)
def release_lock(self) -> None:
"""Release the lock (by resolving the deferred)"""
if not self.deferred.called:
with PreserveLoggingContext():
self.deferred.callback(None)
async def __aenter__(self) -> None:
self._lock_span.__enter__()
with start_active_span("WaitingLock.waiting_for_lock"):
while self._inner_lock_cm is None:
self.deferred = defer.Deferred()
lock_cm = await self.store.try_acquire_multi_read_write_lock(
self.lock_names, write=self.write
)
if lock_cm:
self._inner_lock_cm = lock_cm
break
try:
# Wait until the we get notified the lock might have been
# released (by the deferred being resolved). We also
# periodically wake up in case the lock was released but we
# weren't notified.
with PreserveLoggingContext():
timeout = self._get_next_retry_interval()
await timeout_deferred(
deferred=self.deferred,
timeout=timeout,
clock=self.clock,
)
except Exception:
pass
assert self._inner_lock_cm
await self._inner_lock_cm.__aenter__()
return
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc: BaseException | None,
tb: TracebackType | None,
) -> bool | None:
assert self._inner_lock_cm
for lock_name, lock_key in self.lock_names:
self.handler.notify_lock_released(lock_name, lock_key)
try:
r = await self._inner_lock_cm.__aexit__(exc_type, exc, tb)
finally:
self._lock_span.__exit__(exc_type, exc, tb)
return r
def _get_next_retry_interval(self) -> float:
next = self._retry_interval
self._retry_interval = max(5, next * 2)
if self._retry_interval > Duration(minutes=10).as_secs(): # >7 iterations
logger.warning(
"Lock timeout is getting excessive: %ss. There may be a deadlock.",
self._retry_interval,
)
return next * random.uniform(0.9, 1.1)