From d372ab3280c0bdb54795a3e9cc7a6470f5d3218c Mon Sep 17 00:00:00 2001 From: Mathieu Velten Date: Thu, 8 Jan 2026 19:21:24 +0100 Subject: [PATCH] Add cancel_task API to the task scheduler (#19310) Co-authored-by: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> --- changelog.d/19310.misc | 1 + docs/admin_api/scheduled_tasks.md | 3 +- synapse/replication/tcp/commands.py | 13 +++ synapse/replication/tcp/handler.py | 14 ++- synapse/rest/admin/scheduled_tasks.py | 7 +- synapse/types/__init__.py | 2 + synapse/util/task_scheduler.py | 41 ++++++++- tests/rest/admin/test_scheduled_tasks.py | 2 +- tests/util/test_task_scheduler.py | 105 ++++++++++++++++++++++- 9 files changed, 178 insertions(+), 10 deletions(-) create mode 100644 changelog.d/19310.misc diff --git a/changelog.d/19310.misc b/changelog.d/19310.misc new file mode 100644 index 0000000000..5080d7d985 --- /dev/null +++ b/changelog.d/19310.misc @@ -0,0 +1 @@ +Add an internal `cancel_task` API to the task scheduler. diff --git a/docs/admin_api/scheduled_tasks.md b/docs/admin_api/scheduled_tasks.md index b80da5083c..949a03ee39 100644 --- a/docs/admin_api/scheduled_tasks.md +++ b/docs/admin_api/scheduled_tasks.md @@ -36,9 +36,10 @@ It returns a JSON body like the following: - "scheduled" - Task is scheduled but not active - "active" - Task is active and probably running, and if not will be run on next scheduler loop run - "complete" - Task has completed successfully + - "cancelled" - Task has been cancelled - "failed" - Task is over and either returned a failed status, or had an exception -* `max_timestamp`: int - Is optional. Returns only the scheduled tasks with a timestamp inferior to the specified one. +* `max_timestamp`: int - Is optional. Returns only the scheduled tasks with a timestamp (in milliseconds since the unix epoch) inferior to the specified one. **Response** diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py index 79194f7275..0c85fb36fc 100644 --- a/synapse/replication/tcp/commands.py +++ b/synapse/replication/tcp/commands.py @@ -505,6 +505,18 @@ class NewActiveTaskCommand(_SimpleCommand): NAME = "NEW_ACTIVE_TASK" +class CancelTaskCommand(_SimpleCommand): + """Sent to inform the instance handling background tasks that a task + has been cancelled and should be terminated. + + Format:: + + CANCEL_TASK "" + """ + + NAME = "CANCEL_TASK" + + _COMMANDS: tuple[type[Command], ...] = ( ServerCommand, RdataCommand, @@ -520,6 +532,7 @@ _COMMANDS: tuple[type[Command], ...] = ( ClearUserSyncsCommand, LockReleasedCommand, NewActiveTaskCommand, + CancelTaskCommand, ) # Map of command name to command type. diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index 05370045e6..087c87545e 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -35,6 +35,7 @@ from twisted.internet.protocol import ReconnectingClientFactory from synapse.metrics import SERVER_NAME_LABEL, LaterGauge from synapse.replication.tcp.commands import ( + CancelTaskCommand, ClearUserSyncsCommand, Command, FederationAckCommand, @@ -746,10 +747,17 @@ class ReplicationCommandHandler: def on_NEW_ACTIVE_TASK( self, conn: IReplicationConnection, cmd: NewActiveTaskCommand ) -> None: - """Called when get a new NEW_ACTIVE_TASK command.""" + """Called when we get a new NEW_ACTIVE_TASK command.""" if self._task_scheduler: self._task_scheduler.on_new_task(cmd.data) + async def on_CANCEL_TASK( + self, conn: IReplicationConnection, cmd: CancelTaskCommand + ) -> None: + """Called when we get a new CANCEL_TASK command.""" + if self._task_scheduler: + await self._task_scheduler.on_cancel_task(cmd.data) + def new_connection(self, connection: IReplicationConnection) -> None: """Called when we have a new connection.""" self._connections.append(connection) @@ -872,6 +880,10 @@ class ReplicationCommandHandler: """Called when a new task has been scheduled for immediate launch and is ACTIVE.""" self.send_command(NewActiveTaskCommand(task_id)) + def send_cancel_task(self, task_id: str) -> None: + """Called when a scheduled task has been cancelled and should be terminated.""" + self.send_command(CancelTaskCommand(task_id)) + UpdateToken = TypeVar("UpdateToken") UpdateRow = TypeVar("UpdateRow") diff --git a/synapse/rest/admin/scheduled_tasks.py b/synapse/rest/admin/scheduled_tasks.py index 41c402b424..5b3526c7e5 100644 --- a/synapse/rest/admin/scheduled_tasks.py +++ b/synapse/rest/admin/scheduled_tasks.py @@ -41,7 +41,12 @@ class ScheduledTasksRestServlet(RestServlet): # extract query params action_name = parse_string(request, "action_name") resource_id = parse_string(request, "resource_id") - status = parse_string(request, "job_status") + status = parse_string(request, "status") + # This parameter was historically called `job_status`, while the Admin API docs + # defined it as `status`. We now support both, as `status` is generally + # a nicer name. A v2 of this endpoint should keep only `status`. + if status is None: + status = parse_string(request, "job_status") max_timestamp = parse_integer(request, "max_timestamp") actions = [action_name] if action_name else None diff --git a/synapse/types/__init__.py b/synapse/types/__init__.py index 16892b37c0..99eefb8acb 100644 --- a/synapse/types/__init__.py +++ b/synapse/types/__init__.py @@ -1546,6 +1546,8 @@ class TaskStatus(str, Enum): COMPLETE = "complete" # Task is over and either returned a failed status, or had an exception FAILED = "failed" + # Task has been cancelled + CANCELLED = "cancelled" @attr.s(auto_attribs=True, frozen=True, slots=True) diff --git a/synapse/util/task_scheduler.py b/synapse/util/task_scheduler.py index 353ddb70bc..e5cfc85a37 100644 --- a/synapse/util/task_scheduler.py +++ b/synapse/util/task_scheduler.py @@ -22,6 +22,7 @@ import logging from typing import TYPE_CHECKING, Awaitable, Callable +from twisted.internet import defer from twisted.python.failure import Failure from synapse.logging.context import ( @@ -111,7 +112,8 @@ class TaskScheduler: self.server_name = hs.hostname self._store = hs.get_datastores().main self._clock = hs.get_clock() - self._running_tasks: set[str] = set() + # A map between a task's ID and a deferred linked to the task + self._running_tasks: dict[str, defer.Deferred] = {} # A map between action names and their registered function self._actions: dict[ str, @@ -325,6 +327,37 @@ class TaskScheduler: raise Exception(f"Task {id} is currently ACTIVE and can't be deleted") await self._store.delete_scheduled_task(id) + async def cancel_task(self, id: str) -> None: + """Cancel an ACTIVE or SCHEDULED task. + + Args: + id: id of the task to cancel + """ + task = await self.get_task(id) + if not task: + logger.debug("Can't cancel task %s because it doesn't exist in the DB", id) + return + + if not ( + task.status == TaskStatus.ACTIVE or task.status == TaskStatus.SCHEDULED + ): + logger.debug( + "Can't cancel task %s because it is neither ACTIVE nor SCHEDULED", id + ) + return + + if self._run_background_tasks: + await self.on_cancel_task(id) + else: + self.hs.get_replication_command_handler().send_cancel_task(id) + + async def on_cancel_task(self, id: str) -> None: + if id in self._running_tasks: + deferred = self._running_tasks[id] + deferred.cancel() + self._running_tasks.pop(id) + await self.update_task(id, status=TaskStatus.CANCELLED) + def on_new_task(self, task_id: str) -> None: """Handle a notification that a new ready-to-run task has been added to the queue""" # Just run the scheduler @@ -458,7 +491,7 @@ class TaskScheduler: result=result, error=error, ) - self._running_tasks.remove(task.id) + self._running_tasks.pop(task.id) current_time = self._clock.time() usage = log_context.get_resource_usage() @@ -489,6 +522,6 @@ class TaskScheduler: if task.id in self._running_tasks: return - self._running_tasks.add(task.id) await self.update_task(task.id, status=TaskStatus.ACTIVE) - self.hs.run_as_background_process(f"task-{task.action}", wrapper) + deferred = self.hs.run_as_background_process(f"task-{task.action}", wrapper) + self._running_tasks[task.id] = deferred diff --git a/tests/rest/admin/test_scheduled_tasks.py b/tests/rest/admin/test_scheduled_tasks.py index fb275f6d55..4b7adb6b89 100644 --- a/tests/rest/admin/test_scheduled_tasks.py +++ b/tests/rest/admin/test_scheduled_tasks.py @@ -135,7 +135,7 @@ class ScheduledTasksAdminApiTestCase(unittest.HomeserverTestCase): # filter via job_status channel = self.make_request( "GET", - "/_synapse/admin/v1/scheduled_tasks?job_status=active", + "/_synapse/admin/v1/scheduled_tasks?status=active", content={}, access_token=self.admin_user_tok, ) diff --git a/tests/util/test_task_scheduler.py b/tests/util/test_task_scheduler.py index 2c8e21b339..94c1d778e6 100644 --- a/tests/util/test_task_scheduler.py +++ b/tests/util/test_task_scheduler.py @@ -40,6 +40,12 @@ class TestTaskScheduler(HomeserverTestCase): self.task_scheduler.register_action(self._sleeping_task, "_sleeping_task") self.task_scheduler.register_action(self._raising_task, "_raising_task") self.task_scheduler.register_action(self._resumable_task, "_resumable_task") + self.task_scheduler.register_action( + self._incrementing_active_task, "_incrementing_active_task" + ) + self.task_scheduler.register_action( + self._incrementing_running_task, "_incrementing_running_task" + ) async def _test_task( self, task: ScheduledTask @@ -187,8 +193,8 @@ class TestTaskScheduler(HomeserverTestCase): self.assertEqual(task.status, TaskStatus.ACTIVE) # Simulate a synapse restart by emptying the list of running tasks - self.task_scheduler._running_tasks = set() - self.reactor.advance((TaskScheduler.SCHEDULE_INTERVAL.as_secs())) + self.task_scheduler._running_tasks = {} + self.reactor.advance(TaskScheduler.SCHEDULE_INTERVAL.as_secs()) task = self.get_success(self.task_scheduler.get_task(task_id)) assert task is not None @@ -196,6 +202,101 @@ class TestTaskScheduler(HomeserverTestCase): assert task.result is not None self.assertTrue(task.result.get("success")) + def _test_cancel_task(self, task_id: str) -> None: + task = self.get_success(self.task_scheduler.get_task(task_id)) + assert task is not None + assert task.status == TaskStatus.ACTIVE + + assert task.result and "counter" in task.result + current_counter = int(task.result["counter"]) + + self.reactor.advance(1) + + task = self.get_success(self.task_scheduler.get_task(task_id)) + assert task is not None + assert task.status == TaskStatus.ACTIVE + + # At this point the task should have run at least one more time, let's check the counter + assert task.result and "counter" in task.result + new_counter = int(task.result["counter"]) + assert new_counter > current_counter + current_counter = new_counter + + # Cancelling active task + self.get_success(self.task_scheduler.cancel_task(task_id)) + + self.reactor.advance(1) + + # Task should be marked as cancelled + task = self.get_success(self.task_scheduler.get_task(task_id)) + assert task is not None + self.assertEqual(task.status, TaskStatus.CANCELLED) + + # Task should be in the running tasks + assert task_id not in self.task_scheduler._running_tasks + + # Counter should not increase anymore and stay the same + assert task.result and "counter" in task.result + new_counter = int(task.result["counter"]) + assert new_counter == current_counter + current_counter = new_counter + + # Let's check one more time to be sure that it is not increasing + self.reactor.advance(100) + + task = self.get_success(self.task_scheduler.get_task(task_id)) + assert task is not None + assert task.result and "counter" in task.result + new_counter = int(task.result["counter"]) + assert new_counter == current_counter + + async def _incrementing_running_task( + self, task: ScheduledTask + ) -> tuple[TaskStatus, JsonMapping | None, str | None]: + current_counter = 0 + + while True: + current_counter += 1 + await self.task_scheduler.update_task( + task.id, result={"counter": current_counter} + ) + await self.hs.get_clock().sleep(Duration(microseconds=1)) + + return TaskStatus.COMPLETE, None, None # type: ignore[unreachable] + + def test_cancel_running_task(self) -> None: + """Schedule and then cancel a long running task that increments a counter.""" + + task_id = self.get_success( + self.task_scheduler.schedule_task( + "_incrementing_running_task", + ) + ) + + self._test_cancel_task(task_id) + + async def _incrementing_active_task( + self, task: ScheduledTask + ) -> tuple[TaskStatus, JsonMapping | None, str | None]: + current_counter = 0 + if task.result and "counter" in task.result: + current_counter = int(task.result["counter"]) + + return TaskStatus.ACTIVE, {"counter": current_counter + 1}, None + + def test_cancel_active_task(self) -> None: + """Schedule and then cancel a long active task that increments a counter, + but step by step, by keeping the task ACTIVE even if it is not running + between the steps.""" + + task_id = self.get_success( + self.task_scheduler.schedule_task( + "_incrementing_active_task", + ) + ) + + self._test_cancel_task(task_id) + class TestTaskSchedulerWithBackgroundWorker(BaseMultiWorkerStreamTestCase): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: