Write union types as X | Y where possible (#19111)
Some checks are pending
Build docker images / Build and push image for linux/amd64 (push) Waiting to run
Build docker images / Build and push image for linux/arm64 (push) Waiting to run
Build docker images / Push merged images to docker.io/matrixdotorg/synapse (push) Blocked by required conditions
Build docker images / Push merged images to ghcr.io/element-hq/synapse (push) Blocked by required conditions
Deploy the documentation / Calculate variables for GitHub Pages deployment (push) Waiting to run
Deploy the documentation / GitHub Pages (push) Blocked by required conditions
Build release artifacts / Calculate list of debian distros (push) Waiting to run
Build release artifacts / Build .deb packages (push) Blocked by required conditions
Build release artifacts / Build wheels on macos-14 (push) Waiting to run
Build release artifacts / Build wheels on macos-15-intel (push) Waiting to run
Build release artifacts / Build wheels on ubuntu-24.04 (push) Waiting to run
Build release artifacts / Build wheels on ubuntu-24.04-arm (push) Waiting to run
Build release artifacts / Build sdist (push) Waiting to run
Build release artifacts / Attach assets to release (push) Blocked by required conditions
Schema / Ensure Synapse config schema is valid (push) Waiting to run
Schema / Ensure generated documentation is up-to-date (push) Waiting to run
Tests / lint (push) Blocked by required conditions
Tests / lint-readme (push) Blocked by required conditions
Tests / linting-done (push) Blocked by required conditions
Tests / calculate-test-jobs (push) Blocked by required conditions
Tests / changes (push) Waiting to run
Tests / check-sampleconfig (push) Blocked by required conditions
Tests / check-schema-delta (push) Blocked by required conditions
Tests / check-lockfile (push) Waiting to run
Tests / Typechecking (push) Blocked by required conditions
Tests / lint-crlf (push) Waiting to run
Tests / lint-newsfile (push) Waiting to run
Tests / lint-clippy (push) Blocked by required conditions
Tests / lint-clippy-nightly (push) Blocked by required conditions
Tests / lint-rust (push) Blocked by required conditions
Tests / lint-rustfmt (push) Blocked by required conditions
Tests / trial (push) Blocked by required conditions
Tests / trial-olddeps (push) Blocked by required conditions
Tests / trial-pypy (all, pypy-3.10) (push) Blocked by required conditions
Tests / sytest (push) Blocked by required conditions
Tests / export-data (push) Blocked by required conditions
Tests / portdb (13, 3.10) (push) Blocked by required conditions
Tests / portdb (17, 3.14) (push) Blocked by required conditions
Tests / complement (monolith, Postgres) (push) Blocked by required conditions
Tests / complement (monolith, SQLite) (push) Blocked by required conditions
Tests / complement (workers, Postgres) (push) Blocked by required conditions
Tests / cargo-test (push) Blocked by required conditions
Tests / cargo-bench (push) Blocked by required conditions
Tests / tests-done (push) Blocked by required conditions

aka PEP 604, added in Python 3.10
This commit is contained in:
Andrew Ferrazzutti 2025-11-06 15:02:33 -05:00 committed by GitHub
parent 6790312831
commit fcac7e0282
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
465 changed files with 4034 additions and 4555 deletions

View file

@ -25,7 +25,6 @@
import argparse
import os
import subprocess
from typing import Optional
from zipfile import ZipFile
from packaging.tags import Tag
@ -80,7 +79,7 @@ def cpython(wheel_file: str, name: str, version: Version, tag: Tag) -> str:
return new_wheel_file
def main(wheel_file: str, dest_dir: str, archs: Optional[str]) -> None:
def main(wheel_file: str, dest_dir: str, archs: str | None) -> None:
"""Entry point"""
# Parse the wheel file name into its parts. Note that `parse_wheel_filename`

1
changelog.d/19111.misc Normal file
View file

@ -0,0 +1 @@
Write union types as `X | Y` where possible, as per PEP 604, added in Python 3.10.

View file

@ -33,7 +33,6 @@ import sys
import time
import urllib
from http import TwistedHttpClient
from typing import Optional
import urlparse
from signedjson.key import NACL_ED25519, decode_verify_key_bytes
@ -726,7 +725,7 @@ class SynapseCmd(cmd.Cmd):
method,
path,
data=None,
query_params: Optional[dict] = None,
query_params: dict | None = None,
alt_text=None,
):
"""Runs an HTTP request and pretty prints the output.

View file

@ -22,7 +22,6 @@
import json
import urllib
from pprint import pformat
from typing import Optional
from twisted.internet import defer, reactor
from twisted.web.client import Agent, readBody
@ -90,7 +89,7 @@ class TwistedHttpClient(HttpClient):
body = yield readBody(response)
return json.loads(body)
def _create_put_request(self, url, json_data, headers_dict: Optional[dict] = None):
def _create_put_request(self, url, json_data, headers_dict: dict | None = None):
"""Wrapper of _create_request to issue a PUT request"""
headers_dict = headers_dict or {}
@ -101,7 +100,7 @@ class TwistedHttpClient(HttpClient):
"PUT", url, producer=_JsonProducer(json_data), headers_dict=headers_dict
)
def _create_get_request(self, url, headers_dict: Optional[dict] = None):
def _create_get_request(self, url, headers_dict: dict | None = None):
"""Wrapper of _create_request to issue a GET request"""
return self._create_request("GET", url, headers_dict=headers_dict or {})
@ -113,7 +112,7 @@ class TwistedHttpClient(HttpClient):
data=None,
qparams=None,
jsonreq=True,
headers: Optional[dict] = None,
headers: dict | None = None,
):
headers = headers or {}
@ -138,7 +137,7 @@ class TwistedHttpClient(HttpClient):
@defer.inlineCallbacks
def _create_request(
self, method, url, producer=None, headers_dict: Optional[dict] = None
self, method, url, producer=None, headers_dict: dict | None = None
):
"""Creates and sends a request to the given url"""
headers_dict = headers_dict or {}

View file

@ -68,7 +68,6 @@ from typing import (
Mapping,
MutableMapping,
NoReturn,
Optional,
SupportsIndex,
)
@ -468,7 +467,7 @@ def add_worker_roles_to_shared_config(
def merge_worker_template_configs(
existing_dict: Optional[dict[str, Any]],
existing_dict: dict[str, Any] | None,
to_be_merged_dict: dict[str, Any],
) -> dict[str, Any]:
"""When given an existing dict of worker template configuration consisting with both
@ -1026,7 +1025,7 @@ def generate_worker_log_config(
Returns: the path to the generated file
"""
# Check whether we should write worker logs to disk, in addition to the console
extra_log_template_args: dict[str, Optional[str]] = {}
extra_log_template_args: dict[str, str | None] = {}
if environ.get("SYNAPSE_WORKERS_WRITE_LOGS_TO_DISK"):
extra_log_template_args["LOG_FILE_PATH"] = f"{data_dir}/logs/{worker_name}.log"

View file

@ -6,7 +6,7 @@ import os
import platform
import subprocess
import sys
from typing import Any, Mapping, MutableMapping, NoReturn, Optional
from typing import Any, Mapping, MutableMapping, NoReturn
import jinja2
@ -50,7 +50,7 @@ def generate_config_from_template(
config_dir: str,
config_path: str,
os_environ: Mapping[str, str],
ownership: Optional[str],
ownership: str | None,
) -> None:
"""Generate a homeserver.yaml from environment variables
@ -147,7 +147,7 @@ def generate_config_from_template(
subprocess.run(args, check=True)
def run_generate_config(environ: Mapping[str, str], ownership: Optional[str]) -> None:
def run_generate_config(environ: Mapping[str, str], ownership: str | None) -> None:
"""Run synapse with a --generate-config param to generate a template config file
Args:

View file

@ -299,7 +299,7 @@ logcontext is not finished before the `async` processing completes.
**Bad**:
```python
cache: Optional[ObservableDeferred[None]] = None
cache: ObservableDeferred[None] | None = None
async def do_something_else(
to_resolve: Deferred[None]
@ -326,7 +326,7 @@ with LoggingContext("request-1"):
**Good**:
```python
cache: Optional[ObservableDeferred[None]] = None
cache: ObservableDeferred[None] | None = None
async def do_something_else(
to_resolve: Deferred[None]
@ -358,7 +358,7 @@ with LoggingContext("request-1"):
**OK**:
```python
cache: Optional[ObservableDeferred[None]] = None
cache: ObservableDeferred[None] | None = None
async def do_something_else(
to_resolve: Deferred[None]

View file

@ -15,7 +15,7 @@ _First introduced in Synapse v1.57.0_
```python
async def on_account_data_updated(
user_id: str,
room_id: Optional[str],
room_id: str | None,
account_data_type: str,
content: "synapse.module_api.JsonDict",
) -> None:
@ -82,7 +82,7 @@ class CustomAccountDataModule:
async def log_new_account_data(
self,
user_id: str,
room_id: Optional[str],
room_id: str | None,
account_data_type: str,
content: JsonDict,
) -> None:

View file

@ -12,7 +12,7 @@ The available account validity callbacks are:
_First introduced in Synapse v1.39.0_
```python
async def is_user_expired(user: str) -> Optional[bool]
async def is_user_expired(user: str) -> bool | None
```
Called when processing any authenticated request (except for logout requests). The module

View file

@ -11,7 +11,7 @@ The available media repository callbacks are:
_First introduced in Synapse v1.132.0_
```python
async def get_media_config_for_user(user_id: str) -> Optional[JsonDict]
async def get_media_config_for_user(user_id: str) -> JsonDict | None
```
**<span style="color:red">
@ -70,7 +70,7 @@ implementations of this callback.
_First introduced in Synapse v1.139.0_
```python
async def get_media_upload_limits_for_user(user_id: str, size: int) -> Optional[List[synapse.module_api.MediaUploadLimit]]
async def get_media_upload_limits_for_user(user_id: str, size: int) -> list[synapse.module_api.MediaUploadLimit] | None
```
**<span style="color:red">

View file

@ -23,12 +23,7 @@ async def check_auth(
user: str,
login_type: str,
login_dict: "synapse.module_api.JsonDict",
) -> Optional[
Tuple[
str,
Optional[Callable[["synapse.module_api.LoginResponse"], Awaitable[None]]]
]
]
) -> tuple[str, Callable[["synapse.module_api.LoginResponse"], Awaitable[None]] | None] | None
```
The login type and field names should be provided by the user in the
@ -67,12 +62,7 @@ async def check_3pid_auth(
medium: str,
address: str,
password: str,
) -> Optional[
Tuple[
str,
Optional[Callable[["synapse.module_api.LoginResponse"], Awaitable[None]]]
]
]
) -> tuple[str, Callable[["synapse.module_api.LoginResponse"], Awaitable[None]] | None]
```
Called when a user attempts to register or log in with a third party identifier,
@ -98,7 +88,7 @@ _First introduced in Synapse v1.46.0_
```python
async def on_logged_out(
user_id: str,
device_id: Optional[str],
device_id: str | None,
access_token: str
) -> None
```
@ -119,7 +109,7 @@ _First introduced in Synapse v1.52.0_
async def get_username_for_registration(
uia_results: Dict[str, Any],
params: Dict[str, Any],
) -> Optional[str]
) -> str | None
```
Called when registering a new user. The module can return a username to set for the user
@ -180,7 +170,7 @@ _First introduced in Synapse v1.54.0_
async def get_displayname_for_registration(
uia_results: Dict[str, Any],
params: Dict[str, Any],
) -> Optional[str]
) -> str | None
```
Called when registering a new user. The module can return a display name to set for the
@ -259,12 +249,7 @@ class MyAuthProvider:
username: str,
login_type: str,
login_dict: "synapse.module_api.JsonDict",
) -> Optional[
Tuple[
str,
Optional[Callable[["synapse.module_api.LoginResponse"], Awaitable[None]]],
]
]:
) -> tuple[str, Callable[["synapse.module_api.LoginResponse"], Awaitable[None]] | None] | None:
if login_type != "my.login_type":
return None
@ -276,12 +261,7 @@ class MyAuthProvider:
username: str,
login_type: str,
login_dict: "synapse.module_api.JsonDict",
) -> Optional[
Tuple[
str,
Optional[Callable[["synapse.module_api.LoginResponse"], Awaitable[None]]],
]
]:
) -> tuple[str, Callable[["synapse.module_api.LoginResponse"], Awaitable[None]] | None] | None:
if login_type != "m.login.password":
return None

View file

@ -23,7 +23,7 @@ _First introduced in Synapse v1.42.0_
```python
async def get_users_for_states(
state_updates: Iterable["synapse.api.UserPresenceState"],
) -> Dict[str, Set["synapse.api.UserPresenceState"]]
) -> dict[str, set["synapse.api.UserPresenceState"]]
```
**Requires** `get_interested_users` to also be registered
@ -45,7 +45,7 @@ _First introduced in Synapse v1.42.0_
```python
async def get_interested_users(
user_id: str
) -> Union[Set[str], "synapse.module_api.PRESENCE_ALL_USERS"]
) -> set[str] | "synapse.module_api.PRESENCE_ALL_USERS"
```
**Requires** `get_users_for_states` to also be registered
@ -73,7 +73,7 @@ that `@alice:example.org` receives all presence updates from `@bob:example.com`
`@charlie:somewhere.org`, regardless of whether Alice shares a room with any of them.
```python
from typing import Dict, Iterable, Set, Union
from typing import Iterable
from synapse.module_api import ModuleApi
@ -90,7 +90,7 @@ class CustomPresenceRouter:
async def get_users_for_states(
self,
state_updates: Iterable["synapse.api.UserPresenceState"],
) -> Dict[str, Set["synapse.api.UserPresenceState"]]:
) -> dict[str, set["synapse.api.UserPresenceState"]]:
res = {}
for update in state_updates:
if (
@ -104,7 +104,7 @@ class CustomPresenceRouter:
async def get_interested_users(
self,
user_id: str,
) -> Union[Set[str], "synapse.module_api.PRESENCE_ALL_USERS"]:
) -> set[str] | "synapse.module_api.PRESENCE_ALL_USERS":
if user_id == "@alice:example.com":
return {"@bob:example.com", "@charlie:somewhere.org"}

View file

@ -11,7 +11,7 @@ The available ratelimit callbacks are:
_First introduced in Synapse v1.132.0_
```python
async def get_ratelimit_override_for_user(user: str, limiter_name: str) -> Optional[synapse.module_api.RatelimitOverride]
async def get_ratelimit_override_for_user(user: str, limiter_name: str) -> synapse.module_api.RatelimitOverride | None
```
**<span style="color:red">

View file

@ -331,9 +331,9 @@ search results; otherwise return `False`.
The profile is represented as a dictionary with the following keys:
* `user_id: str`. The Matrix ID for this user.
* `display_name: Optional[str]`. The user's display name, or `None` if this user
* `display_name: str | None`. The user's display name, or `None` if this user
has not set a display name.
* `avatar_url: Optional[str]`. The `mxc://` URL to the user's avatar, or `None`
* `avatar_url: str | None`. The `mxc://` URL to the user's avatar, or `None`
if this user has not set an avatar.
The module is given a copy of the original dictionary, so modifying it from within the
@ -352,10 +352,10 @@ _First introduced in Synapse v1.37.0_
```python
async def check_registration_for_spam(
email_threepid: Optional[dict],
username: Optional[str],
email_threepid: dict | None,
username: str | None,
request_info: Collection[Tuple[str, str]],
auth_provider_id: Optional[str] = None,
auth_provider_id: str | None = None,
) -> "synapse.spam_checker_api.RegistrationBehaviour"
```
@ -438,10 +438,10 @@ _First introduced in Synapse v1.87.0_
```python
async def check_login_for_spam(
user_id: str,
device_id: Optional[str],
initial_display_name: Optional[str],
request_info: Collection[Tuple[Optional[str], str]],
auth_provider_id: Optional[str] = None,
device_id: str | None,
initial_display_name: str | None,
request_info: Collection[tuple[str | None, str]],
auth_provider_id: str | None = None,
) -> Union["synapse.module_api.NOT_SPAM", "synapse.module_api.errors.Codes"]
```
@ -509,7 +509,7 @@ class ListSpamChecker:
resource=IsUserEvilResource(config),
)
async def check_event_for_spam(self, event: "synapse.events.EventBase") -> Union[Literal["NOT_SPAM"], Codes]:
async def check_event_for_spam(self, event: "synapse.events.EventBase") -> Literal["NOT_SPAM"] | Codes:
if event.sender in self.evil_users:
return Codes.FORBIDDEN
else:

View file

@ -16,7 +16,7 @@ _First introduced in Synapse v1.39.0_
async def check_event_allowed(
event: "synapse.events.EventBase",
state_events: "synapse.types.StateMap",
) -> Tuple[bool, Optional[dict]]
) -> tuple[bool, dict | None]
```
**<span style="color:red">
@ -340,7 +340,7 @@ class EventCensorer:
self,
event: "synapse.events.EventBase",
state_events: "synapse.types.StateMap",
) -> Tuple[bool, Optional[dict]]:
) -> Tuple[bool, dict | None]:
event_dict = event.get_dict()
new_event_content = await self.api.http_client.post_json_get_json(
uri=self._endpoint, post_json=event_dict,

View file

@ -76,7 +76,7 @@ possible.
#### `get_interested_users`
```python
async def get_interested_users(self, user_id: str) -> Union[Set[str], str]
async def get_interested_users(self, user_id: str) -> set[str] | str
```
**Required.** An asynchronous method that is passed a single Matrix User ID. This
@ -182,7 +182,7 @@ class ExamplePresenceRouter:
async def get_interested_users(
self,
user_id: str,
) -> Union[Set[str], PresenceRouter.ALL_USERS]:
) -> set[str] | PresenceRouter.ALL_USERS:
"""
Retrieve a list of users that `user_id` is interested in receiving the
presence of. This will be in addition to those they share a room with.

View file

@ -80,10 +80,15 @@ select = [
"G",
# pyupgrade
"UP006",
"UP007",
"UP045",
]
extend-safe-fixes = [
# pyupgrade
"UP006"
# pyupgrade rules compatible with Python >= 3.9
"UP006",
"UP007",
# pyupgrade rules compatible with Python >= 3.10
"UP045",
]
[tool.ruff.lint.isort]

View file

@ -18,7 +18,7 @@ import sys
import threading
from concurrent.futures import ThreadPoolExecutor
from types import FrameType
from typing import Collection, Optional, Sequence
from typing import Collection, Sequence
# These are expanded inside the dockerfile to be a fully qualified image name.
# e.g. docker.io/library/debian:bookworm
@ -49,7 +49,7 @@ class Builder:
def __init__(
self,
redirect_stdout: bool = False,
docker_build_args: Optional[Sequence[str]] = None,
docker_build_args: Sequence[str] | None = None,
):
self.redirect_stdout = redirect_stdout
self._docker_build_args = tuple(docker_build_args or ())
@ -167,7 +167,7 @@ class Builder:
def run_builds(
builder: Builder, dists: Collection[str], jobs: int = 1, skip_tests: bool = False
) -> None:
def sig(signum: int, _frame: Optional[FrameType]) -> None:
def sig(signum: int, _frame: FrameType | None) -> None:
print("Caught SIGINT")
builder.kill_containers()

View file

@ -43,7 +43,7 @@ import argparse
import base64
import json
import sys
from typing import Any, Mapping, Optional, Union
from typing import Any, Mapping
from urllib import parse as urlparse
import requests
@ -103,12 +103,12 @@ def sign_json(
def request(
method: Optional[str],
method: str | None,
origin_name: str,
origin_key: signedjson.types.SigningKey,
destination: str,
path: str,
content: Optional[str],
content: str | None,
verify_tls: bool,
) -> requests.Response:
if method is None:
@ -301,9 +301,9 @@ class MatrixConnectionAdapter(HTTPAdapter):
def get_connection_with_tls_context(
self,
request: PreparedRequest,
verify: Optional[Union[bool, str]],
proxies: Optional[Mapping[str, str]] = None,
cert: Optional[Union[tuple[str, str], str]] = None,
verify: bool | str | None,
proxies: Mapping[str, str] | None = None,
cert: tuple[str, str] | str | None = None,
) -> HTTPConnectionPool:
# overrides the get_connection_with_tls_context() method in the base class
parsed = urlparse.urlsplit(request.url)
@ -368,7 +368,7 @@ class MatrixConnectionAdapter(HTTPAdapter):
return server_name, 8448, server_name
@staticmethod
def _get_well_known(server_name: str) -> Optional[str]:
def _get_well_known(server_name: str) -> str | None:
if ":" in server_name:
# explicit port, or ipv6 literal. Either way, no .well-known
return None

View file

@ -4,7 +4,7 @@
import json
import re
import sys
from typing import Any, Optional
from typing import Any
import yaml
@ -259,17 +259,17 @@ def indent(text: str, first_line: bool = True) -> str:
return text
def em(s: Optional[str]) -> str:
def em(s: str | None) -> str:
"""Add emphasis to text."""
return f"*{s}*" if s else ""
def a(s: Optional[str], suffix: str = " ") -> str:
def a(s: str | None, suffix: str = " ") -> str:
"""Appends a space if the given string is not empty."""
return s + suffix if s else ""
def p(s: Optional[str], prefix: str = " ") -> str:
def p(s: str | None, prefix: str = " ") -> str:
"""Prepend a space if the given string is not empty."""
return prefix + s if s else ""

View file

@ -24,7 +24,7 @@ can crop up, e.g the cache descriptors.
"""
import enum
from typing import Callable, Mapping, Optional, Union
from typing import Callable, Mapping
import attr
import mypy.types
@ -123,7 +123,7 @@ class ArgLocation:
"""
prometheus_metric_fullname_to_label_arg_map: Mapping[str, Optional[ArgLocation]] = {
prometheus_metric_fullname_to_label_arg_map: Mapping[str, ArgLocation | None] = {
# `Collector` subclasses:
"prometheus_client.metrics.MetricWrapperBase": ArgLocation("labelnames", 2),
"prometheus_client.metrics.Counter": ArgLocation("labelnames", 2),
@ -211,7 +211,7 @@ class SynapsePlugin(Plugin):
def get_base_class_hook(
self, fullname: str
) -> Optional[Callable[[ClassDefContext], None]]:
) -> Callable[[ClassDefContext], None] | None:
def _get_base_class_hook(ctx: ClassDefContext) -> None:
# Run any `get_base_class_hook` checks from other plugins first.
#
@ -232,7 +232,7 @@ class SynapsePlugin(Plugin):
def get_function_signature_hook(
self, fullname: str
) -> Optional[Callable[[FunctionSigContext], FunctionLike]]:
) -> Callable[[FunctionSigContext], FunctionLike] | None:
# Strip off the unique identifier for classes that are dynamically created inside
# functions. ex. `synapse.metrics.jemalloc.JemallocCollector@185` (this is the line
# number)
@ -262,7 +262,7 @@ class SynapsePlugin(Plugin):
def get_method_signature_hook(
self, fullname: str
) -> Optional[Callable[[MethodSigContext], CallableType]]:
) -> Callable[[MethodSigContext], CallableType] | None:
if fullname.startswith(
(
"synapse.util.caches.descriptors.CachedFunction.__call__",
@ -721,7 +721,7 @@ def check_is_cacheable_wrapper(ctx: MethodSigContext) -> CallableType:
def check_is_cacheable(
signature: CallableType,
ctx: Union[MethodSigContext, FunctionSigContext],
ctx: MethodSigContext | FunctionSigContext,
) -> None:
"""
Check if a callable returns a type which can be cached.
@ -795,7 +795,7 @@ AT_CACHED_MUTABLE_RETURN = ErrorCode(
def is_cacheable(
rt: mypy.types.Type, signature: CallableType, verbose: bool
) -> tuple[bool, Optional[str]]:
) -> tuple[bool, str | None]:
"""
Check if a particular type is cachable.

View file

@ -32,7 +32,7 @@ import time
import urllib.request
from os import path
from tempfile import TemporaryDirectory
from typing import Any, Match, Optional, Union
from typing import Any, Match
import attr
import click
@ -327,11 +327,11 @@ def _prepare() -> None:
@cli.command()
@click.option("--gh-token", envvar=["GH_TOKEN", "GITHUB_TOKEN"])
def tag(gh_token: Optional[str]) -> None:
def tag(gh_token: str | None) -> None:
_tag(gh_token)
def _tag(gh_token: Optional[str]) -> None:
def _tag(gh_token: str | None) -> None:
"""Tags the release and generates a draft GitHub release"""
# Test that the GH Token is valid before continuing.
@ -471,11 +471,11 @@ def _publish(gh_token: str) -> None:
@cli.command()
@click.option("--gh-token", envvar=["GH_TOKEN", "GITHUB_TOKEN"], required=False)
def upload(gh_token: Optional[str]) -> None:
def upload(gh_token: str | None) -> None:
_upload(gh_token)
def _upload(gh_token: Optional[str]) -> None:
def _upload(gh_token: str | None) -> None:
"""Upload release to pypi."""
# Test that the GH Token is valid before continuing.
@ -576,11 +576,11 @@ def _merge_into(repo: Repo, source: str, target: str) -> None:
@cli.command()
@click.option("--gh-token", envvar=["GH_TOKEN", "GITHUB_TOKEN"], required=False)
def wait_for_actions(gh_token: Optional[str]) -> None:
def wait_for_actions(gh_token: str | None) -> None:
_wait_for_actions(gh_token)
def _wait_for_actions(gh_token: Optional[str]) -> None:
def _wait_for_actions(gh_token: str | None) -> None:
# Test that the GH Token is valid before continuing.
check_valid_gh_token(gh_token)
@ -658,7 +658,7 @@ def _notify(message: str) -> None:
envvar=["GH_TOKEN", "GITHUB_TOKEN"],
required=False,
)
def merge_back(_gh_token: Optional[str]) -> None:
def merge_back(_gh_token: str | None) -> None:
_merge_back()
@ -715,7 +715,7 @@ def _merge_back() -> None:
envvar=["GH_TOKEN", "GITHUB_TOKEN"],
required=False,
)
def announce(_gh_token: Optional[str]) -> None:
def announce(_gh_token: str | None) -> None:
_announce()
@ -851,7 +851,7 @@ def get_repo_and_check_clean_checkout(
return repo
def check_valid_gh_token(gh_token: Optional[str]) -> None:
def check_valid_gh_token(gh_token: str | None) -> None:
"""Check that a github token is valid, if supplied"""
if not gh_token:
@ -867,7 +867,7 @@ def check_valid_gh_token(gh_token: Optional[str]) -> None:
raise click.ClickException(f"Github credentials are bad: {e}")
def find_ref(repo: git.Repo, ref_name: str) -> Optional[git.HEAD]:
def find_ref(repo: git.Repo, ref_name: str) -> git.HEAD | None:
"""Find the branch/ref, looking first locally then in the remote."""
if ref_name in repo.references:
return repo.references[ref_name]
@ -904,7 +904,7 @@ def get_changes_for_version(wanted_version: version.Version) -> str:
# These are 0-based.
start_line: int
end_line: Optional[int] = None # Is none if its the last entry
end_line: int | None = None # Is none if its the last entry
headings: list[VersionSection] = []
for i, token in enumerate(tokens):
@ -991,7 +991,7 @@ def build_dependabot_changelog(repo: Repo, current_version: version.Version) ->
messages = []
for commit in reversed(commits):
if commit.author.name == "dependabot[bot]":
message: Union[str, bytes] = commit.message
message: str | bytes = commit.message
if isinstance(message, bytes):
message = message.decode("utf-8")
messages.append(message.split("\n", maxsplit=1)[0])

View file

@ -38,7 +38,7 @@ import io
import json
import sys
from collections import defaultdict
from typing import Any, Iterator, Optional
from typing import Any, Iterator
import git
from packaging import version
@ -57,7 +57,7 @@ SCHEMA_VERSION_FILES = (
OLDEST_SHOWN_VERSION = version.parse("v1.0")
def get_schema_versions(tag: git.Tag) -> tuple[Optional[int], Optional[int]]:
def get_schema_versions(tag: git.Tag) -> tuple[int | None, int | None]:
"""Get the schema and schema compat versions for a tag."""
schema_version = None
schema_compat_version = None

View file

@ -13,10 +13,8 @@ from typing import (
Iterator,
KeysView,
Mapping,
Optional,
Sequence,
TypeVar,
Union,
ValuesView,
overload,
)
@ -51,7 +49,7 @@ class SortedDict(dict[_KT, _VT]):
self, __key: _Key[_KT], __iterable: Iterable[tuple[_KT, _VT]], **kwargs: _VT
) -> None: ...
@property
def key(self) -> Optional[_Key[_KT]]: ...
def key(self) -> _Key[_KT] | None: ...
@property
def iloc(self) -> SortedKeysView[_KT]: ...
def clear(self) -> None: ...
@ -79,10 +77,10 @@ class SortedDict(dict[_KT, _VT]):
@overload
def pop(self, key: _KT) -> _VT: ...
@overload
def pop(self, key: _KT, default: _T = ...) -> Union[_VT, _T]: ...
def pop(self, key: _KT, default: _T = ...) -> _VT | _T: ...
def popitem(self, index: int = ...) -> tuple[_KT, _VT]: ...
def peekitem(self, index: int = ...) -> tuple[_KT, _VT]: ...
def setdefault(self, key: _KT, default: Optional[_VT] = ...) -> _VT: ...
def setdefault(self, key: _KT, default: _VT | None = ...) -> _VT: ...
# Mypy now reports the first overload as an error, because typeshed widened the type
# of `__map` to its internal `_typeshed.SupportsKeysAndGetItem` type in
# https://github.com/python/typeshed/pull/6653
@ -106,8 +104,8 @@ class SortedDict(dict[_KT, _VT]):
def _check(self) -> None: ...
def islice(
self,
start: Optional[int] = ...,
stop: Optional[int] = ...,
start: int | None = ...,
stop: int | None = ...,
reverse: bool = ...,
) -> Iterator[_KT]: ...
def bisect_left(self, value: _KT) -> int: ...
@ -118,7 +116,7 @@ class SortedKeysView(KeysView[_KT_co], Sequence[_KT_co]):
def __getitem__(self, index: int) -> _KT_co: ...
@overload
def __getitem__(self, index: slice) -> list[_KT_co]: ...
def __delitem__(self, index: Union[int, slice]) -> None: ...
def __delitem__(self, index: int | slice) -> None: ...
class SortedItemsView(ItemsView[_KT_co, _VT_co], Sequence[tuple[_KT_co, _VT_co]]):
def __iter__(self) -> Iterator[tuple[_KT_co, _VT_co]]: ...
@ -126,11 +124,11 @@ class SortedItemsView(ItemsView[_KT_co, _VT_co], Sequence[tuple[_KT_co, _VT_co]]
def __getitem__(self, index: int) -> tuple[_KT_co, _VT_co]: ...
@overload
def __getitem__(self, index: slice) -> list[tuple[_KT_co, _VT_co]]: ...
def __delitem__(self, index: Union[int, slice]) -> None: ...
def __delitem__(self, index: int | slice) -> None: ...
class SortedValuesView(ValuesView[_VT_co], Sequence[_VT_co]):
@overload
def __getitem__(self, index: int) -> _VT_co: ...
@overload
def __getitem__(self, index: slice) -> list[_VT_co]: ...
def __delitem__(self, index: Union[int, slice]) -> None: ...
def __delitem__(self, index: int | slice) -> None: ...

View file

@ -10,10 +10,8 @@ from typing import (
Iterable,
Iterator,
MutableSequence,
Optional,
Sequence,
TypeVar,
Union,
overload,
)
@ -29,8 +27,8 @@ class SortedList(MutableSequence[_T]):
DEFAULT_LOAD_FACTOR: int = ...
def __init__(
self,
iterable: Optional[Iterable[_T]] = ...,
key: Optional[_Key[_T]] = ...,
iterable: Iterable[_T] | None = ...,
key: _Key[_T] | None = ...,
): ...
# NB: currently mypy does not honour return type, see mypy #3307
@overload
@ -42,7 +40,7 @@ class SortedList(MutableSequence[_T]):
@overload
def __new__(cls, iterable: Iterable[_T], key: _Key[_T]) -> SortedKeyList[_T]: ...
@property
def key(self) -> Optional[Callable[[_T], Any]]: ...
def key(self) -> Callable[[_T], Any] | None: ...
def _reset(self, load: int) -> None: ...
def clear(self) -> None: ...
def _clear(self) -> None: ...
@ -57,7 +55,7 @@ class SortedList(MutableSequence[_T]):
def _pos(self, idx: int) -> int: ...
def _build_index(self) -> None: ...
def __contains__(self, value: Any) -> bool: ...
def __delitem__(self, index: Union[int, slice]) -> None: ...
def __delitem__(self, index: int | slice) -> None: ...
@overload
def __getitem__(self, index: int) -> _T: ...
@overload
@ -76,8 +74,8 @@ class SortedList(MutableSequence[_T]):
def reverse(self) -> None: ...
def islice(
self,
start: Optional[int] = ...,
stop: Optional[int] = ...,
start: int | None = ...,
stop: int | None = ...,
reverse: bool = ...,
) -> Iterator[_T]: ...
def _islice(
@ -90,8 +88,8 @@ class SortedList(MutableSequence[_T]):
) -> Iterator[_T]: ...
def irange(
self,
minimum: Optional[int] = ...,
maximum: Optional[int] = ...,
minimum: int | None = ...,
maximum: int | None = ...,
inclusive: tuple[bool, bool] = ...,
reverse: bool = ...,
) -> Iterator[_T]: ...
@ -107,7 +105,7 @@ class SortedList(MutableSequence[_T]):
def insert(self, index: int, value: _T) -> None: ...
def pop(self, index: int = ...) -> _T: ...
def index(
self, value: _T, start: Optional[int] = ..., stop: Optional[int] = ...
self, value: _T, start: int | None = ..., stop: int | None = ...
) -> int: ...
def __add__(self: _SL, other: Iterable[_T]) -> _SL: ...
def __radd__(self: _SL, other: Iterable[_T]) -> _SL: ...
@ -126,10 +124,10 @@ class SortedList(MutableSequence[_T]):
class SortedKeyList(SortedList[_T]):
def __init__(
self, iterable: Optional[Iterable[_T]] = ..., key: _Key[_T] = ...
self, iterable: Iterable[_T] | None = ..., key: _Key[_T] = ...
) -> None: ...
def __new__(
cls, iterable: Optional[Iterable[_T]] = ..., key: _Key[_T] = ...
cls, iterable: Iterable[_T] | None = ..., key: _Key[_T] = ...
) -> SortedKeyList[_T]: ...
@property
def key(self) -> Callable[[_T], Any]: ...
@ -146,15 +144,15 @@ class SortedKeyList(SortedList[_T]):
def _delete(self, pos: int, idx: int) -> None: ...
def irange(
self,
minimum: Optional[int] = ...,
maximum: Optional[int] = ...,
minimum: int | None = ...,
maximum: int | None = ...,
inclusive: tuple[bool, bool] = ...,
reverse: bool = ...,
) -> Iterator[_T]: ...
def irange_key(
self,
min_key: Optional[Any] = ...,
max_key: Optional[Any] = ...,
min_key: Any | None = ...,
max_key: Any | None = ...,
inclusive: tuple[bool, bool] = ...,
reserve: bool = ...,
) -> Iterator[_T]: ...
@ -170,7 +168,7 @@ class SortedKeyList(SortedList[_T]):
def copy(self: _SKL) -> _SKL: ...
def __copy__(self: _SKL) -> _SKL: ...
def index(
self, value: _T, start: Optional[int] = ..., stop: Optional[int] = ...
self, value: _T, start: int | None = ..., stop: int | None = ...
) -> int: ...
def __add__(self: _SKL, other: Iterable[_T]) -> _SKL: ...
def __radd__(self: _SKL, other: Iterable[_T]) -> _SKL: ...

View file

@ -11,10 +11,8 @@ from typing import (
Iterable,
Iterator,
MutableSet,
Optional,
Sequence,
TypeVar,
Union,
overload,
)
@ -28,21 +26,19 @@ _Key = Callable[[_T], Any]
class SortedSet(MutableSet[_T], Sequence[_T]):
def __init__(
self,
iterable: Optional[Iterable[_T]] = ...,
key: Optional[_Key[_T]] = ...,
iterable: Iterable[_T] | None = ...,
key: _Key[_T] | None = ...,
) -> None: ...
@classmethod
def _fromset(
cls, values: set[_T], key: Optional[_Key[_T]] = ...
) -> SortedSet[_T]: ...
def _fromset(cls, values: set[_T], key: _Key[_T] | None = ...) -> SortedSet[_T]: ...
@property
def key(self) -> Optional[_Key[_T]]: ...
def key(self) -> _Key[_T] | None: ...
def __contains__(self, value: Any) -> bool: ...
@overload
def __getitem__(self, index: int) -> _T: ...
@overload
def __getitem__(self, index: slice) -> list[_T]: ...
def __delitem__(self, index: Union[int, slice]) -> None: ...
def __delitem__(self, index: int | slice) -> None: ...
def __eq__(self, other: Any) -> bool: ...
def __ne__(self, other: Any) -> bool: ...
def __lt__(self, other: Iterable[_T]) -> bool: ...
@ -62,32 +58,28 @@ class SortedSet(MutableSet[_T], Sequence[_T]):
def _discard(self, value: _T) -> None: ...
def pop(self, index: int = ...) -> _T: ...
def remove(self, value: _T) -> None: ...
def difference(self, *iterables: Iterable[_S]) -> SortedSet[Union[_T, _S]]: ...
def __sub__(self, *iterables: Iterable[_S]) -> SortedSet[Union[_T, _S]]: ...
def difference_update(
self, *iterables: Iterable[_S]
) -> SortedSet[Union[_T, _S]]: ...
def __isub__(self, *iterables: Iterable[_S]) -> SortedSet[Union[_T, _S]]: ...
def intersection(self, *iterables: Iterable[_S]) -> SortedSet[Union[_T, _S]]: ...
def __and__(self, *iterables: Iterable[_S]) -> SortedSet[Union[_T, _S]]: ...
def __rand__(self, *iterables: Iterable[_S]) -> SortedSet[Union[_T, _S]]: ...
def intersection_update(
self, *iterables: Iterable[_S]
) -> SortedSet[Union[_T, _S]]: ...
def __iand__(self, *iterables: Iterable[_S]) -> SortedSet[Union[_T, _S]]: ...
def symmetric_difference(self, other: Iterable[_S]) -> SortedSet[Union[_T, _S]]: ...
def __xor__(self, other: Iterable[_S]) -> SortedSet[Union[_T, _S]]: ...
def __rxor__(self, other: Iterable[_S]) -> SortedSet[Union[_T, _S]]: ...
def difference(self, *iterables: Iterable[_S]) -> SortedSet[_T | _S]: ...
def __sub__(self, *iterables: Iterable[_S]) -> SortedSet[_T | _S]: ...
def difference_update(self, *iterables: Iterable[_S]) -> SortedSet[_T | _S]: ...
def __isub__(self, *iterables: Iterable[_S]) -> SortedSet[_T | _S]: ...
def intersection(self, *iterables: Iterable[_S]) -> SortedSet[_T | _S]: ...
def __and__(self, *iterables: Iterable[_S]) -> SortedSet[_T | _S]: ...
def __rand__(self, *iterables: Iterable[_S]) -> SortedSet[_T | _S]: ...
def intersection_update(self, *iterables: Iterable[_S]) -> SortedSet[_T | _S]: ...
def __iand__(self, *iterables: Iterable[_S]) -> SortedSet[_T | _S]: ...
def symmetric_difference(self, other: Iterable[_S]) -> SortedSet[_T | _S]: ...
def __xor__(self, other: Iterable[_S]) -> SortedSet[_T | _S]: ...
def __rxor__(self, other: Iterable[_S]) -> SortedSet[_T | _S]: ...
def symmetric_difference_update(
self, other: Iterable[_S]
) -> SortedSet[Union[_T, _S]]: ...
def __ixor__(self, other: Iterable[_S]) -> SortedSet[Union[_T, _S]]: ...
def union(self, *iterables: Iterable[_S]) -> SortedSet[Union[_T, _S]]: ...
def __or__(self, *iterables: Iterable[_S]) -> SortedSet[Union[_T, _S]]: ...
def __ror__(self, *iterables: Iterable[_S]) -> SortedSet[Union[_T, _S]]: ...
def update(self, *iterables: Iterable[_S]) -> SortedSet[Union[_T, _S]]: ...
def __ior__(self, *iterables: Iterable[_S]) -> SortedSet[Union[_T, _S]]: ...
def _update(self, *iterables: Iterable[_S]) -> SortedSet[Union[_T, _S]]: ...
) -> SortedSet[_T | _S]: ...
def __ixor__(self, other: Iterable[_S]) -> SortedSet[_T | _S]: ...
def union(self, *iterables: Iterable[_S]) -> SortedSet[_T | _S]: ...
def __or__(self, *iterables: Iterable[_S]) -> SortedSet[_T | _S]: ...
def __ror__(self, *iterables: Iterable[_S]) -> SortedSet[_T | _S]: ...
def update(self, *iterables: Iterable[_S]) -> SortedSet[_T | _S]: ...
def __ior__(self, *iterables: Iterable[_S]) -> SortedSet[_T | _S]: ...
def _update(self, *iterables: Iterable[_S]) -> SortedSet[_T | _S]: ...
def __reduce__(
self,
) -> tuple[type[SortedSet[_T]], set[_T], Callable[[_T], Any]]: ...
@ -97,18 +89,18 @@ class SortedSet(MutableSet[_T], Sequence[_T]):
def bisect_right(self, value: _T) -> int: ...
def islice(
self,
start: Optional[int] = ...,
stop: Optional[int] = ...,
start: int | None = ...,
stop: int | None = ...,
reverse: bool = ...,
) -> Iterator[_T]: ...
def irange(
self,
minimum: Optional[_T] = ...,
maximum: Optional[_T] = ...,
minimum: _T | None = ...,
maximum: _T | None = ...,
inclusive: tuple[bool, bool] = ...,
reverse: bool = ...,
) -> Iterator[_T]: ...
def index(
self, value: _T, start: Optional[int] = ..., stop: Optional[int] = ...
self, value: _T, start: int | None = ..., stop: int | None = ...
) -> int: ...
def _reset(self, load: int) -> None: ...

View file

@ -15,7 +15,7 @@
"""Contains *incomplete* type hints for txredisapi."""
from typing import Any, Optional, Union
from typing import Any
from twisted.internet import protocol
from twisted.internet.defer import Deferred
@ -29,8 +29,8 @@ class RedisProtocol(protocol.Protocol):
self,
key: str,
value: Any,
expire: Optional[int] = None,
pexpire: Optional[int] = None,
expire: int | None = None,
pexpire: int | None = None,
only_if_not_exists: bool = False,
only_if_exists: bool = False,
) -> "Deferred[None]": ...
@ -38,8 +38,8 @@ class RedisProtocol(protocol.Protocol):
class SubscriberProtocol(RedisProtocol):
def __init__(self, *args: object, **kwargs: object): ...
password: Optional[str]
def subscribe(self, channels: Union[str, list[str]]) -> "Deferred[None]": ...
password: str | None
def subscribe(self, channels: str | list[str]) -> "Deferred[None]": ...
def connectionMade(self) -> None: ...
# type-ignore: twisted.internet.protocol.Protocol provides a default argument for
# `reason`. txredisapi's LineReceiver Protocol doesn't. But that's fine: it's what's
@ -49,12 +49,12 @@ class SubscriberProtocol(RedisProtocol):
def lazyConnection(
host: str = ...,
port: int = ...,
dbid: Optional[int] = ...,
dbid: int | None = ...,
reconnect: bool = ...,
charset: str = ...,
password: Optional[str] = ...,
connectTimeout: Optional[int] = ...,
replyTimeout: Optional[int] = ...,
password: str | None = ...,
connectTimeout: int | None = ...,
replyTimeout: int | None = ...,
convertNumbers: bool = ...,
) -> RedisProtocol: ...
@ -70,18 +70,18 @@ class RedisFactory(protocol.ReconnectingClientFactory):
continueTrying: bool
handler: ConnectionHandler
pool: list[RedisProtocol]
replyTimeout: Optional[int]
replyTimeout: int | None
def __init__(
self,
uuid: str,
dbid: Optional[int],
dbid: int | None,
poolsize: int,
isLazy: bool = False,
handler: type = ConnectionHandler,
charset: str = "utf-8",
password: Optional[str] = None,
replyTimeout: Optional[int] = None,
convertNumbers: Optional[int] = True,
password: str | None = None,
replyTimeout: int | None = None,
convertNumbers: int | None = True,
): ...
def buildProtocol(self, addr: IAddress) -> RedisProtocol: ...

View file

@ -22,13 +22,13 @@
import argparse
import sys
import time
from typing import NoReturn, Optional
from typing import NoReturn
from signedjson.key import encode_verify_key_base64, get_verify_key, read_signing_keys
from signedjson.types import VerifyKey
def exit(status: int = 0, message: Optional[str] = None) -> NoReturn:
def exit(status: int = 0, message: str | None = None) -> NoReturn:
if message:
print(message, file=sys.stderr)
sys.exit(status)

View file

@ -25,7 +25,7 @@ import logging
import re
from collections import defaultdict
from dataclasses import dataclass
from typing import Iterable, Optional, Pattern
from typing import Iterable, Pattern
import yaml
@ -46,7 +46,7 @@ logger = logging.getLogger("generate_workers_map")
class MockHomeserver(HomeServer):
DATASTORE_CLASS = DataStore
def __init__(self, config: HomeServerConfig, worker_app: Optional[str]) -> None:
def __init__(self, config: HomeServerConfig, worker_app: str | None) -> None:
super().__init__(config.server.server_name, config=config)
self.config.worker.worker_app = worker_app
@ -65,7 +65,7 @@ class EndpointDescription:
# The category of this endpoint. Is read from the `CATEGORY` constant in the servlet
# class.
category: Optional[str]
category: str | None
# TODO:
# - does it need to be routed based on a stream writer config?
@ -141,7 +141,7 @@ def get_registered_paths_for_hs(
def get_registered_paths_for_default(
worker_app: Optional[str], base_config: HomeServerConfig
worker_app: str | None, base_config: HomeServerConfig
) -> dict[tuple[str, str], EndpointDescription]:
"""
Given the name of a worker application and a base homeserver configuration,
@ -271,7 +271,7 @@ def main() -> None:
# TODO SSO endpoints (pick_idp etc) NOT REGISTERED BY THIS SCRIPT
categories_to_methods_and_paths: dict[
Optional[str], dict[tuple[str, str], EndpointDescription]
str | None, dict[tuple[str, str], EndpointDescription]
] = defaultdict(dict)
for (method, path), desc in elided_worker_paths.items():
@ -282,7 +282,7 @@ def main() -> None:
def print_category(
category_name: Optional[str],
category_name: str | None,
elided_worker_paths: dict[tuple[str, str], EndpointDescription],
) -> None:
"""

View file

@ -26,7 +26,7 @@ import hashlib
import hmac
import logging
import sys
from typing import Any, Callable, Optional
from typing import Any, Callable
import requests
import yaml
@ -54,7 +54,7 @@ def request_registration(
server_location: str,
shared_secret: str,
admin: bool = False,
user_type: Optional[str] = None,
user_type: str | None = None,
_print: Callable[[str], None] = print,
exit: Callable[[int], None] = sys.exit,
exists_ok: bool = False,
@ -123,13 +123,13 @@ def register_new_user(
password: str,
server_location: str,
shared_secret: str,
admin: Optional[bool],
user_type: Optional[str],
admin: bool | None,
user_type: str | None,
exists_ok: bool = False,
) -> None:
if not user:
try:
default_user: Optional[str] = getpass.getuser()
default_user: str | None = getpass.getuser()
except Exception:
default_user = None
@ -262,7 +262,7 @@ def main() -> None:
args = parser.parse_args()
config: Optional[dict[str, Any]] = None
config: dict[str, Any] | None = None
if "config" in args and args.config:
config = yaml.safe_load(args.config)
@ -350,7 +350,7 @@ def _read_file(file_path: Any, config_path: str) -> str:
sys.exit(1)
def _find_client_listener(config: dict[str, Any]) -> Optional[str]:
def _find_client_listener(config: dict[str, Any]) -> str | None:
# try to find a listener in the config. Returns a host:port pair
for listener in config.get("listeners", []):
if listener.get("type") != "http" or listener.get("tls", False):

View file

@ -233,14 +233,14 @@ IGNORED_BACKGROUND_UPDATES = {
# Error returned by the run function. Used at the top-level part of the script to
# handle errors and return codes.
end_error: Optional[str] = None
end_error: str | None = None
# The exec_info for the error, if any. If error is defined but not exec_info the script
# will show only the error message without the stacktrace, if exec_info is defined but
# not the error then the script will show nothing outside of what's printed in the run
# function. If both are defined, the script will print both the error and the stacktrace.
end_error_exec_info: Optional[
tuple[type[BaseException], BaseException, TracebackType]
] = None
end_error_exec_info: tuple[type[BaseException], BaseException, TracebackType] | None = (
None
)
R = TypeVar("R")
@ -485,7 +485,7 @@ class Porter:
def r(
txn: LoggingTransaction,
) -> tuple[Optional[list[str]], list[tuple], list[tuple]]:
) -> tuple[list[str] | None, list[tuple], list[tuple]]:
forward_rows = []
backward_rows = []
if do_forward[0]:
@ -502,7 +502,7 @@ class Porter:
if forward_rows or backward_rows:
assert txn.description is not None
headers: Optional[list[str]] = [
headers: list[str] | None = [
column[0] for column in txn.description
]
else:
@ -1152,9 +1152,7 @@ class Porter:
return done, remaining + done
async def _setup_state_group_id_seq(self) -> None:
curr_id: Optional[
int
] = await self.sqlite_store.db_pool.simple_select_one_onecol(
curr_id: int | None = await self.sqlite_store.db_pool.simple_select_one_onecol(
table="state_groups", keyvalues={}, retcol="MAX(id)", allow_none=True
)
@ -1271,10 +1269,10 @@ class Porter:
await self.postgres_store.db_pool.runInteraction("_setup_%s" % (seq_name,), r)
async def _pg_get_serial_sequence(self, table: str, column: str) -> Optional[str]:
async def _pg_get_serial_sequence(self, table: str, column: str) -> str | None:
"""Returns the name of the postgres sequence associated with a column, or NULL."""
def r(txn: LoggingTransaction) -> Optional[str]:
def r(txn: LoggingTransaction) -> str | None:
txn.execute("SELECT pg_get_serial_sequence('%s', '%s')" % (table, column))
result = txn.fetchone()
if not result:
@ -1286,9 +1284,9 @@ class Porter:
)
async def _setup_auth_chain_sequence(self) -> None:
curr_chain_id: Optional[
int
] = await self.sqlite_store.db_pool.simple_select_one_onecol(
curr_chain_id: (
int | None
) = await self.sqlite_store.db_pool.simple_select_one_onecol(
table="event_auth_chains",
keyvalues={},
retcol="MAX(chain_id)",

View file

@ -30,7 +30,7 @@ import signal
import subprocess
import sys
import time
from typing import Iterable, NoReturn, Optional, TextIO
from typing import Iterable, NoReturn, TextIO
import yaml
@ -135,7 +135,7 @@ def start(pidfile: str, app: str, config_files: Iterable[str], daemonize: bool)
return False
def stop(pidfile: str, app: str) -> Optional[int]:
def stop(pidfile: str, app: str) -> int | None:
"""Attempts to kill a synapse worker from the pidfile.
Args:
pidfile: path to file containing worker's pid

View file

@ -18,7 +18,7 @@
# [This file includes modifications made by New Vector Limited]
#
#
from typing import TYPE_CHECKING, Optional, Protocol
from typing import TYPE_CHECKING, Protocol
from prometheus_client import Histogram
@ -51,7 +51,7 @@ class Auth(Protocol):
room_id: str,
requester: Requester,
allow_departed_users: bool = False,
) -> tuple[str, Optional[str]]:
) -> tuple[str, str | None]:
"""Check if the user is in the room, or was at some point.
Args:
room_id: The room to check.
@ -190,7 +190,7 @@ class Auth(Protocol):
async def check_user_in_room_or_world_readable(
self, room_id: str, requester: Requester, allow_departed_users: bool = False
) -> tuple[str, Optional[str]]:
) -> tuple[str, str | None]:
"""Checks that the user is or was in the room or the room is world
readable. If it isn't then an exception is raised.

View file

@ -19,7 +19,7 @@
#
#
import logging
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING
from netaddr import IPAddress
@ -64,7 +64,7 @@ class BaseAuth:
room_id: str,
requester: Requester,
allow_departed_users: bool = False,
) -> tuple[str, Optional[str]]:
) -> tuple[str, str | None]:
"""Check if the user is in the room, or was at some point.
Args:
room_id: The room to check.
@ -114,7 +114,7 @@ class BaseAuth:
@trace
async def check_user_in_room_or_world_readable(
self, room_id: str, requester: Requester, allow_departed_users: bool = False
) -> tuple[str, Optional[str]]:
) -> tuple[str, str | None]:
"""Checks that the user is or was in the room or the room is world
readable. If it isn't then an exception is raised.
@ -294,7 +294,7 @@ class BaseAuth:
@cancellable
async def get_appservice_user(
self, request: Request, access_token: str
) -> Optional[Requester]:
) -> Requester | None:
"""
Given a request, reads the request parameters to determine:
- whether it's an application service that's making this request

View file

@ -13,7 +13,7 @@
#
#
import logging
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING
from urllib.parse import urlencode
from pydantic import (
@ -74,11 +74,11 @@ class ServerMetadata(BaseModel):
class IntrospectionResponse(BaseModel):
retrieved_at_ms: StrictInt
active: StrictBool
scope: Optional[StrictStr] = None
username: Optional[StrictStr] = None
sub: Optional[StrictStr] = None
device_id: Optional[StrictStr] = None
expires_in: Optional[StrictInt] = None
scope: StrictStr | None = None
username: StrictStr | None = None
sub: StrictStr | None = None
device_id: StrictStr | None = None
expires_in: StrictInt | None = None
model_config = ConfigDict(extra="allow")
def get_scope_set(self) -> set[str]:

View file

@ -20,7 +20,7 @@
#
import logging
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Callable, Optional
from typing import TYPE_CHECKING, Any, Callable
from urllib.parse import urlencode
from authlib.oauth2 import ClientAuth
@ -102,25 +102,25 @@ class IntrospectionResult:
return []
return scope_to_list(value)
def get_sub(self) -> Optional[str]:
def get_sub(self) -> str | None:
value = self._inner.get("sub")
if not isinstance(value, str):
return None
return value
def get_username(self) -> Optional[str]:
def get_username(self) -> str | None:
value = self._inner.get("username")
if not isinstance(value, str):
return None
return value
def get_name(self) -> Optional[str]:
def get_name(self) -> str | None:
value = self._inner.get("name")
if not isinstance(value, str):
return None
return value
def get_device_id(self) -> Optional[str]:
def get_device_id(self) -> str | None:
value = self._inner.get("device_id")
if value is not None and not isinstance(value, str):
raise AuthError(
@ -174,7 +174,7 @@ class MSC3861DelegatedAuth(BaseAuth):
self._clock = hs.get_clock()
self._http_client = hs.get_proxied_http_client()
self._hostname = hs.hostname
self._admin_token: Callable[[], Optional[str]] = self._config.admin_token
self._admin_token: Callable[[], str | None] = self._config.admin_token
self._force_tracing_for_users = hs.config.tracing.force_tracing_for_users
self._rust_http_client = HttpClient(
@ -247,7 +247,7 @@ class MSC3861DelegatedAuth(BaseAuth):
metadata = await self._issuer_metadata.get()
return metadata.issuer or self._config.issuer
async def account_management_url(self) -> Optional[str]:
async def account_management_url(self) -> str | None:
"""
Get the configured account management URL

View file

@ -20,7 +20,7 @@
#
import logging
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING
from synapse.api.constants import LimitBlockingTypes, UserTypes
from synapse.api.errors import Codes, ResourceLimitError
@ -51,10 +51,10 @@ class AuthBlocking:
async def check_auth_blocking(
self,
user_id: Optional[str] = None,
threepid: Optional[dict] = None,
user_type: Optional[str] = None,
requester: Optional[Requester] = None,
user_id: str | None = None,
threepid: dict | None = None,
user_type: str | None = None,
requester: Requester | None = None,
) -> None:
"""Checks if the user should be rejected for some external reason,
such as monthly active user limiting or global disable flag

View file

@ -26,7 +26,7 @@ import math
import typing
from enum import Enum
from http import HTTPStatus
from typing import Any, Optional, Union
from typing import Any, Optional
from twisted.web import http
@ -164,9 +164,9 @@ class CodeMessageException(RuntimeError):
def __init__(
self,
code: Union[int, HTTPStatus],
code: int | HTTPStatus,
msg: str,
headers: Optional[dict[str, str]] = None,
headers: dict[str, str] | None = None,
):
super().__init__("%d: %s" % (code, msg))
@ -223,8 +223,8 @@ class SynapseError(CodeMessageException):
code: int,
msg: str,
errcode: str = Codes.UNKNOWN,
additional_fields: Optional[dict] = None,
headers: Optional[dict[str, str]] = None,
additional_fields: dict | None = None,
headers: dict[str, str] | None = None,
):
"""Constructs a synapse error.
@ -244,7 +244,7 @@ class SynapseError(CodeMessageException):
return cs_error(self.msg, self.errcode, **self._additional_fields)
@property
def debug_context(self) -> Optional[str]:
def debug_context(self) -> str | None:
"""Override this to add debugging context that shouldn't be sent to clients."""
return None
@ -276,7 +276,7 @@ class ProxiedRequestError(SynapseError):
code: int,
msg: str,
errcode: str = Codes.UNKNOWN,
additional_fields: Optional[dict] = None,
additional_fields: dict | None = None,
):
super().__init__(code, msg, errcode, additional_fields)
@ -340,7 +340,7 @@ class FederationDeniedError(SynapseError):
destination: The destination which has been denied
"""
def __init__(self, destination: Optional[str]):
def __init__(self, destination: str | None):
"""Raised by federation client or server to indicate that we are
are deliberately not attempting to contact a given server because it is
not on our federation whitelist.
@ -399,7 +399,7 @@ class AuthError(SynapseError):
code: int,
msg: str,
errcode: str = Codes.FORBIDDEN,
additional_fields: Optional[dict] = None,
additional_fields: dict | None = None,
):
super().__init__(code, msg, errcode, additional_fields)
@ -432,7 +432,7 @@ class UnstableSpecAuthError(AuthError):
msg: str,
errcode: str,
previous_errcode: str = Codes.FORBIDDEN,
additional_fields: Optional[dict] = None,
additional_fields: dict | None = None,
):
self.previous_errcode = previous_errcode
super().__init__(code, msg, errcode, additional_fields)
@ -497,8 +497,8 @@ class ResourceLimitError(SynapseError):
code: int,
msg: str,
errcode: str = Codes.RESOURCE_LIMIT_EXCEEDED,
admin_contact: Optional[str] = None,
limit_type: Optional[str] = None,
admin_contact: str | None = None,
limit_type: str | None = None,
):
self.admin_contact = admin_contact
self.limit_type = limit_type
@ -542,7 +542,7 @@ class InvalidCaptchaError(SynapseError):
self,
code: int = 400,
msg: str = "Invalid captcha.",
error_url: Optional[str] = None,
error_url: str | None = None,
errcode: str = Codes.CAPTCHA_INVALID,
):
super().__init__(code, msg, errcode)
@ -563,9 +563,9 @@ class LimitExceededError(SynapseError):
self,
limiter_name: str,
code: int = 429,
retry_after_ms: Optional[int] = None,
retry_after_ms: int | None = None,
errcode: str = Codes.LIMIT_EXCEEDED,
pause: Optional[float] = None,
pause: float | None = None,
):
# Use HTTP header Retry-After to enable library-assisted retry handling.
headers = (
@ -582,7 +582,7 @@ class LimitExceededError(SynapseError):
return cs_error(self.msg, self.errcode, retry_after_ms=self.retry_after_ms)
@property
def debug_context(self) -> Optional[str]:
def debug_context(self) -> str | None:
return self.limiter_name
@ -675,7 +675,7 @@ class RequestSendFailed(RuntimeError):
class UnredactedContentDeletedError(SynapseError):
def __init__(self, content_keep_ms: Optional[int] = None):
def __init__(self, content_keep_ms: int | None = None):
super().__init__(
404,
"The content for that event has already been erased from the database",
@ -751,7 +751,7 @@ class FederationError(RuntimeError):
code: int,
reason: str,
affected: str,
source: Optional[str] = None,
source: str | None = None,
):
if level not in ["FATAL", "ERROR", "WARN"]:
raise ValueError("Level is not valid: %s" % (level,))
@ -786,7 +786,7 @@ class FederationPullAttemptBackoffError(RuntimeError):
"""
def __init__(
self, event_ids: "StrCollection", message: Optional[str], retry_after_ms: int
self, event_ids: "StrCollection", message: str | None, retry_after_ms: int
):
event_ids = list(event_ids)

View file

@ -28,9 +28,7 @@ from typing import (
Collection,
Iterable,
Mapping,
Optional,
TypeVar,
Union,
)
import jsonschema
@ -155,7 +153,7 @@ class Filtering:
self.DEFAULT_FILTER_COLLECTION = FilterCollection(hs, {})
async def get_user_filter(
self, user_id: UserID, filter_id: Union[int, str]
self, user_id: UserID, filter_id: int | str
) -> "FilterCollection":
result = await self.store.get_user_filter(user_id, filter_id)
return FilterCollection(self._hs, result)
@ -531,7 +529,7 @@ class Filter:
return newFilter
def _matches_wildcard(actual_value: Optional[str], filter_value: str) -> bool:
def _matches_wildcard(actual_value: str | None, filter_value: str) -> bool:
if filter_value.endswith("*") and isinstance(actual_value, str):
type_prefix = filter_value[:-1]
return actual_value.startswith(type_prefix)

View file

@ -19,7 +19,7 @@
#
#
from typing import Any, Optional
from typing import Any
import attr
@ -41,15 +41,13 @@ class UserDevicePresenceState:
"""
user_id: str
device_id: Optional[str]
device_id: str | None
state: str
last_active_ts: int
last_sync_ts: int
@classmethod
def default(
cls, user_id: str, device_id: Optional[str]
) -> "UserDevicePresenceState":
def default(cls, user_id: str, device_id: str | None) -> "UserDevicePresenceState":
"""Returns a default presence state."""
return cls(
user_id=user_id,
@ -81,7 +79,7 @@ class UserPresenceState:
last_active_ts: int
last_federation_update_ts: int
last_user_sync_ts: int
status_msg: Optional[str]
status_msg: str | None
currently_active: bool
def as_dict(self) -> JsonDict:

View file

@ -102,9 +102,7 @@ class Ratelimiter:
self.clock.looping_call(self._prune_message_counts, 15 * 1000)
def _get_key(
self, requester: Optional[Requester], key: Optional[Hashable]
) -> Hashable:
def _get_key(self, requester: Requester | None, key: Hashable | None) -> Hashable:
"""Use the requester's MXID as a fallback key if no key is provided."""
if key is None:
if not requester:
@ -121,13 +119,13 @@ class Ratelimiter:
async def can_do_action(
self,
requester: Optional[Requester],
key: Optional[Hashable] = None,
rate_hz: Optional[float] = None,
burst_count: Optional[int] = None,
requester: Requester | None,
key: Hashable | None = None,
rate_hz: float | None = None,
burst_count: int | None = None,
update: bool = True,
n_actions: int = 1,
_time_now_s: Optional[float] = None,
_time_now_s: float | None = None,
) -> tuple[bool, float]:
"""Can the entity (e.g. user or IP address) perform the action?
@ -247,10 +245,10 @@ class Ratelimiter:
def record_action(
self,
requester: Optional[Requester],
key: Optional[Hashable] = None,
requester: Requester | None,
key: Hashable | None = None,
n_actions: int = 1,
_time_now_s: Optional[float] = None,
_time_now_s: float | None = None,
) -> None:
"""Record that an action(s) took place, even if they violate the rate limit.
@ -332,14 +330,14 @@ class Ratelimiter:
async def ratelimit(
self,
requester: Optional[Requester],
key: Optional[Hashable] = None,
rate_hz: Optional[float] = None,
burst_count: Optional[int] = None,
requester: Requester | None,
key: Hashable | None = None,
rate_hz: float | None = None,
burst_count: int | None = None,
update: bool = True,
n_actions: int = 1,
_time_now_s: Optional[float] = None,
pause: Optional[float] = 0.5,
_time_now_s: float | None = None,
pause: float | None = 0.5,
) -> None:
"""Checks if an action can be performed. If not, raises a LimitExceededError
@ -396,7 +394,7 @@ class RequestRatelimiter:
store: DataStore,
clock: Clock,
rc_message: RatelimitSettings,
rc_admin_redaction: Optional[RatelimitSettings],
rc_admin_redaction: RatelimitSettings | None,
):
self.store = store
self.clock = clock
@ -412,7 +410,7 @@ class RequestRatelimiter:
# Check whether ratelimiting room admin message redaction is enabled
# by the presence of rate limits in the config
if rc_admin_redaction:
self.admin_redaction_ratelimiter: Optional[Ratelimiter] = Ratelimiter(
self.admin_redaction_ratelimiter: Ratelimiter | None = Ratelimiter(
store=self.store,
clock=self.clock,
cfg=rc_admin_redaction,

View file

@ -18,7 +18,7 @@
#
#
from typing import Callable, Optional
from typing import Callable
import attr
@ -503,7 +503,7 @@ class RoomVersionCapability:
"""An object which describes the unique attributes of a room version."""
identifier: str # the identifier for this capability
preferred_version: Optional[RoomVersion]
preferred_version: RoomVersion | None
support_check_lambda: Callable[[RoomVersion], bool]

View file

@ -24,7 +24,6 @@
import hmac
import urllib.parse
from hashlib import sha256
from typing import Optional
from urllib.parse import urlencode, urljoin
from synapse.config import ConfigError
@ -75,7 +74,7 @@ class LoginSSORedirectURIBuilder:
self._public_baseurl = hs_config.server.public_baseurl
def build_login_sso_redirect_uri(
self, *, idp_id: Optional[str], client_redirect_url: str
self, *, idp_id: str | None, client_redirect_url: str
) -> str:
"""Build a `/login/sso/redirect` URI for the given identity provider.

View file

@ -36,8 +36,6 @@ from typing import (
Awaitable,
Callable,
NoReturn,
Optional,
Union,
cast,
)
from wsgiref.simple_server import WSGIServer
@ -180,8 +178,8 @@ def start_worker_reactor(
def start_reactor(
appname: str,
soft_file_limit: int,
gc_thresholds: Optional[tuple[int, int, int]],
pid_file: Optional[str],
gc_thresholds: tuple[int, int, int] | None,
pid_file: str | None,
daemonize: bool,
print_pidfile: bool,
logger: logging.Logger,
@ -421,7 +419,7 @@ def listen_http(
root_resource: Resource,
version_string: str,
max_request_body_size: int,
context_factory: Optional[IOpenSSLContextFactory],
context_factory: IOpenSSLContextFactory | None,
reactor: ISynapseReactor = reactor,
) -> list[Port]:
"""
@ -564,9 +562,7 @@ def setup_sighup_handling() -> None:
if _already_setup_sighup_handling:
return
previous_sighup_handler: Union[
Callable[[int, Optional[FrameType]], Any], int, None
] = None
previous_sighup_handler: Callable[[int, FrameType | None], Any] | int | None = None
# Set up the SIGHUP machinery.
if hasattr(signal, "SIGHUP"):

View file

@ -24,7 +24,7 @@ import logging
import os
import sys
import tempfile
from typing import Mapping, Optional, Sequence
from typing import Mapping, Sequence
from twisted.internet import defer, task
@ -136,7 +136,7 @@ class FileExfiltrationWriter(ExfiltrationWriter):
to a temporary directory.
"""
def __init__(self, user_id: str, directory: Optional[str] = None):
def __init__(self, user_id: str, directory: str | None = None):
self.user_id = user_id
if directory:
@ -291,7 +291,7 @@ def load_config(argv_options: list[str]) -> tuple[HomeServerConfig, argparse.Nam
def create_homeserver(
config: HomeServerConfig,
reactor: Optional[ISynapseReactor] = None,
reactor: ISynapseReactor | None = None,
) -> AdminCmdServer:
"""
Create a homeserver instance for the Synapse admin command process.

View file

@ -26,7 +26,7 @@ import os
import signal
import sys
from types import FrameType
from typing import Any, Callable, Optional
from typing import Any, Callable
from twisted.internet.main import installReactor
@ -172,7 +172,7 @@ def main() -> None:
# Install signal handlers to propagate signals to all our children, so that they
# shut down cleanly. This also inhibits our own exit, but that's good: we want to
# wait until the children have exited.
def handle_signal(signum: int, frame: Optional[FrameType]) -> None:
def handle_signal(signum: int, frame: FrameType | None) -> None:
print(
f"complement_fork_starter: Caught signal {signum}. Stopping children.",
file=sys.stderr,

View file

@ -21,7 +21,6 @@
#
import logging
import sys
from typing import Optional
from twisted.web.resource import Resource
@ -336,7 +335,7 @@ def load_config(argv_options: list[str]) -> HomeServerConfig:
def create_homeserver(
config: HomeServerConfig,
reactor: Optional[ISynapseReactor] = None,
reactor: ISynapseReactor | None = None,
) -> GenericWorkerServer:
"""
Create a homeserver instance for the Synapse worker process.

View file

@ -22,7 +22,7 @@
import logging
import os
import sys
from typing import Iterable, Optional
from typing import Iterable
from twisted.internet.tcp import Port
from twisted.web.resource import EncodingResourceWrapper, Resource
@ -350,7 +350,7 @@ def load_or_generate_config(argv_options: list[str]) -> HomeServerConfig:
def create_homeserver(
config: HomeServerConfig,
reactor: Optional[ISynapseReactor] = None,
reactor: ISynapseReactor | None = None,
) -> SynapseHomeServer:
"""
Create a homeserver instance for the Synapse main process.

View file

@ -26,7 +26,6 @@ from enum import Enum
from typing import (
TYPE_CHECKING,
Iterable,
Optional,
Pattern,
Sequence,
cast,
@ -95,12 +94,12 @@ class ApplicationService:
token: str,
id: str,
sender: UserID,
url: Optional[str] = None,
namespaces: Optional[JsonDict] = None,
hs_token: Optional[str] = None,
protocols: Optional[Iterable[str]] = None,
url: str | None = None,
namespaces: JsonDict | None = None,
hs_token: str | None = None,
protocols: Iterable[str] | None = None,
rate_limited: bool = True,
ip_range_whitelist: Optional[IPSet] = None,
ip_range_whitelist: IPSet | None = None,
supports_ephemeral: bool = False,
msc3202_transaction_extensions: bool = False,
msc4190_device_management: bool = False,
@ -142,7 +141,7 @@ class ApplicationService:
self.rate_limited = rate_limited
def _check_namespaces(
self, namespaces: Optional[JsonDict]
self, namespaces: JsonDict | None
) -> dict[str, list[Namespace]]:
# Sanity check that it is of the form:
# {
@ -179,9 +178,7 @@ class ApplicationService:
return result
def _matches_regex(
self, namespace_key: str, test_string: str
) -> Optional[Namespace]:
def _matches_regex(self, namespace_key: str, test_string: str) -> Namespace | None:
for namespace in self.namespaces[namespace_key]:
if namespace.regex.match(test_string):
return namespace

View file

@ -25,10 +25,8 @@ from typing import (
TYPE_CHECKING,
Iterable,
Mapping,
Optional,
Sequence,
TypeVar,
Union,
)
from prometheus_client import Counter
@ -222,7 +220,7 @@ class ApplicationServiceApi(SimpleHttpClient):
assert service.hs_token is not None
try:
args: Mapping[bytes, Union[list[bytes], str]] = fields
args: Mapping[bytes, list[bytes] | str] = fields
if self.config.use_appservice_legacy_authorization:
args = {
**fields,
@ -258,11 +256,11 @@ class ApplicationServiceApi(SimpleHttpClient):
async def get_3pe_protocol(
self, service: "ApplicationService", protocol: str
) -> Optional[JsonDict]:
) -> JsonDict | None:
if service.url is None:
return {}
async def _get() -> Optional[JsonDict]:
async def _get() -> JsonDict | None:
# This is required by the configuration.
assert service.hs_token is not None
try:
@ -300,7 +298,7 @@ class ApplicationServiceApi(SimpleHttpClient):
key = (service.id, protocol)
return await self.protocol_meta_cache.wrap(key, _get)
async def ping(self, service: "ApplicationService", txn_id: Optional[str]) -> None:
async def ping(self, service: "ApplicationService", txn_id: str | None) -> None:
# The caller should check that url is set
assert service.url is not None, "ping called without URL being set"
@ -322,7 +320,7 @@ class ApplicationServiceApi(SimpleHttpClient):
one_time_keys_count: TransactionOneTimeKeysCount,
unused_fallback_keys: TransactionUnusedFallbackKeys,
device_list_summary: DeviceListUpdates,
txn_id: Optional[int] = None,
txn_id: int | None = None,
) -> bool:
"""
Push data to an application service.

View file

@ -62,7 +62,6 @@ from typing import (
Callable,
Collection,
Iterable,
Optional,
Sequence,
)
@ -123,10 +122,10 @@ class ApplicationServiceScheduler:
def enqueue_for_appservice(
self,
appservice: ApplicationService,
events: Optional[Collection[EventBase]] = None,
ephemeral: Optional[Collection[JsonMapping]] = None,
to_device_messages: Optional[Collection[JsonMapping]] = None,
device_list_summary: Optional[DeviceListUpdates] = None,
events: Collection[EventBase] | None = None,
ephemeral: Collection[JsonMapping] | None = None,
to_device_messages: Collection[JsonMapping] | None = None,
device_list_summary: DeviceListUpdates | None = None,
) -> None:
"""
Enqueue some data to be sent off to an application service.
@ -260,8 +259,8 @@ class _ServiceQueuer:
):
return
one_time_keys_count: Optional[TransactionOneTimeKeysCount] = None
unused_fallback_keys: Optional[TransactionUnusedFallbackKeys] = None
one_time_keys_count: TransactionOneTimeKeysCount | None = None
unused_fallback_keys: TransactionUnusedFallbackKeys | None = None
if (
self._msc3202_transaction_extensions_enabled
@ -369,11 +368,11 @@ class _TransactionController:
self,
service: ApplicationService,
events: Sequence[EventBase],
ephemeral: Optional[list[JsonMapping]] = None,
to_device_messages: Optional[list[JsonMapping]] = None,
one_time_keys_count: Optional[TransactionOneTimeKeysCount] = None,
unused_fallback_keys: Optional[TransactionUnusedFallbackKeys] = None,
device_list_summary: Optional[DeviceListUpdates] = None,
ephemeral: list[JsonMapping] | None = None,
to_device_messages: list[JsonMapping] | None = None,
one_time_keys_count: TransactionOneTimeKeysCount | None = None,
unused_fallback_keys: TransactionUnusedFallbackKeys | None = None,
device_list_summary: DeviceListUpdates | None = None,
) -> None:
"""
Create a transaction with the given data and send to the provided
@ -504,7 +503,7 @@ class _Recoverer:
self.service = service
self.callback = callback
self.backoff_counter = 1
self.scheduled_recovery: Optional[IDelayedCall] = None
self.scheduled_recovery: IDelayedCall | None = None
def recover(self) -> None:
delay = 2**self.backoff_counter

View file

@ -36,9 +36,7 @@ from typing import (
Iterable,
Iterator,
MutableMapping,
Optional,
TypeVar,
Union,
)
import attr
@ -60,7 +58,7 @@ class ConfigError(Exception):
the problem lies.
"""
def __init__(self, msg: str, path: Optional[StrSequence] = None):
def __init__(self, msg: str, path: StrSequence | None = None):
self.msg = msg
self.path = path
@ -175,7 +173,7 @@ class Config:
)
@staticmethod
def parse_size(value: Union[str, int]) -> int:
def parse_size(value: str | int) -> int:
"""Interpret `value` as a number of bytes.
If an integer is provided it is treated as bytes and is unchanged.
@ -202,7 +200,7 @@ class Config:
raise TypeError(f"Bad byte size {value!r}")
@staticmethod
def parse_duration(value: Union[str, int]) -> int:
def parse_duration(value: str | int) -> int:
"""Convert a duration as a string or integer to a number of milliseconds.
If an integer is provided it is treated as milliseconds and is unchanged.
@ -270,7 +268,7 @@ class Config:
return path_exists(file_path)
@classmethod
def check_file(cls, file_path: Optional[str], config_name: str) -> str:
def check_file(cls, file_path: str | None, config_name: str) -> str:
if file_path is None:
raise ConfigError("Missing config for %s." % (config_name,))
try:
@ -318,7 +316,7 @@ class Config:
def read_templates(
self,
filenames: list[str],
custom_template_directories: Optional[Iterable[str]] = None,
custom_template_directories: Iterable[str] | None = None,
) -> list[jinja2.Template]:
"""Load a list of template files from disk using the given variables.
@ -465,11 +463,11 @@ class RootConfig:
data_dir_path: str,
server_name: str,
generate_secrets: bool = False,
report_stats: Optional[bool] = None,
report_stats: bool | None = None,
open_private_ports: bool = False,
listeners: Optional[list[dict]] = None,
tls_certificate_path: Optional[str] = None,
tls_private_key_path: Optional[str] = None,
listeners: list[dict] | None = None,
tls_certificate_path: str | None = None,
tls_private_key_path: str | None = None,
) -> str:
"""
Build a default configuration file
@ -655,7 +653,7 @@ class RootConfig:
@classmethod
def load_or_generate_config(
cls: type[TRootConfig], description: str, argv_options: list[str]
) -> Optional[TRootConfig]:
) -> TRootConfig | None:
"""Parse the commandline and config files
Supports generation of config files, so is used for the main homeserver app.
@ -898,7 +896,7 @@ class RootConfig:
:returns: the previous config object, which no longer has a reference to this
RootConfig.
"""
existing_config: Optional[Config] = getattr(self, section_name, None)
existing_config: Config | None = getattr(self, section_name, None)
if existing_config is None:
raise ValueError(f"Unknown config section '{section_name}'")
logger.info("Reloading config section '%s'", section_name)

View file

@ -6,9 +6,7 @@ from typing import (
Iterator,
Literal,
MutableMapping,
Optional,
TypeVar,
Union,
overload,
)
@ -64,7 +62,7 @@ from synapse.config import ( # noqa: F401
from synapse.types import StrSequence
class ConfigError(Exception):
def __init__(self, msg: str, path: Optional[StrSequence] = None):
def __init__(self, msg: str, path: StrSequence | None = None):
self.msg = msg
self.path = path
@ -146,16 +144,16 @@ class RootConfig:
data_dir_path: str,
server_name: str,
generate_secrets: bool = ...,
report_stats: Optional[bool] = ...,
report_stats: bool | None = ...,
open_private_ports: bool = ...,
listeners: Optional[Any] = ...,
tls_certificate_path: Optional[str] = ...,
tls_private_key_path: Optional[str] = ...,
listeners: Any | None = ...,
tls_certificate_path: str | None = ...,
tls_private_key_path: str | None = ...,
) -> str: ...
@classmethod
def load_or_generate_config(
cls: type[TRootConfig], description: str, argv_options: list[str]
) -> Optional[TRootConfig]: ...
) -> TRootConfig | None: ...
@classmethod
def load_config(
cls: type[TRootConfig], description: str, argv_options: list[str]
@ -183,11 +181,11 @@ class Config:
default_template_dir: str
def __init__(self, root_config: RootConfig = ...) -> None: ...
@staticmethod
def parse_size(value: Union[str, int]) -> int: ...
def parse_size(value: str | int) -> int: ...
@staticmethod
def parse_duration(value: Union[str, int]) -> int: ...
def parse_duration(value: str | int) -> int: ...
@staticmethod
def abspath(file_path: Optional[str]) -> str: ...
def abspath(file_path: str | None) -> str: ...
@classmethod
def path_exists(cls, file_path: str) -> bool: ...
@classmethod
@ -200,7 +198,7 @@ class Config:
def read_templates(
self,
filenames: list[str],
custom_template_directories: Optional[Iterable[str]] = None,
custom_template_directories: Iterable[str] | None = None,
) -> list[jinja2.Template]: ...
def read_config_files(config_files: Iterable[str]) -> dict[str, Any]: ...

View file

@ -20,7 +20,7 @@
#
import logging
from typing import Any, Iterable, Optional
from typing import Any, Iterable
from synapse.api.constants import EventTypes
from synapse.config._base import Config, ConfigError
@ -46,7 +46,7 @@ class ApiConfig(Config):
def _get_prejoin_state_entries(
self, config: JsonDict
) -> Iterable[tuple[str, Optional[str]]]:
) -> Iterable[tuple[str, str | None]]:
"""Get the event types and state keys to include in the prejoin state."""
room_prejoin_state_config = config.get("room_prejoin_state") or {}

View file

@ -23,7 +23,7 @@ import logging
import os
import re
import threading
from typing import Any, Callable, Mapping, Optional
from typing import Any, Callable, Mapping
import attr
@ -53,7 +53,7 @@ class CacheProperties:
default_factor_size: float = float(
os.environ.get(_CACHE_PREFIX, _DEFAULT_FACTOR_SIZE)
)
resize_all_caches_func: Optional[Callable[[], None]] = None
resize_all_caches_func: Callable[[], None] | None = None
properties = CacheProperties()
@ -107,7 +107,7 @@ class CacheConfig(Config):
cache_factors: dict[str, float]
global_factor: float
track_memory_usage: bool
expiry_time_msec: Optional[int]
expiry_time_msec: int | None
sync_response_cache_duration: int
@staticmethod

View file

@ -20,7 +20,7 @@
#
#
from typing import Any, Optional
from typing import Any
from synapse.config.sso import SsoAttributeRequirement
from synapse.types import JsonDict
@ -49,7 +49,7 @@ class CasConfig(Config):
# TODO Update this to a _synapse URL.
public_baseurl = self.root.server.public_baseurl
self.cas_service_url: Optional[str] = (
self.cas_service_url: str | None = (
public_baseurl + "_matrix/client/r0/login/cas/ticket"
)

View file

@ -19,7 +19,7 @@
#
from os import path
from typing import Any, Optional
from typing import Any
from synapse.config import ConfigError
from synapse.types import JsonDict
@ -33,11 +33,11 @@ class ConsentConfig(Config):
def __init__(self, *args: Any):
super().__init__(*args)
self.user_consent_version: Optional[str] = None
self.user_consent_template_dir: Optional[str] = None
self.user_consent_server_notice_content: Optional[JsonDict] = None
self.user_consent_version: str | None = None
self.user_consent_template_dir: str | None = None
self.user_consent_server_notice_content: JsonDict | None = None
self.user_consent_server_notice_to_guests = False
self.block_events_without_consent_error: Optional[str] = None
self.block_events_without_consent_error: str | None = None
self.user_consent_at_registration = False
self.user_consent_policy_name = "Privacy Policy"

View file

@ -59,7 +59,7 @@ class ClientAuthMethod(enum.Enum):
PRIVATE_KEY_JWT = "private_key_jwt"
def _parse_jwks(jwks: Optional[JsonDict]) -> Optional["JsonWebKey"]:
def _parse_jwks(jwks: JsonDict | None) -> Optional["JsonWebKey"]:
"""A helper function to parse a JWK dict into a JsonWebKey."""
if jwks is None:
@ -71,7 +71,7 @@ def _parse_jwks(jwks: Optional[JsonDict]) -> Optional["JsonWebKey"]:
def _check_client_secret(
instance: "MSC3861", _attribute: attr.Attribute, _value: Optional[str]
instance: "MSC3861", _attribute: attr.Attribute, _value: str | None
) -> None:
if instance._client_secret and instance._client_secret_path:
raise ConfigError(
@ -88,7 +88,7 @@ def _check_client_secret(
def _check_admin_token(
instance: "MSC3861", _attribute: attr.Attribute, _value: Optional[str]
instance: "MSC3861", _attribute: attr.Attribute, _value: str | None
) -> None:
if instance._admin_token and instance._admin_token_path:
raise ConfigError(
@ -124,7 +124,7 @@ class MSC3861:
issuer: str = attr.ib(default="", validator=attr.validators.instance_of(str))
"""The URL of the OIDC Provider."""
issuer_metadata: Optional[JsonDict] = attr.ib(default=None)
issuer_metadata: JsonDict | None = attr.ib(default=None)
"""The issuer metadata to use, otherwise discovered from /.well-known/openid-configuration as per MSC2965."""
client_id: str = attr.ib(
@ -138,7 +138,7 @@ class MSC3861:
)
"""The auth method used when calling the introspection endpoint."""
_client_secret: Optional[str] = attr.ib(
_client_secret: str | None = attr.ib(
default=None,
validator=[
attr.validators.optional(attr.validators.instance_of(str)),
@ -150,7 +150,7 @@ class MSC3861:
when using any of the client_secret_* client auth methods.
"""
_client_secret_path: Optional[str] = attr.ib(
_client_secret_path: str | None = attr.ib(
default=None,
validator=[
attr.validators.optional(attr.validators.instance_of(str)),
@ -196,19 +196,19 @@ class MSC3861:
("experimental", "msc3861", "client_auth_method"),
)
introspection_endpoint: Optional[str] = attr.ib(
introspection_endpoint: str | None = attr.ib(
default=None,
validator=attr.validators.optional(attr.validators.instance_of(str)),
)
"""The URL of the introspection endpoint used to validate access tokens."""
account_management_url: Optional[str] = attr.ib(
account_management_url: str | None = attr.ib(
default=None,
validator=attr.validators.optional(attr.validators.instance_of(str)),
)
"""The URL of the My Account page on the OIDC Provider as per MSC2965."""
_admin_token: Optional[str] = attr.ib(
_admin_token: str | None = attr.ib(
default=None,
validator=[
attr.validators.optional(attr.validators.instance_of(str)),
@ -220,7 +220,7 @@ class MSC3861:
This is used by the OIDC provider, to make admin calls to Synapse.
"""
_admin_token_path: Optional[str] = attr.ib(
_admin_token_path: str | None = attr.ib(
default=None,
validator=[
attr.validators.optional(attr.validators.instance_of(str)),
@ -232,7 +232,7 @@ class MSC3861:
external file.
"""
def client_secret(self) -> Optional[str]:
def client_secret(self) -> str | None:
"""Returns the secret given via `client_secret` or `client_secret_path`."""
if self._client_secret_path:
return read_secret_from_file_once(
@ -241,7 +241,7 @@ class MSC3861:
)
return self._client_secret
def admin_token(self) -> Optional[str]:
def admin_token(self) -> str | None:
"""Returns the admin token given via `admin_token` or `admin_token_path`."""
if self._admin_token_path:
return read_secret_from_file_once(
@ -526,7 +526,7 @@ class ExperimentalConfig(Config):
# MSC4108: Mechanism to allow OIDC sign in and E2EE set up via QR code
self.msc4108_enabled = experimental.get("msc4108_enabled", False)
self.msc4108_delegation_endpoint: Optional[str] = experimental.get(
self.msc4108_delegation_endpoint: str | None = experimental.get(
"msc4108_delegation_endpoint", None
)

View file

@ -18,7 +18,7 @@
# [This file includes modifications made by New Vector Limited]
#
#
from typing import Any, Optional
from typing import Any
from synapse.config._base import Config
from synapse.config._util import validate_config
@ -32,7 +32,7 @@ class FederationConfig(Config):
federation_config = config.setdefault("federation", {})
# FIXME: federation_domain_whitelist needs sytests
self.federation_domain_whitelist: Optional[dict] = None
self.federation_domain_whitelist: dict | None = None
federation_domain_whitelist = config.get("federation_domain_whitelist", None)
if federation_domain_whitelist is not None:

View file

@ -23,7 +23,7 @@
import hashlib
import logging
import os
from typing import TYPE_CHECKING, Any, Iterator, Optional
from typing import TYPE_CHECKING, Any, Iterator
import attr
import jsonschema
@ -110,7 +110,7 @@ class TrustedKeyServer:
server_name: str
# map from key id to key object, or None to disable signature verification.
verify_keys: Optional[dict[str, VerifyKey]] = None
verify_keys: dict[str, VerifyKey] | None = None
class KeyConfig(Config):
@ -219,7 +219,7 @@ class KeyConfig(Config):
if form_secret_path:
if form_secret:
raise ConfigError(CONFLICTING_FORM_SECRET_OPTS_ERROR)
self.form_secret: Optional[str] = read_file(
self.form_secret: str | None = read_file(
form_secret_path, ("form_secret_path",)
).strip()
else:
@ -279,7 +279,7 @@ class KeyConfig(Config):
raise ConfigError("Error reading %s: %s" % (name, str(e)))
def read_old_signing_keys(
self, old_signing_keys: Optional[JsonDict]
self, old_signing_keys: JsonDict | None
) -> dict[str, "VerifyKeyWithExpiry"]:
if old_signing_keys is None:
return {}
@ -408,7 +408,7 @@ def _parse_key_servers(
server_name = server["server_name"]
result = TrustedKeyServer(server_name=server_name)
verify_keys: Optional[dict[str, str]] = server.get("verify_keys")
verify_keys: dict[str, str] | None = server.get("verify_keys")
if verify_keys is not None:
result.verify_keys = {}
for key_id, key_base64 in verify_keys.items():

View file

@ -26,7 +26,7 @@ import os
import sys
import threading
from string import Template
from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Any
import yaml
from zope.interface import implementer
@ -280,7 +280,7 @@ def one_time_logging_setup(*, logBeginner: LogBeginner = globalLogBeginner) -> N
def _setup_stdlib_logging(
config: "HomeServerConfig", log_config_path: Optional[str]
config: "HomeServerConfig", log_config_path: str | None
) -> None:
"""
Set up Python standard library logging.
@ -327,7 +327,7 @@ def _load_logging_config(log_config_path: str) -> None:
reset_logging_config()
def _reload_logging_config(log_config_path: Optional[str]) -> None:
def _reload_logging_config(log_config_path: str | None) -> None:
"""
Reload the log configuration from the file and apply it.
"""

View file

@ -13,7 +13,7 @@
#
#
from typing import Any, Optional
from typing import Any
from pydantic import (
AnyHttpUrl,
@ -36,8 +36,8 @@ from ._base import Config, ConfigError, RootConfig
class MasConfigModel(ParseModel):
enabled: StrictBool = False
endpoint: AnyHttpUrl = AnyHttpUrl("http://localhost:8080")
secret: Optional[StrictStr] = Field(default=None)
secret_path: Optional[FilePath] = Field(default=None)
secret: StrictStr | None = Field(default=None)
secret_path: FilePath | None = Field(default=None)
@model_validator(mode="after")
def verify_secret(self) -> Self:

View file

@ -15,7 +15,7 @@
#
#
from typing import Any, Optional
from typing import Any
from pydantic import Field, StrictStr, ValidationError, model_validator
from typing_extensions import Self
@ -29,7 +29,7 @@ from ._base import Config, ConfigError
class TransportConfigModel(ParseModel):
type: StrictStr
livekit_service_url: Optional[StrictStr] = Field(default=None)
livekit_service_url: StrictStr | None = Field(default=None)
"""An optional livekit service URL. Only required if type is "livekit"."""
@model_validator(mode="after")

View file

@ -20,7 +20,7 @@
#
#
from typing import Any, Optional
from typing import Any
import attr
@ -75,7 +75,7 @@ class MetricsConfig(Config):
)
def generate_config_section(
self, report_stats: Optional[bool] = None, **kwargs: Any
self, report_stats: bool | None = None, **kwargs: Any
) -> str:
if report_stats is not None:
res = "report_stats: %s\n" % ("true" if report_stats else "false")

View file

@ -21,7 +21,7 @@
import importlib.resources as importlib_resources
import json
import re
from typing import Any, Iterable, Optional, Pattern
from typing import Any, Iterable, Pattern
from urllib import parse as urlparse
import attr
@ -39,7 +39,7 @@ class OEmbedEndpointConfig:
# The patterns to match.
url_patterns: list[Pattern[str]]
# The supported formats.
formats: Optional[list[str]]
formats: list[str] | None
class OembedConfig(Config):

View file

@ -21,7 +21,7 @@
#
from collections import Counter
from typing import Any, Collection, Iterable, Mapping, Optional
from typing import Any, Collection, Iterable, Mapping
import attr
@ -276,7 +276,7 @@ def _parse_oidc_config_dict(
) from e
client_secret_jwt_key_config = oidc_config.get("client_secret_jwt_key")
client_secret_jwt_key: Optional[OidcProviderClientSecretJwtKey] = None
client_secret_jwt_key: OidcProviderClientSecretJwtKey | None = None
if client_secret_jwt_key_config is not None:
keyfile = client_secret_jwt_key_config.get("key_file")
if keyfile:
@ -384,10 +384,10 @@ class OidcProviderConfig:
idp_name: str
# Optional MXC URI for icon for this IdP.
idp_icon: Optional[str]
idp_icon: str | None
# Optional brand identifier for this IdP.
idp_brand: Optional[str]
idp_brand: str | None
# whether the OIDC discovery mechanism is used to discover endpoints
discover: bool
@ -401,11 +401,11 @@ class OidcProviderConfig:
# oauth2 client secret to use. if `None`, use client_secret_jwt_key to generate
# a secret.
client_secret: Optional[str]
client_secret: str | None
# key to use to construct a JWT to use as a client secret. May be `None` if
# `client_secret` is set.
client_secret_jwt_key: Optional[OidcProviderClientSecretJwtKey]
client_secret_jwt_key: OidcProviderClientSecretJwtKey | None
# auth method to use when exchanging the token.
# Valid values are 'client_secret_basic', 'client_secret_post' and
@ -416,7 +416,7 @@ class OidcProviderConfig:
# Valid values are 'auto', 'always', and 'never'.
pkce_method: str
id_token_signing_alg_values_supported: Optional[list[str]]
id_token_signing_alg_values_supported: list[str] | None
"""
List of the JWS signing algorithms (`alg` values) that are supported for signing the
`id_token`.
@ -448,18 +448,18 @@ class OidcProviderConfig:
scopes: Collection[str]
# the oauth2 authorization endpoint. Required if discovery is disabled.
authorization_endpoint: Optional[str]
authorization_endpoint: str | None
# the oauth2 token endpoint. Required if discovery is disabled.
token_endpoint: Optional[str]
token_endpoint: str | None
# the OIDC userinfo endpoint. Required if discovery is disabled and the
# "openid" scope is not requested.
userinfo_endpoint: Optional[str]
userinfo_endpoint: str | None
# URI where to fetch the JWKS. Required if discovery is disabled and the
# "openid" scope is used.
jwks_uri: Optional[str]
jwks_uri: str | None
# Whether Synapse should react to backchannel logouts
backchannel_logout_enabled: bool
@ -474,7 +474,7 @@ class OidcProviderConfig:
# values are: "auto" or "userinfo_endpoint".
user_profile_method: str
redirect_uri: Optional[str]
redirect_uri: str | None
"""
An optional replacement for Synapse's hardcoded `redirect_uri` URL
(`<public_baseurl>/_synapse/client/oidc/callback`). This can be used to send

View file

@ -19,7 +19,7 @@
#
#
from typing import Any, Optional, cast
from typing import Any, cast
import attr
@ -39,7 +39,7 @@ class RatelimitSettings:
cls,
config: dict[str, Any],
key: str,
defaults: Optional[dict[str, float]] = None,
defaults: dict[str, float] | None = None,
) -> "RatelimitSettings":
"""Parse config[key] as a new-style rate limiter config.

View file

@ -20,7 +20,7 @@
#
#
import argparse
from typing import Any, Optional
from typing import Any
from synapse.api.constants import RoomCreationPreset
from synapse.config._base import Config, ConfigError, read_file
@ -181,7 +181,7 @@ class RegistrationConfig(Config):
refreshable_access_token_lifetime = self.parse_duration(
refreshable_access_token_lifetime
)
self.refreshable_access_token_lifetime: Optional[int] = (
self.refreshable_access_token_lifetime: int | None = (
refreshable_access_token_lifetime
)
@ -226,7 +226,7 @@ class RegistrationConfig(Config):
refresh_token_lifetime = config.get("refresh_token_lifetime")
if refresh_token_lifetime is not None:
refresh_token_lifetime = self.parse_duration(refresh_token_lifetime)
self.refresh_token_lifetime: Optional[int] = refresh_token_lifetime
self.refresh_token_lifetime: int | None = refresh_token_lifetime
if (
self.session_lifetime is not None

View file

@ -20,7 +20,7 @@
#
import logging
from typing import Any, Optional
from typing import Any
import attr
@ -35,8 +35,8 @@ class RetentionPurgeJob:
"""Object describing the configuration of the manhole"""
interval: int
shortest_max_lifetime: Optional[int]
longest_max_lifetime: Optional[int]
shortest_max_lifetime: int | None
longest_max_lifetime: int | None
class RetentionConfig(Config):

View file

@ -25,7 +25,7 @@ import logging
import os.path
import urllib.parse
from textwrap import indent
from typing import Any, Iterable, Optional, TypedDict, Union
from typing import Any, Iterable, TypedDict
from urllib.request import getproxies_environment
import attr
@ -95,9 +95,9 @@ def _6to4(network: IPNetwork) -> IPNetwork:
def generate_ip_set(
ip_addresses: Optional[Iterable[str]],
extra_addresses: Optional[Iterable[str]] = None,
config_path: Optional[StrSequence] = None,
ip_addresses: Iterable[str] | None,
extra_addresses: Iterable[str] | None = None,
config_path: StrSequence | None = None,
) -> IPSet:
"""
Generate an IPSet from a list of IP addresses or CIDRs.
@ -230,8 +230,8 @@ class HttpListenerConfig:
x_forwarded: bool = False
resources: list[HttpResourceConfig] = attr.Factory(list)
additional_resources: dict[str, dict] = attr.Factory(dict)
tag: Optional[str] = None
request_id_header: Optional[str] = None
tag: str | None = None
request_id_header: str | None = None
@attr.s(slots=True, frozen=True, auto_attribs=True)
@ -244,7 +244,7 @@ class TCPListenerConfig:
tls: bool = False
# http_options is only populated if type=http
http_options: Optional[HttpListenerConfig] = None
http_options: HttpListenerConfig | None = None
def get_site_tag(self) -> str:
"""Retrieves http_options.tag if it exists, otherwise the port number."""
@ -269,7 +269,7 @@ class UnixListenerConfig:
type: str = attr.ib(validator=attr.validators.in_(KNOWN_LISTENER_TYPES))
# http_options is only populated if type=http
http_options: Optional[HttpListenerConfig] = None
http_options: HttpListenerConfig | None = None
def get_site_tag(self) -> str:
return "unix"
@ -279,7 +279,7 @@ class UnixListenerConfig:
return False
ListenerConfig = Union[TCPListenerConfig, UnixListenerConfig]
ListenerConfig = TCPListenerConfig | UnixListenerConfig
@attr.s(slots=True, frozen=True, auto_attribs=True)
@ -288,14 +288,14 @@ class ManholeConfig:
username: str = attr.ib(validator=attr.validators.instance_of(str))
password: str = attr.ib(validator=attr.validators.instance_of(str))
priv_key: Optional[Key]
pub_key: Optional[Key]
priv_key: Key | None
pub_key: Key | None
@attr.s(frozen=True)
class LimitRemoteRoomsConfig:
enabled: bool = attr.ib(validator=attr.validators.instance_of(bool), default=False)
complexity: Union[float, int] = attr.ib(
complexity: float | int = attr.ib(
validator=attr.validators.instance_of((float, int)), # noqa
default=1.0,
)
@ -313,11 +313,11 @@ class ProxyConfigDictionary(TypedDict):
Dictionary of proxy settings suitable for interacting with `urllib.request` API's
"""
http: Optional[str]
http: str | None
"""
Proxy server to use for HTTP requests.
"""
https: Optional[str]
https: str | None
"""
Proxy server to use for HTTPS requests.
"""
@ -336,15 +336,15 @@ class ProxyConfig:
Synapse configuration for HTTP proxy settings.
"""
http_proxy: Optional[str]
http_proxy: str | None
"""
Proxy server to use for HTTP requests.
"""
https_proxy: Optional[str]
https_proxy: str | None
"""
Proxy server to use for HTTPS requests.
"""
no_proxy_hosts: Optional[list[str]]
no_proxy_hosts: list[str] | None
"""
List of hosts, IP addresses, or IP ranges in CIDR format which should not use the
proxy. Synapse will directly connect to these hosts.
@ -607,7 +607,7 @@ class ServerConfig(Config):
# before redacting them.
redaction_retention_period = config.get("redaction_retention_period", "7d")
if redaction_retention_period is not None:
self.redaction_retention_period: Optional[int] = self.parse_duration(
self.redaction_retention_period: int | None = self.parse_duration(
redaction_retention_period
)
else:
@ -618,7 +618,7 @@ class ServerConfig(Config):
"forgotten_room_retention_period", None
)
if forgotten_room_retention_period is not None:
self.forgotten_room_retention_period: Optional[int] = self.parse_duration(
self.forgotten_room_retention_period: int | None = self.parse_duration(
forgotten_room_retention_period
)
else:
@ -627,7 +627,7 @@ class ServerConfig(Config):
# How long to keep entries in the `users_ips` table.
user_ips_max_age = config.get("user_ips_max_age", "28d")
if user_ips_max_age is not None:
self.user_ips_max_age: Optional[int] = self.parse_duration(user_ips_max_age)
self.user_ips_max_age: int | None = self.parse_duration(user_ips_max_age)
else:
self.user_ips_max_age = None
@ -864,11 +864,11 @@ class ServerConfig(Config):
)
# Whitelist of domain names that given next_link parameters must have
next_link_domain_whitelist: Optional[list[str]] = config.get(
next_link_domain_whitelist: list[str] | None = config.get(
"next_link_domain_whitelist"
)
self.next_link_domain_whitelist: Optional[set[str]] = None
self.next_link_domain_whitelist: set[str] | None = None
if next_link_domain_whitelist is not None:
if not isinstance(next_link_domain_whitelist, list):
raise ConfigError("'next_link_domain_whitelist' must be a list")
@ -880,7 +880,7 @@ class ServerConfig(Config):
if not isinstance(templates_config, dict):
raise ConfigError("The 'templates' section must be a dictionary")
self.custom_template_directory: Optional[str] = templates_config.get(
self.custom_template_directory: str | None = templates_config.get(
"custom_template_directory"
)
if self.custom_template_directory is not None and not isinstance(
@ -896,12 +896,12 @@ class ServerConfig(Config):
config.get("exclude_rooms_from_sync") or []
)
delete_stale_devices_after: Optional[str] = (
delete_stale_devices_after: str | None = (
config.get("delete_stale_devices_after") or None
)
if delete_stale_devices_after is not None:
self.delete_stale_devices_after: Optional[int] = self.parse_duration(
self.delete_stale_devices_after: int | None = self.parse_duration(
delete_stale_devices_after
)
else:
@ -910,7 +910,7 @@ class ServerConfig(Config):
# The maximum allowed delay duration for delayed events (MSC4140).
max_event_delay_duration = config.get("max_event_delay_duration")
if max_event_delay_duration is not None:
self.max_event_delay_ms: Optional[int] = self.parse_duration(
self.max_event_delay_ms: int | None = self.parse_duration(
max_event_delay_duration
)
if self.max_event_delay_ms <= 0:
@ -927,7 +927,7 @@ class ServerConfig(Config):
data_dir_path: str,
server_name: str,
open_private_ports: bool,
listeners: Optional[list[dict]],
listeners: list[dict] | None,
**kwargs: Any,
) -> str:
_, bind_port = parse_and_validate_server_name(server_name)
@ -1028,7 +1028,7 @@ class ServerConfig(Config):
help="Turn on the twisted telnet manhole service on the given port.",
)
def read_gc_intervals(self, durations: Any) -> Optional[tuple[float, float, float]]:
def read_gc_intervals(self, durations: Any) -> tuple[float, float, float] | None:
"""Reads the three durations for the GC min interval option, returning seconds."""
if durations is None:
return None
@ -1066,8 +1066,8 @@ def is_threepid_reserved(
def read_gc_thresholds(
thresholds: Optional[list[Any]],
) -> Optional[tuple[int, int, int]]:
thresholds: list[Any] | None,
) -> tuple[int, int, int] | None:
"""Reads the three integer thresholds for garbage collection. Ensures that
the thresholds are integers if thresholds are supplied.
"""

View file

@ -18,7 +18,7 @@
#
#
from typing import Any, Optional
from typing import Any
from synapse.types import JsonDict, UserID
@ -58,12 +58,12 @@ class ServerNoticesConfig(Config):
def __init__(self, *args: Any):
super().__init__(*args)
self.server_notices_mxid: Optional[str] = None
self.server_notices_mxid_display_name: Optional[str] = None
self.server_notices_mxid_avatar_url: Optional[str] = None
self.server_notices_room_name: Optional[str] = None
self.server_notices_room_avatar_url: Optional[str] = None
self.server_notices_room_topic: Optional[str] = None
self.server_notices_mxid: str | None = None
self.server_notices_mxid_display_name: str | None = None
self.server_notices_mxid_avatar_url: str | None = None
self.server_notices_room_name: str | None = None
self.server_notices_room_avatar_url: str | None = None
self.server_notices_room_topic: str | None = None
self.server_notices_auto_join: bool = False
def read_config(self, config: JsonDict, **kwargs: Any) -> None:

View file

@ -19,7 +19,7 @@
#
#
import logging
from typing import Any, Optional
from typing import Any
import attr
@ -44,8 +44,8 @@ class SsoAttributeRequirement:
attribute: str
# If neither `value` nor `one_of` is given, the attribute must simply exist.
value: Optional[str] = None
one_of: Optional[list[str]] = None
value: str | None = None
one_of: list[str] | None = None
JSON_SCHEMA = {
"type": "object",

View file

@ -20,7 +20,7 @@
#
import logging
from typing import Any, Optional, Pattern
from typing import Any, Pattern
from matrix_common.regex import glob_to_regex
@ -135,8 +135,8 @@ class TlsConfig(Config):
"use_insecure_ssl_client_just_for_testing_do_not_use"
)
self.tls_certificate: Optional[crypto.X509] = None
self.tls_private_key: Optional[crypto.PKey] = None
self.tls_certificate: crypto.X509 | None = None
self.tls_private_key: crypto.PKey | None = None
def read_certificate_from_disk(self) -> None:
"""
@ -147,8 +147,8 @@ class TlsConfig(Config):
def generate_config_section(
self,
tls_certificate_path: Optional[str],
tls_private_key_path: Optional[str],
tls_certificate_path: str | None,
tls_private_key_path: str | None,
**kwargs: Any,
) -> str:
"""If the TLS paths are not specified the default will be certs in the

View file

@ -12,7 +12,7 @@
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
from typing import Any, Optional
from typing import Any
from synapse.api.constants import UserTypes
from synapse.types import JsonDict
@ -26,9 +26,7 @@ class UserTypesConfig(Config):
def read_config(self, config: JsonDict, **kwargs: Any) -> None:
user_types: JsonDict = config.get("user_types", {})
self.default_user_type: Optional[str] = user_types.get(
"default_user_type", None
)
self.default_user_type: str | None = user_types.get("default_user_type", None)
self.extra_user_types: list[str] = user_types.get("extra_user_types", [])
all_user_types: list[str] = []

View file

@ -22,7 +22,7 @@
import argparse
import logging
from typing import Any, Optional, Union
from typing import Any
import attr
from pydantic import (
@ -79,7 +79,7 @@ MAIN_PROCESS_INSTANCE_MAP_NAME = "main"
logger = logging.getLogger(__name__)
def _instance_to_list_converter(obj: Union[str, list[str]]) -> list[str]:
def _instance_to_list_converter(obj: str | list[str]) -> list[str]:
"""Helper for allowing parsing a string or list of strings to a config
option expecting a list of strings.
"""
@ -119,7 +119,7 @@ class InstanceUnixLocationConfig(ParseModel):
return f"{self.path}"
InstanceLocationConfig = Union[InstanceTcpLocationConfig, InstanceUnixLocationConfig]
InstanceLocationConfig = InstanceTcpLocationConfig | InstanceUnixLocationConfig
@attr.s
@ -190,7 +190,7 @@ class OutboundFederationRestrictedTo:
locations: list of instance locations to connect to proxy via.
"""
instances: Optional[list[str]]
instances: list[str] | None
locations: list[InstanceLocationConfig] = attr.Factory(list)
def __contains__(self, instance: str) -> bool:
@ -246,7 +246,7 @@ class WorkerConfig(Config):
if worker_replication_secret_path:
if worker_replication_secret:
raise ConfigError(CONFLICTING_WORKER_REPLICATION_SECRET_OPTS_ERROR)
self.worker_replication_secret: Optional[str] = read_file(
self.worker_replication_secret: str | None = read_file(
worker_replication_secret_path, ("worker_replication_secret_path",)
).strip()
else:
@ -341,7 +341,7 @@ class WorkerConfig(Config):
% MAIN_PROCESS_INSTANCE_MAP_NAME
)
# type-ignore: the expression `Union[A, B]` is not a Type[Union[A, B]] currently
# type-ignore: the expression `A | B` is not a `type[A | B]` currently
self.instance_map: dict[str, InstanceLocationConfig] = (
parse_and_validate_mapping(
instance_map,

View file

@ -21,7 +21,7 @@
import abc
import logging
from typing import TYPE_CHECKING, Callable, Iterable, Optional
from typing import TYPE_CHECKING, Callable, Iterable
import attr
from signedjson.key import (
@ -150,7 +150,7 @@ class Keyring:
"""
def __init__(
self, hs: "HomeServer", key_fetchers: "Optional[Iterable[KeyFetcher]]" = None
self, hs: "HomeServer", key_fetchers: "Iterable[KeyFetcher] | None" = None
):
self.server_name = hs.hostname

View file

@ -160,7 +160,7 @@ def validate_event_for_room_version(event: "EventBase") -> None:
async def check_state_independent_auth_rules(
store: _EventSourceStore,
event: "EventBase",
batched_auth_events: Optional[Mapping[str, "EventBase"]] = None,
batched_auth_events: Mapping[str, "EventBase"] | None = None,
) -> None:
"""Check that an event complies with auth rules that are independent of room state
@ -788,7 +788,7 @@ def _check_joined_room(
def get_send_level(
etype: str, state_key: Optional[str], power_levels_event: Optional["EventBase"]
etype: str, state_key: str | None, power_levels_event: Optional["EventBase"]
) -> int:
"""Get the power level required to send an event of a given type
@ -989,7 +989,7 @@ def _check_power_levels(
user_level = get_user_power_level(event.user_id, auth_events)
# Check other levels:
levels_to_check: list[tuple[str, Optional[str]]] = [
levels_to_check: list[tuple[str, str | None]] = [
("users_default", None),
("events_default", None),
("state_default", None),
@ -1027,12 +1027,12 @@ def _check_power_levels(
new_loc = new_loc.get(dir, {})
if level_to_check in old_loc:
old_level: Optional[int] = int(old_loc[level_to_check])
old_level: int | None = int(old_loc[level_to_check])
else:
old_level = None
if level_to_check in new_loc:
new_level: Optional[int] = int(new_loc[level_to_check])
new_level: int | None = int(new_loc[level_to_check])
else:
new_level = None

View file

@ -28,7 +28,6 @@ from typing import (
Generic,
Iterable,
Literal,
Optional,
TypeVar,
Union,
overload,
@ -90,21 +89,21 @@ class DictProperty(Generic[T]):
def __get__(
self,
instance: Literal[None],
owner: Optional[type[_DictPropertyInstance]] = None,
owner: type[_DictPropertyInstance] | None = None,
) -> "DictProperty": ...
@overload
def __get__(
self,
instance: _DictPropertyInstance,
owner: Optional[type[_DictPropertyInstance]] = None,
owner: type[_DictPropertyInstance] | None = None,
) -> T: ...
def __get__(
self,
instance: Optional[_DictPropertyInstance],
owner: Optional[type[_DictPropertyInstance]] = None,
) -> Union[T, "DictProperty"]:
instance: _DictPropertyInstance | None,
owner: type[_DictPropertyInstance] | None = None,
) -> T | "DictProperty":
# if the property is accessed as a class property rather than an instance
# property, return the property itself rather than the value
if instance is None:
@ -156,21 +155,21 @@ class DefaultDictProperty(DictProperty, Generic[T]):
def __get__(
self,
instance: Literal[None],
owner: Optional[type[_DictPropertyInstance]] = None,
owner: type[_DictPropertyInstance] | None = None,
) -> "DefaultDictProperty": ...
@overload
def __get__(
self,
instance: _DictPropertyInstance,
owner: Optional[type[_DictPropertyInstance]] = None,
owner: type[_DictPropertyInstance] | None = None,
) -> T: ...
def __get__(
self,
instance: Optional[_DictPropertyInstance],
owner: Optional[type[_DictPropertyInstance]] = None,
) -> Union[T, "DefaultDictProperty"]:
instance: _DictPropertyInstance | None,
owner: type[_DictPropertyInstance] | None = None,
) -> T | "DefaultDictProperty":
if instance is None:
return self
assert isinstance(instance, EventBase)
@ -191,7 +190,7 @@ class EventBase(metaclass=abc.ABCMeta):
signatures: dict[str, dict[str, str]],
unsigned: JsonDict,
internal_metadata_dict: JsonDict,
rejected_reason: Optional[str],
rejected_reason: str | None,
):
assert room_version.event_format == self.format_version
@ -209,7 +208,7 @@ class EventBase(metaclass=abc.ABCMeta):
hashes: DictProperty[dict[str, str]] = DictProperty("hashes")
origin_server_ts: DictProperty[int] = DictProperty("origin_server_ts")
sender: DictProperty[str] = DictProperty("sender")
# TODO state_key should be Optional[str]. This is generally asserted in Synapse
# TODO state_key should be str | None. This is generally asserted in Synapse
# by calling is_state() first (which ensures it is not None), but it is hard (not possible?)
# to properly annotate that calling is_state() asserts that state_key exists
# and is non-None. It would be better to replace such direct references with
@ -231,7 +230,7 @@ class EventBase(metaclass=abc.ABCMeta):
return self.content["membership"]
@property
def redacts(self) -> Optional[str]:
def redacts(self) -> str | None:
"""MSC2176 moved the redacts field into the content."""
if self.room_version.updated_redaction_rules:
return self.content.get("redacts")
@ -240,7 +239,7 @@ class EventBase(metaclass=abc.ABCMeta):
def is_state(self) -> bool:
return self.get_state_key() is not None
def get_state_key(self) -> Optional[str]:
def get_state_key(self) -> str | None:
"""Get the state key of this event, or None if it's not a state event"""
return self._dict.get("state_key")
@ -250,13 +249,13 @@ class EventBase(metaclass=abc.ABCMeta):
return d
def get(self, key: str, default: Optional[Any] = None) -> Any:
def get(self, key: str, default: Any | None = None) -> Any:
return self._dict.get(key, default)
def get_internal_metadata_dict(self) -> JsonDict:
return self.internal_metadata.get_dict()
def get_pdu_json(self, time_now: Optional[int] = None) -> JsonDict:
def get_pdu_json(self, time_now: int | None = None) -> JsonDict:
pdu_json = self.get_dict()
if time_now is not None and "age_ts" in pdu_json["unsigned"]:
@ -283,13 +282,13 @@ class EventBase(metaclass=abc.ABCMeta):
return template_json
def __getitem__(self, field: str) -> Optional[Any]:
def __getitem__(self, field: str) -> Any | None:
return self._dict[field]
def __contains__(self, field: str) -> bool:
return field in self._dict
def items(self) -> list[tuple[str, Optional[Any]]]:
def items(self) -> list[tuple[str, Any | None]]:
return list(self._dict.items())
def keys(self) -> Iterable[str]:
@ -348,8 +347,8 @@ class FrozenEvent(EventBase):
self,
event_dict: JsonDict,
room_version: RoomVersion,
internal_metadata_dict: Optional[JsonDict] = None,
rejected_reason: Optional[str] = None,
internal_metadata_dict: JsonDict | None = None,
rejected_reason: str | None = None,
):
internal_metadata_dict = internal_metadata_dict or {}
@ -400,8 +399,8 @@ class FrozenEventV2(EventBase):
self,
event_dict: JsonDict,
room_version: RoomVersion,
internal_metadata_dict: Optional[JsonDict] = None,
rejected_reason: Optional[str] = None,
internal_metadata_dict: JsonDict | None = None,
rejected_reason: str | None = None,
):
internal_metadata_dict = internal_metadata_dict or {}
@ -427,7 +426,7 @@ class FrozenEventV2(EventBase):
else:
frozen_dict = event_dict
self._event_id: Optional[str] = None
self._event_id: str | None = None
super().__init__(
frozen_dict,
@ -502,8 +501,8 @@ class FrozenEventV4(FrozenEventV3):
self,
event_dict: JsonDict,
room_version: RoomVersion,
internal_metadata_dict: Optional[JsonDict] = None,
rejected_reason: Optional[str] = None,
internal_metadata_dict: JsonDict | None = None,
rejected_reason: str | None = None,
):
super().__init__(
event_dict=event_dict,
@ -511,7 +510,7 @@ class FrozenEventV4(FrozenEventV3):
internal_metadata_dict=internal_metadata_dict,
rejected_reason=rejected_reason,
)
self._room_id: Optional[str] = None
self._room_id: str | None = None
@property
def room_id(self) -> str:
@ -554,7 +553,7 @@ class FrozenEventV4(FrozenEventV3):
def _event_type_from_format_version(
format_version: int,
) -> type[Union[FrozenEvent, FrozenEventV2, FrozenEventV3]]:
) -> type[FrozenEvent | FrozenEventV2 | FrozenEventV3]:
"""Returns the python type to use to construct an Event object for the
given event format version.
@ -580,8 +579,8 @@ def _event_type_from_format_version(
def make_event_from_dict(
event_dict: JsonDict,
room_version: RoomVersion = RoomVersions.V1,
internal_metadata_dict: Optional[JsonDict] = None,
rejected_reason: Optional[str] = None,
internal_metadata_dict: JsonDict | None = None,
rejected_reason: str | None = None,
) -> EventBase:
"""Construct an EventBase from the given event dict"""
event_type = _event_type_from_format_version(room_version.event_format)
@ -598,10 +597,10 @@ class _EventRelation:
rel_type: str
# The aggregation key. Will be None if the rel_type is not m.annotation or is
# not a string.
aggregation_key: Optional[str]
aggregation_key: str | None
def relation_from_event(event: EventBase) -> Optional[_EventRelation]:
def relation_from_event(event: EventBase) -> _EventRelation | None:
"""
Attempt to parse relation information an event.

View file

@ -19,7 +19,7 @@
#
#
import logging
from typing import TYPE_CHECKING, Any, Optional, Union
from typing import TYPE_CHECKING, Any
import attr
from signedjson.types import SigningKey
@ -83,7 +83,7 @@ class EventBuilder:
room_version: RoomVersion
# MSC4291 makes the room ID == the create event ID. This means the create event has no room_id.
room_id: Optional[str]
room_id: str | None
type: str
sender: str
@ -92,9 +92,9 @@ class EventBuilder:
# These only exist on a subset of events, so they raise AttributeError if
# someone tries to get them when they don't exist.
_state_key: Optional[str] = None
_redacts: Optional[str] = None
_origin_server_ts: Optional[int] = None
_state_key: str | None = None
_redacts: str | None = None
_origin_server_ts: int | None = None
internal_metadata: EventInternalMetadata = attr.Factory(
lambda: EventInternalMetadata({})
@ -126,8 +126,8 @@ class EventBuilder:
async def build(
self,
prev_event_ids: list[str],
auth_event_ids: Optional[list[str]],
depth: Optional[int] = None,
auth_event_ids: list[str] | None,
depth: int | None = None,
) -> EventBase:
"""Transform into a fully signed and hashed event
@ -205,8 +205,8 @@ class EventBuilder:
format_version = self.room_version.event_format
# The types of auth/prev events changes between event versions.
prev_events: Union[StrCollection, list[tuple[str, dict[str, str]]]]
auth_events: Union[list[str], list[tuple[str, dict[str, str]]]]
prev_events: StrCollection | list[tuple[str, dict[str, str]]]
auth_events: list[str] | list[tuple[str, dict[str, str]]]
if format_version == EventFormatVersions.ROOM_V1_V2:
auth_events = await self._store.add_event_hashes(auth_event_ids)
prev_events = await self._store.add_event_hashes(prev_event_ids)
@ -327,7 +327,7 @@ def create_local_event_from_event_dict(
signing_key: SigningKey,
room_version: RoomVersion,
event_dict: JsonDict,
internal_metadata_dict: Optional[JsonDict] = None,
internal_metadata_dict: JsonDict | None = None,
) -> EventBase:
"""Takes a fully formed event dict, ensuring that fields like
`origin_server_ts` have correct values for a locally produced event,

View file

@ -25,9 +25,7 @@ from typing import (
Awaitable,
Callable,
Iterable,
Optional,
TypeVar,
Union,
)
from typing_extensions import ParamSpec
@ -44,7 +42,7 @@ GET_USERS_FOR_STATES_CALLBACK = Callable[
[Iterable[UserPresenceState]], Awaitable[dict[str, set[UserPresenceState]]]
]
# This must either return a set of strings or the constant PresenceRouter.ALL_USERS.
GET_INTERESTED_USERS_CALLBACK = Callable[[str], Awaitable[Union[set[str], str]]]
GET_INTERESTED_USERS_CALLBACK = Callable[[str], Awaitable[set[str] | str]]
logger = logging.getLogger(__name__)
@ -77,8 +75,8 @@ def load_legacy_presence_router(hs: "HomeServer") -> None:
# All methods that the module provides should be async, but this wasn't enforced
# in the old module system, so we wrap them if needed
def async_wrapper(
f: Optional[Callable[P, R]],
) -> Optional[Callable[P, Awaitable[R]]]:
f: Callable[P, R] | None,
) -> Callable[P, Awaitable[R]] | None:
# f might be None if the callback isn't implemented by the module. In this
# case we don't want to register a callback at all so we return None.
if f is None:
@ -95,7 +93,7 @@ def load_legacy_presence_router(hs: "HomeServer") -> None:
return run
# Register the hooks through the module API.
hooks: dict[str, Optional[Callable[..., Any]]] = {
hooks: dict[str, Callable[..., Any] | None] = {
hook: async_wrapper(getattr(presence_router, hook, None))
for hook in presence_router_methods
}
@ -118,8 +116,8 @@ class PresenceRouter:
def register_presence_router_callbacks(
self,
get_users_for_states: Optional[GET_USERS_FOR_STATES_CALLBACK] = None,
get_interested_users: Optional[GET_INTERESTED_USERS_CALLBACK] = None,
get_users_for_states: GET_USERS_FOR_STATES_CALLBACK | None = None,
get_interested_users: GET_INTERESTED_USERS_CALLBACK | None = None,
) -> None:
# PresenceRouter modules are required to implement both of these methods
# or neither of them as they are assumed to act in a complementary manner
@ -191,7 +189,7 @@ class PresenceRouter:
return users_for_states
async def get_interested_users(self, user_id: str) -> Union[set[str], str]:
async def get_interested_users(self, user_id: str) -> set[str] | str:
"""
Retrieve a list of users that `user_id` is interested in receiving the
presence of. This will be in addition to those they share a room with.

View file

@ -51,7 +51,7 @@ class UnpersistedEventContextBase(ABC):
def __init__(self, storage_controller: "StorageControllers"):
self._storage: "StorageControllers" = storage_controller
self.app_service: Optional[ApplicationService] = None
self.app_service: ApplicationService | None = None
@abstractmethod
async def persist(
@ -134,20 +134,20 @@ class EventContext(UnpersistedEventContextBase):
_storage: "StorageControllers"
state_group_deltas: dict[tuple[int, int], StateMap[str]]
rejected: Optional[str] = None
_state_group: Optional[int] = None
state_group_before_event: Optional[int] = None
_state_delta_due_to_event: Optional[StateMap[str]] = None
app_service: Optional[ApplicationService] = None
rejected: str | None = None
_state_group: int | None = None
state_group_before_event: int | None = None
_state_delta_due_to_event: StateMap[str] | None = None
app_service: ApplicationService | None = None
partial_state: bool = False
@staticmethod
def with_state(
storage: "StorageControllers",
state_group: Optional[int],
state_group_before_event: Optional[int],
state_delta_due_to_event: Optional[StateMap[str]],
state_group: int | None,
state_group_before_event: int | None,
state_delta_due_to_event: StateMap[str] | None,
partial_state: bool,
state_group_deltas: dict[tuple[int, int], StateMap[str]],
) -> "EventContext":
@ -227,7 +227,7 @@ class EventContext(UnpersistedEventContextBase):
return context
@property
def state_group(self) -> Optional[int]:
def state_group(self) -> int | None:
"""The ID of the state group for this event.
Note that state events are persisted with a state group which includes the new
@ -354,13 +354,13 @@ class UnpersistedEventContext(UnpersistedEventContextBase):
"""
_storage: "StorageControllers"
state_group_before_event: Optional[int]
state_group_after_event: Optional[int]
state_delta_due_to_event: Optional[StateMap[str]]
prev_group_for_state_group_before_event: Optional[int]
delta_ids_to_state_group_before_event: Optional[StateMap[str]]
state_group_before_event: int | None
state_group_after_event: int | None
state_delta_due_to_event: StateMap[str] | None
prev_group_for_state_group_before_event: int | None
delta_ids_to_state_group_before_event: StateMap[str] | None
partial_state: bool
state_map_before_event: Optional[StateMap[str]] = None
state_map_before_event: StateMap[str] | None = None
@classmethod
async def batch_persist_unpersisted_contexts(
@ -511,7 +511,7 @@ class UnpersistedEventContext(UnpersistedEventContextBase):
def _encode_state_group_delta(
state_group_delta: dict[tuple[int, int], StateMap[str]],
) -> list[tuple[int, int, Optional[list[tuple[str, str, str]]]]]:
) -> list[tuple[int, int, list[tuple[str, str, str]] | None]]:
if not state_group_delta:
return []
@ -538,8 +538,8 @@ def _decode_state_group_delta(
def _encode_state_dict(
state_dict: Optional[StateMap[str]],
) -> Optional[list[tuple[str, str, str]]]:
state_dict: StateMap[str] | None,
) -> list[tuple[str, str, str]] | None:
"""Since dicts of (type, state_key) -> event_id cannot be serialized in
JSON we need to convert them to a form that can.
"""
@ -550,8 +550,8 @@ def _encode_state_dict(
def _decode_state_dict(
input: Optional[list[tuple[str, str, str]]],
) -> Optional[StateMap[str]]:
input: list[tuple[str, str, str]] | None,
) -> StateMap[str] | None:
"""Decodes a state dict encoded using `_encode_state_dict` above"""
if input is None:
return None

View file

@ -30,8 +30,6 @@ from typing import (
Mapping,
Match,
MutableMapping,
Optional,
Union,
)
import attr
@ -415,9 +413,9 @@ class SerializeEventConfig:
event_format: Callable[[JsonDict], JsonDict] = format_event_for_client_v1
# The entity that requested the event. This is used to determine whether to include
# the transaction_id in the unsigned section of the event.
requester: Optional[Requester] = None
requester: Requester | None = None
# List of event fields to include. If empty, all fields will be returned.
only_event_fields: Optional[list[str]] = None
only_event_fields: list[str] | None = None
# Some events can have stripped room state stored in the `unsigned` field.
# This is required for invite and knock functionality. If this option is
# False, that state will be removed from the event before it is returned.
@ -439,7 +437,7 @@ def make_config_for_admin(existing: SerializeEventConfig) -> SerializeEventConfi
def serialize_event(
e: Union[JsonDict, EventBase],
e: JsonDict | EventBase,
time_now_ms: int,
*,
config: SerializeEventConfig = _DEFAULT_SERIALIZE_EVENT_CONFIG,
@ -480,7 +478,7 @@ def serialize_event(
# If we have a txn_id saved in the internal_metadata, we should include it in the
# unsigned section of the event if it was sent by the same session as the one
# requesting the event.
txn_id: Optional[str] = getattr(e.internal_metadata, "txn_id", None)
txn_id: str | None = getattr(e.internal_metadata, "txn_id", None)
if (
txn_id is not None
and config.requester is not None
@ -490,7 +488,7 @@ def serialize_event(
# this includes old events as well as those created by appservice, guests,
# or with tokens minted with the admin API. For those events, fallback
# to using the access token instead.
event_device_id: Optional[str] = getattr(e.internal_metadata, "device_id", None)
event_device_id: str | None = getattr(e.internal_metadata, "device_id", None)
if event_device_id is not None:
if event_device_id == config.requester.device_id:
d["unsigned"]["transaction_id"] = txn_id
@ -504,9 +502,7 @@ def serialize_event(
#
# For guests and appservice users, we can't check the access token ID
# so assume it is the same session.
event_token_id: Optional[int] = getattr(
e.internal_metadata, "token_id", None
)
event_token_id: int | None = getattr(e.internal_metadata, "token_id", None)
if (
(
event_token_id is not None
@ -577,11 +573,11 @@ class EventClientSerializer:
async def serialize_event(
self,
event: Union[JsonDict, EventBase],
event: JsonDict | EventBase,
time_now: int,
*,
config: SerializeEventConfig = _DEFAULT_SERIALIZE_EVENT_CONFIG,
bundle_aggregations: Optional[dict[str, "BundledAggregations"]] = None,
bundle_aggregations: dict[str, "BundledAggregations"] | None = None,
) -> JsonDict:
"""Serializes a single event.
@ -712,11 +708,11 @@ class EventClientSerializer:
@trace
async def serialize_events(
self,
events: Collection[Union[JsonDict, EventBase]],
events: Collection[JsonDict | EventBase],
time_now: int,
*,
config: SerializeEventConfig = _DEFAULT_SERIALIZE_EVENT_CONFIG,
bundle_aggregations: Optional[dict[str, "BundledAggregations"]] = None,
bundle_aggregations: dict[str, "BundledAggregations"] | None = None,
) -> list[JsonDict]:
"""Serializes multiple events.
@ -755,13 +751,13 @@ class EventClientSerializer:
self._add_extra_fields_to_unsigned_client_event_callbacks.append(callback)
_PowerLevel = Union[str, int]
PowerLevelsContent = Mapping[str, Union[_PowerLevel, Mapping[str, _PowerLevel]]]
_PowerLevel = str | int
PowerLevelsContent = Mapping[str, _PowerLevel | Mapping[str, _PowerLevel]]
def copy_and_fixup_power_levels_contents(
old_power_levels: PowerLevelsContent,
) -> dict[str, Union[int, dict[str, int]]]:
) -> dict[str, int | dict[str, int]]:
"""Copy the content of a power_levels event, unfreezing immutabledicts along the way.
We accept as input power level values which are strings, provided they represent an
@ -777,7 +773,7 @@ def copy_and_fixup_power_levels_contents(
if not isinstance(old_power_levels, collections.abc.Mapping):
raise TypeError("Not a valid power-levels content: %r" % (old_power_levels,))
power_levels: dict[str, Union[int, dict[str, int]]] = {}
power_levels: dict[str, int | dict[str, int]] = {}
for k, v in old_power_levels.items():
if isinstance(v, collections.abc.Mapping):
@ -901,7 +897,7 @@ def strip_event(event: EventBase) -> JsonDict:
}
def parse_stripped_state_event(raw_stripped_event: Any) -> Optional[StrippedStateEvent]:
def parse_stripped_state_event(raw_stripped_event: Any) -> StrippedStateEvent | None:
"""
Given a raw value from an event's `unsigned` field, attempt to parse it into a
`StrippedStateEvent`.

View file

@ -19,7 +19,7 @@
#
#
import collections.abc
from typing import Union, cast
from typing import cast
import jsonschema
from pydantic import Field, StrictBool, StrictStr
@ -177,7 +177,7 @@ class EventValidator:
errcode=Codes.BAD_JSON,
)
def validate_builder(self, event: Union[EventBase, EventBuilder]) -> None:
def validate_builder(self, event: EventBase | EventBuilder) -> None:
"""Validates that the builder/event has roughly the right format. Only
checks values that we expect a proto event to have, rather than all the
fields an event would have
@ -249,7 +249,7 @@ class EventValidator:
if not isinstance(d[s], str):
raise SynapseError(400, "'%s' not a string type" % (s,))
def _ensure_state_event(self, event: Union[EventBase, EventBuilder]) -> None:
def _ensure_state_event(self, event: EventBase | EventBuilder) -> None:
if not event.is_state():
raise SynapseError(400, "'%s' must be state events" % (event.type,))

View file

@ -20,7 +20,7 @@
#
#
import logging
from typing import TYPE_CHECKING, Awaitable, Callable, Optional, Sequence
from typing import TYPE_CHECKING, Awaitable, Callable, Sequence
from synapse.api.constants import MAX_DEPTH, EventContentFields, EventTypes, Membership
from synapse.api.errors import Codes, SynapseError
@ -67,7 +67,7 @@ class FederationBase:
# We need to define this lazily otherwise we get a cyclic dependency.
# self._policy_handler = hs.get_room_policy_handler()
self._policy_handler: Optional[RoomPolicyHandler] = None
self._policy_handler: RoomPolicyHandler | None = None
def _lazily_get_policy_handler(self) -> RoomPolicyHandler:
"""Lazily get the room policy handler.
@ -88,9 +88,8 @@ class FederationBase:
self,
room_version: RoomVersion,
pdu: EventBase,
record_failure_callback: Optional[
Callable[[EventBase, str], Awaitable[None]]
] = None,
record_failure_callback: Callable[[EventBase, str], Awaitable[None]]
| None = None,
) -> EventBase:
"""Checks that event is correctly signed by the sending server.

View file

@ -37,7 +37,6 @@ from typing import (
Optional,
Sequence,
TypeVar,
Union,
)
import attr
@ -263,7 +262,7 @@ class FederationClient(FederationBase):
user: UserID,
destination: str,
query: dict[str, dict[str, dict[str, int]]],
timeout: Optional[int],
timeout: int | None,
) -> JsonDict:
"""Claims one-time keys for a device hosted on a remote server.
@ -334,7 +333,7 @@ class FederationClient(FederationBase):
@tag_args
async def backfill(
self, dest: str, room_id: str, limit: int, extremities: Collection[str]
) -> Optional[list[EventBase]]:
) -> list[EventBase] | None:
"""Requests some more historic PDUs for the given room from the
given destination server.
@ -381,8 +380,8 @@ class FederationClient(FederationBase):
destination: str,
event_id: str,
room_version: RoomVersion,
timeout: Optional[int] = None,
) -> Optional[EventBase]:
timeout: int | None = None,
) -> EventBase | None:
"""Requests the PDU with given origin and ID from the remote home
server. Does not have any caching or rate limiting!
@ -441,7 +440,7 @@ class FederationClient(FederationBase):
@trace
@tag_args
async def get_pdu_policy_recommendation(
self, destination: str, pdu: EventBase, timeout: Optional[int] = None
self, destination: str, pdu: EventBase, timeout: int | None = None
) -> str:
"""Requests that the destination server (typically a policy server)
check the event and return its recommendation on how to handle the
@ -497,8 +496,8 @@ class FederationClient(FederationBase):
@trace
@tag_args
async def ask_policy_server_to_sign_event(
self, destination: str, pdu: EventBase, timeout: Optional[int] = None
) -> Optional[JsonDict]:
self, destination: str, pdu: EventBase, timeout: int | None = None
) -> JsonDict | None:
"""Requests that the destination server (typically a policy server)
sign the event as not spam.
@ -538,8 +537,8 @@ class FederationClient(FederationBase):
destinations: Collection[str],
event_id: str,
room_version: RoomVersion,
timeout: Optional[int] = None,
) -> Optional[PulledPduInfo]:
timeout: int | None = None,
) -> PulledPduInfo | None:
"""Requests the PDU with given origin and ID from the remote home
servers.
@ -832,10 +831,9 @@ class FederationClient(FederationBase):
pdu: EventBase,
origin: str,
room_version: RoomVersion,
record_failure_callback: Optional[
Callable[[EventBase, str], Awaitable[None]]
] = None,
) -> Optional[EventBase]:
record_failure_callback: Callable[[EventBase, str], Awaitable[None]]
| None = None,
) -> EventBase | None:
"""Takes a PDU and checks its signatures and hashes.
If the PDU fails its signature check then we check if we have it in the
@ -931,7 +929,7 @@ class FederationClient(FederationBase):
description: str,
destinations: Iterable[str],
callback: Callable[[str], Awaitable[T]],
failover_errcodes: Optional[Container[str]] = None,
failover_errcodes: Container[str] | None = None,
failover_on_unknown_endpoint: bool = False,
) -> T:
"""Try an operation on a series of servers, until it succeeds
@ -1046,7 +1044,7 @@ class FederationClient(FederationBase):
user_id: str,
membership: str,
content: dict,
params: Optional[Mapping[str, Union[str, Iterable[str]]]],
params: Mapping[str, str | Iterable[str]] | None,
) -> tuple[str, EventBase, RoomVersion]:
"""
Creates an m.room.member event, with context, without participating in the room.
@ -1563,11 +1561,11 @@ class FederationClient(FederationBase):
async def get_public_rooms(
self,
remote_server: str,
limit: Optional[int] = None,
since_token: Optional[str] = None,
search_filter: Optional[dict] = None,
limit: int | None = None,
since_token: str | None = None,
search_filter: dict | None = None,
include_all_networks: bool = False,
third_party_instance_id: Optional[str] = None,
third_party_instance_id: str | None = None,
) -> JsonDict:
"""Get the list of public rooms from a remote homeserver
@ -1676,7 +1674,7 @@ class FederationClient(FederationBase):
async def get_room_complexity(
self, destination: str, room_id: str
) -> Optional[JsonDict]:
) -> JsonDict | None:
"""
Fetch the complexity of a remote room from another server.
@ -1987,10 +1985,10 @@ class FederationClient(FederationBase):
max_timeout_ms: int,
download_ratelimiter: Ratelimiter,
ip_address: str,
) -> Union[
tuple[int, dict[bytes, list[bytes]], bytes],
tuple[int, dict[bytes, list[bytes]]],
]:
) -> (
tuple[int, dict[bytes, list[bytes]], bytes]
| tuple[int, dict[bytes, list[bytes]]]
):
try:
return await self.transport_layer.federation_download_media(
destination,

View file

@ -28,8 +28,6 @@ from typing import (
Callable,
Collection,
Mapping,
Optional,
Union,
)
from prometheus_client import Counter, Gauge, Histogram
@ -176,13 +174,11 @@ class FederationServer(FederationBase):
# We cache responses to state queries, as they take a while and often
# come in waves.
self._state_resp_cache: ResponseCache[tuple[str, Optional[str]]] = (
ResponseCache(
clock=hs.get_clock(),
name="state_resp",
server_name=self.server_name,
timeout_ms=30000,
)
self._state_resp_cache: ResponseCache[tuple[str, str | None]] = ResponseCache(
clock=hs.get_clock(),
name="state_resp",
server_name=self.server_name,
timeout_ms=30000,
)
self._state_ids_resp_cache: ResponseCache[tuple[str, str]] = ResponseCache(
clock=hs.get_clock(),
@ -666,7 +662,7 @@ class FederationServer(FederationBase):
async def on_pdu_request(
self, origin: str, event_id: str
) -> tuple[int, Union[JsonDict, str]]:
) -> tuple[int, JsonDict | str]:
pdu = await self.handler.get_persisted_pdu(origin, event_id)
if pdu:
@ -763,7 +759,7 @@ class FederationServer(FederationBase):
prev_state_ids = await context.get_prev_state_ids()
state_event_ids: Collection[str]
servers_in_room: Optional[Collection[str]]
servers_in_room: Collection[str] | None
if caller_supports_partial_state:
summary = await self.store.get_room_summary(room_id)
state_event_ids = _get_event_ids_for_partial_state_join(
@ -1126,7 +1122,7 @@ class FederationServer(FederationBase):
return {"events": serialize_and_filter_pdus(missing_events, time_now)}
async def on_openid_userinfo(self, token: str) -> Optional[str]:
async def on_openid_userinfo(self, token: str) -> str | None:
ts_now_ms = self._clock.time_msec()
return await self.store.get_user_id_for_open_id_token(token, ts_now_ms)
@ -1205,7 +1201,7 @@ class FederationServer(FederationBase):
async def _get_next_nonspam_staged_event_for_room(
self, room_id: str, room_version: RoomVersion
) -> Optional[tuple[str, EventBase]]:
) -> tuple[str, EventBase] | None:
"""Fetch the first non-spam event from staging queue.
Args:
@ -1246,8 +1242,8 @@ class FederationServer(FederationBase):
room_id: str,
room_version: RoomVersion,
lock: Lock,
latest_origin: Optional[str] = None,
latest_event: Optional[EventBase] = None,
latest_origin: str | None = None,
latest_event: EventBase | None = None,
) -> None:
"""Process events in the staging area for the given room.

View file

@ -27,7 +27,6 @@ These actions are mostly only used by the :py:mod:`.replication` module.
"""
import logging
from typing import Optional
from synapse.federation.units import Transaction
from synapse.storage.databases.main import DataStore
@ -44,7 +43,7 @@ class TransactionActions:
async def have_responded(
self, origin: str, transaction: Transaction
) -> Optional[tuple[int, JsonDict]]:
) -> tuple[int, JsonDict] | None:
"""Have we already responded to a transaction with the same id and
origin?

View file

@ -42,7 +42,6 @@ from typing import (
TYPE_CHECKING,
Hashable,
Iterable,
Optional,
Sized,
)
@ -217,7 +216,7 @@ class FederationRemoteSendQueue(AbstractFederationSender):
destination: str,
edu_type: str,
content: JsonDict,
key: Optional[Hashable] = None,
key: Hashable | None = None,
) -> None:
"""As per FederationSender"""
if self.is_mine_server_name(destination):

View file

@ -138,7 +138,6 @@ from typing import (
Hashable,
Iterable,
Literal,
Optional,
)
import attr
@ -266,7 +265,7 @@ class AbstractFederationSender(metaclass=abc.ABCMeta):
destination: str,
edu_type: str,
content: JsonDict,
key: Optional[Hashable] = None,
key: Hashable | None = None,
) -> None:
"""Construct an Edu object, and queue it for sending
@ -410,7 +409,7 @@ class FederationSender(AbstractFederationSender):
self.is_mine_id = hs.is_mine_id
self.is_mine_server_name = hs.is_mine_server_name
self._presence_router: Optional["PresenceRouter"] = None
self._presence_router: "PresenceRouter" | None = None
self._transaction_manager = TransactionManager(hs)
self._instance_name = hs.get_instance_name()
@ -481,7 +480,7 @@ class FederationSender(AbstractFederationSender):
def _get_per_destination_queue(
self, destination: str
) -> Optional[PerDestinationQueue]:
) -> PerDestinationQueue | None:
"""Get or create a PerDestinationQueue for the given destination
Args:
@ -605,7 +604,7 @@ class FederationSender(AbstractFederationSender):
)
return
destinations: Optional[Collection[str]] = None
destinations: Collection[str] | None = None
if not event.prev_event_ids():
# If there are no prev event IDs then the state is empty
# and so no remote servers in the room
@ -1010,7 +1009,7 @@ class FederationSender(AbstractFederationSender):
destination: str,
edu_type: str,
content: JsonDict,
key: Optional[Hashable] = None,
key: Hashable | None = None,
) -> None:
"""Construct an Edu object, and queue it for sending
@ -1038,7 +1037,7 @@ class FederationSender(AbstractFederationSender):
self.send_edu(edu, key)
def send_edu(self, edu: Edu, key: Optional[Hashable]) -> None:
def send_edu(self, edu: Edu, key: Hashable | None) -> None:
"""Queue an EDU for sending
Args:
@ -1134,7 +1133,7 @@ class FederationSender(AbstractFederationSender):
In order to reduce load spikes, adds a delay between each destination.
"""
last_processed: Optional[str] = None
last_processed: str | None = None
while not self._is_shutdown:
destinations_to_wake = (

View file

@ -23,7 +23,7 @@ import datetime
import logging
from collections import OrderedDict
from types import TracebackType
from typing import TYPE_CHECKING, Hashable, Iterable, Optional
from typing import TYPE_CHECKING, Hashable, Iterable
import attr
from prometheus_client import Counter
@ -121,7 +121,7 @@ class PerDestinationQueue:
self._destination = destination
self.transmission_loop_running = False
self._transmission_loop_enabled = True
self.active_transmission_loop: Optional[defer.Deferred] = None
self.active_transmission_loop: defer.Deferred | None = None
# Flag to signal to any running transmission loop that there is new data
# queued up to be sent.
@ -142,7 +142,7 @@ class PerDestinationQueue:
# Cache of the last successfully-transmitted stream ordering for this
# destination (we are the only updater so this is safe)
self._last_successful_stream_ordering: Optional[int] = None
self._last_successful_stream_ordering: int | None = None
# a queue of pending PDUs
self._pending_pdus: list[EventBase] = []
@ -742,9 +742,9 @@ class _TransactionQueueManager:
queue: PerDestinationQueue
_device_stream_id: Optional[int] = None
_device_list_id: Optional[int] = None
_last_stream_ordering: Optional[int] = None
_device_stream_id: int | None = None
_device_list_id: int | None = None
_last_stream_ordering: int | None = None
_pdus: list[EventBase] = attr.Factory(list)
async def __aenter__(self) -> tuple[list[EventBase], list[Edu]]:
@ -845,9 +845,9 @@ class _TransactionQueueManager:
async def __aexit__(
self,
exc_type: Optional[type[BaseException]],
exc: Optional[BaseException],
tb: Optional[TracebackType],
exc_type: type[BaseException] | None,
exc: BaseException | None,
tb: TracebackType | None,
) -> None:
if exc_type is not None:
# Failed to send transaction, so we bail out.

View file

@ -31,8 +31,6 @@ from typing import (
Generator,
Iterable,
Mapping,
Optional,
Union,
)
import attr
@ -122,7 +120,7 @@ class TransportLayerClient:
)
async def get_event(
self, destination: str, event_id: str, timeout: Optional[int] = None
self, destination: str, event_id: str, timeout: int | None = None
) -> JsonDict:
"""Requests the pdu with give id and origin from the given server.
@ -144,7 +142,7 @@ class TransportLayerClient:
)
async def get_policy_recommendation_for_pdu(
self, destination: str, event: EventBase, timeout: Optional[int] = None
self, destination: str, event: EventBase, timeout: int | None = None
) -> JsonDict:
"""Requests the policy recommendation for the given pdu from the given policy server.
@ -171,7 +169,7 @@ class TransportLayerClient:
)
async def ask_policy_server_to_sign_event(
self, destination: str, event: EventBase, timeout: Optional[int] = None
self, destination: str, event: EventBase, timeout: int | None = None
) -> JsonDict:
"""Requests that the destination server (typically a policy server)
sign the event as not spam.
@ -198,7 +196,7 @@ class TransportLayerClient:
async def backfill(
self, destination: str, room_id: str, event_tuples: Collection[str], limit: int
) -> Optional[Union[JsonDict, list]]:
) -> JsonDict | list | None:
"""Requests `limit` previous PDUs in a given context before list of
PDUs.
@ -235,7 +233,7 @@ class TransportLayerClient:
async def timestamp_to_event(
self, destination: str, room_id: str, timestamp: int, direction: Direction
) -> Union[JsonDict, list]:
) -> JsonDict | list:
"""
Calls a remote federating server at `destination` asking for their
closest event to the given timestamp in the given direction.
@ -270,7 +268,7 @@ class TransportLayerClient:
async def send_transaction(
self,
transaction: Transaction,
json_data_callback: Optional[Callable[[], JsonDict]] = None,
json_data_callback: Callable[[], JsonDict] | None = None,
) -> JsonDict:
"""Sends the given Transaction to its destination
@ -343,7 +341,7 @@ class TransportLayerClient:
room_id: str,
user_id: str,
membership: str,
params: Optional[Mapping[str, Union[str, Iterable[str]]]],
params: Mapping[str, str | Iterable[str]] | None,
) -> JsonDict:
"""Asks a remote server to build and sign us a membership event
@ -528,11 +526,11 @@ class TransportLayerClient:
async def get_public_rooms(
self,
remote_server: str,
limit: Optional[int] = None,
since_token: Optional[str] = None,
search_filter: Optional[dict] = None,
limit: int | None = None,
since_token: str | None = None,
search_filter: dict | None = None,
include_all_networks: bool = False,
third_party_instance_id: Optional[str] = None,
third_party_instance_id: str | None = None,
) -> JsonDict:
"""Get the list of public rooms from a remote homeserver
@ -567,7 +565,7 @@ class TransportLayerClient:
)
raise
else:
args: dict[str, Union[str, Iterable[str]]] = {
args: dict[str, str | Iterable[str]] = {
"include_all_networks": "true" if include_all_networks else "false"
}
if third_party_instance_id:
@ -694,7 +692,7 @@ class TransportLayerClient:
user: UserID,
destination: str,
query_content: JsonDict,
timeout: Optional[int],
timeout: int | None,
) -> JsonDict:
"""Claim one-time keys for a list of devices hosted on a remote server.
@ -740,7 +738,7 @@ class TransportLayerClient:
user: UserID,
destination: str,
query_content: JsonDict,
timeout: Optional[int],
timeout: int | None,
) -> JsonDict:
"""Claim one-time keys for a list of devices hosted on a remote server.
@ -997,13 +995,13 @@ class SendJoinResponse:
event_dict: JsonDict
# The parsed join event from the /send_join response. This will be None if
# "event" is not included in the response.
event: Optional[EventBase] = None
event: EventBase | None = None
# The room state is incomplete
members_omitted: bool = False
# List of servers in the room
servers_in_room: Optional[list[str]] = None
servers_in_room: list[str] | None = None
@attr.s(slots=True, auto_attribs=True)

View file

@ -20,7 +20,7 @@
#
#
import logging
from typing import TYPE_CHECKING, Iterable, Literal, Optional
from typing import TYPE_CHECKING, Iterable, Literal
from synapse.api.errors import FederationDeniedError, SynapseError
from synapse.federation.transport.server._base import (
@ -52,7 +52,7 @@ logger = logging.getLogger(__name__)
class TransportLayerServer(JsonResource):
"""Handles incoming federation HTTP requests"""
def __init__(self, hs: "HomeServer", servlet_groups: Optional[list[str]] = None):
def __init__(self, hs: "HomeServer", servlet_groups: list[str] | None = None):
"""Initialize the TransportLayerServer
Will by default register all servlets. For custom behaviour, pass in
@ -135,7 +135,7 @@ class PublicRoomList(BaseFederationServlet):
if not self.allow_access:
raise FederationDeniedError(origin)
limit: Optional[int] = parse_integer_from_args(query, "limit", 0)
limit: int | None = parse_integer_from_args(query, "limit", 0)
since_token = parse_string_from_args(query, "since", None)
include_all_networks = parse_boolean_from_args(
query, "include_all_networks", default=False
@ -170,7 +170,7 @@ class PublicRoomList(BaseFederationServlet):
if not self.allow_access:
raise FederationDeniedError(origin)
limit: Optional[int] = int(content.get("limit", 100))
limit: int | None = int(content.get("limit", 100))
since_token = content.get("since", None)
search_filter = content.get("filter", None)
@ -240,7 +240,7 @@ class OpenIdUserInfo(BaseFederationServlet):
async def on_GET(
self,
origin: Optional[str],
origin: str | None,
content: Literal[None],
query: dict[bytes, list[bytes]],
) -> tuple[int, JsonDict]:
@ -281,7 +281,7 @@ def register_servlets(
resource: HttpServer,
authenticator: Authenticator,
ratelimiter: FederationRateLimiter,
servlet_groups: Optional[Iterable[str]] = None,
servlet_groups: Iterable[str] | None = None,
) -> None:
"""Initialize and register servlet classes.

View file

@ -24,7 +24,7 @@ import logging
import re
import time
from http import HTTPStatus
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Optional, cast
from typing import TYPE_CHECKING, Any, Awaitable, Callable, cast
from synapse.api.errors import Codes, FederationDeniedError, SynapseError
from synapse.api.urls import FEDERATION_V1_PREFIX
@ -77,7 +77,7 @@ class Authenticator:
# A method just so we can pass 'self' as the authenticator to the Servlets
async def authenticate_request(
self, request: SynapseRequest, content: Optional[JsonDict]
self, request: SynapseRequest, content: JsonDict | None
) -> str:
now = self._clock.time_msec()
json_request: JsonDict = {
@ -165,7 +165,7 @@ class Authenticator:
logger.exception("Error resetting retry timings on %s", origin)
def _parse_auth_header(header_bytes: bytes) -> tuple[str, str, str, Optional[str]]:
def _parse_auth_header(header_bytes: bytes) -> tuple[str, str, str, str | None]:
"""Parse an X-Matrix auth header
Args:
@ -252,7 +252,7 @@ class BaseFederationServlet:
components as specified in the path match regexp.
Returns:
Optional[tuple[int, object]]: either (response code, response object) to
tuple[int, object] | None: either (response code, response object) to
return a JSON response, or None if the request has already been handled.
Raises:
@ -289,7 +289,7 @@ class BaseFederationServlet:
@functools.wraps(func)
async def new_func(
request: SynapseRequest, *args: Any, **kwargs: str
) -> Optional[tuple[int, Any]]:
) -> tuple[int, Any] | None:
"""A callback which can be passed to HttpServer.RegisterPaths
Args:
@ -309,7 +309,7 @@ class BaseFederationServlet:
try:
with start_active_span("authenticate_request"):
origin: Optional[str] = await authenticator.authenticate_request(
origin: str | None = await authenticator.authenticate_request(
request, content
)
except NoAuthenticationError:

View file

@ -24,9 +24,7 @@ from typing import (
TYPE_CHECKING,
Literal,
Mapping,
Optional,
Sequence,
Union,
)
from synapse.api.constants import Direction, EduTypes
@ -156,7 +154,7 @@ class FederationEventServlet(BaseFederationServerServlet):
content: Literal[None],
query: dict[bytes, list[bytes]],
event_id: str,
) -> tuple[int, Union[JsonDict, str]]:
) -> tuple[int, JsonDict | str]:
return await self.handler.on_pdu_request(origin, event_id)
@ -642,7 +640,7 @@ class On3pidBindServlet(BaseFederationServerServlet):
REQUIRE_AUTH = False
async def on_POST(
self, origin: Optional[str], content: JsonDict, query: dict[bytes, list[bytes]]
self, origin: str | None, content: JsonDict, query: dict[bytes, list[bytes]]
) -> tuple[int, JsonDict]:
if "invites" in content:
last_exception = None
@ -676,7 +674,7 @@ class FederationVersionServlet(BaseFederationServlet):
async def on_GET(
self,
origin: Optional[str],
origin: str | None,
content: Literal[None],
query: dict[bytes, list[bytes]],
) -> tuple[int, JsonDict]:
@ -812,7 +810,7 @@ class FederationMediaDownloadServlet(BaseFederationServerServlet):
async def on_GET(
self,
origin: Optional[str],
origin: str | None,
content: Literal[None],
request: SynapseRequest,
media_id: str,
@ -852,7 +850,7 @@ class FederationMediaThumbnailServlet(BaseFederationServerServlet):
async def on_GET(
self,
origin: Optional[str],
origin: str | None,
content: Literal[None],
request: SynapseRequest,
media_id: str,

View file

@ -24,7 +24,7 @@ server protocol.
"""
import logging
from typing import Optional, Sequence
from typing import Sequence
import attr
@ -70,7 +70,7 @@ class Edu:
getattr(self, "content", {})["org.matrix.opentracing_context"] = "{}"
def _none_to_list(edus: Optional[list[JsonDict]]) -> list[JsonDict]:
def _none_to_list(edus: list[JsonDict] | None) -> list[JsonDict]:
if edus is None:
return []
return edus
@ -128,6 +128,6 @@ def filter_pdus_for_valid_depth(pdus: Sequence[JsonDict]) -> list[JsonDict]:
def serialize_and_filter_pdus(
pdus: Sequence[EventBase], time_now: Optional[int] = None
pdus: Sequence[EventBase], time_now: int | None = None
) -> list[JsonDict]:
return filter_pdus_for_valid_depth([pdu.get_pdu_json(time_now) for pdu in pdus])

View file

@ -21,7 +21,7 @@
#
import logging
import random
from typing import TYPE_CHECKING, Awaitable, Callable, Optional
from typing import TYPE_CHECKING, Awaitable, Callable
from synapse.api.constants import AccountDataTypes
from synapse.replication.http.account_data import (
@ -40,9 +40,7 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
ON_ACCOUNT_DATA_UPDATED_CALLBACK = Callable[
[str, Optional[str], str, JsonDict], Awaitable
]
ON_ACCOUNT_DATA_UPDATED_CALLBACK = Callable[[str, str | None, str, JsonDict], Awaitable]
class AccountDataHandler:
@ -72,7 +70,7 @@ class AccountDataHandler:
] = []
def register_module_callbacks(
self, on_account_data_updated: Optional[ON_ACCOUNT_DATA_UPDATED_CALLBACK] = None
self, on_account_data_updated: ON_ACCOUNT_DATA_UPDATED_CALLBACK | None = None
) -> None:
"""Register callbacks from modules."""
if on_account_data_updated is not None:
@ -81,7 +79,7 @@ class AccountDataHandler:
async def _notify_modules(
self,
user_id: str,
room_id: Optional[str],
room_id: str | None,
account_data_type: str,
content: JsonDict,
) -> None:
@ -143,7 +141,7 @@ class AccountDataHandler:
async def remove_account_data_for_room(
self, user_id: str, room_id: str, account_data_type: str
) -> Optional[int]:
) -> int | None:
"""
Deletes the room account data for the given user and account data type.
@ -219,7 +217,7 @@ class AccountDataHandler:
async def remove_account_data_for_user(
self, user_id: str, account_data_type: str
) -> Optional[int]:
) -> int | None:
"""Removes a piece of global account_data for a user.
Args:
@ -324,7 +322,7 @@ class AccountDataEventSource(EventSource[int, JsonDict]):
limit: int,
room_ids: StrCollection,
is_guest: bool,
explicit_room_id: Optional[str] = None,
explicit_room_id: str | None = None,
) -> tuple[list[JsonDict], int]:
user_id = user.to_string()
last_stream_id = from_key

View file

@ -21,7 +21,7 @@
import email.mime.multipart
import email.utils
import logging
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING
from synapse.api.errors import AuthError, StoreError, SynapseError
from synapse.metrics.background_process_metrics import wrap_as_background_process
@ -108,8 +108,8 @@ class AccountValidityHandler:
async def on_user_login(
self,
user_id: str,
auth_provider_type: Optional[str],
auth_provider_id: Optional[str],
auth_provider_type: str | None,
auth_provider_id: str | None,
) -> None:
"""Tell third-party modules about a user logins.
@ -326,9 +326,9 @@ class AccountValidityHandler:
async def renew_account_for_user(
self,
user_id: str,
expiration_ts: Optional[int] = None,
expiration_ts: int | None = None,
email_sent: bool = False,
renewal_token: Optional[str] = None,
renewal_token: str | None = None,
) -> int:
"""Renews the account attached to a given user by pushing back the
expiration date by the current validity period in the server's

View file

@ -25,7 +25,6 @@ from typing import (
TYPE_CHECKING,
Any,
Mapping,
Optional,
Sequence,
)
@ -71,7 +70,7 @@ class AdminHandler:
self.hs = hs
async def get_redact_task(self, redact_id: str) -> Optional[ScheduledTask]:
async def get_redact_task(self, redact_id: str) -> ScheduledTask | None:
"""Get the current status of an active redaction process
Args:
@ -99,11 +98,9 @@ class AdminHandler:
return ret
async def get_user(self, user: UserID) -> Optional[JsonMapping]:
async def get_user(self, user: UserID) -> JsonMapping | None:
"""Function to get user details"""
user_info: Optional[UserInfo] = await self._store.get_user_by_id(
user.to_string()
)
user_info: UserInfo | None = await self._store.get_user_by_id(user.to_string())
if user_info is None:
return None
@ -355,8 +352,8 @@ class AdminHandler:
rooms: list,
requester: JsonMapping,
use_admin: bool,
reason: Optional[str],
limit: Optional[int],
reason: str | None,
limit: int | None,
) -> str:
"""
Start a task redacting the events of the given user in the given rooms
@ -408,7 +405,7 @@ class AdminHandler:
async def _redact_all_events(
self, task: ScheduledTask
) -> tuple[TaskStatus, Optional[Mapping[str, Any]], Optional[str]]:
) -> tuple[TaskStatus, Mapping[str, Any] | None, str | None]:
"""
Task to redact all of a users events in the given rooms, tracking which, if any, events
whose redaction failed

View file

@ -24,8 +24,6 @@ from typing import (
Collection,
Iterable,
Mapping,
Optional,
Union,
)
from prometheus_client import Counter
@ -240,8 +238,8 @@ class ApplicationServicesHandler:
def notify_interested_services_ephemeral(
self,
stream_key: StreamKeyType,
new_token: Union[int, RoomStreamToken, MultiWriterStreamToken],
users: Collection[Union[str, UserID]],
new_token: int | RoomStreamToken | MultiWriterStreamToken,
users: Collection[str | UserID],
) -> None:
"""
This is called by the notifier in the background when an ephemeral event is handled
@ -340,8 +338,8 @@ class ApplicationServicesHandler:
self,
services: list[ApplicationService],
stream_key: StreamKeyType,
new_token: Union[int, MultiWriterStreamToken],
users: Collection[Union[str, UserID]],
new_token: int | MultiWriterStreamToken,
users: Collection[str | UserID],
) -> None:
logger.debug("Checking interested services for %s", stream_key)
with Measure(
@ -498,8 +496,8 @@ class ApplicationServicesHandler:
async def _handle_presence(
self,
service: ApplicationService,
users: Collection[Union[str, UserID]],
new_token: Optional[int],
users: Collection[str | UserID],
new_token: int | None,
) -> list[JsonMapping]:
"""
Return the latest presence updates that the given application service should receive.
@ -559,7 +557,7 @@ class ApplicationServicesHandler:
self,
service: ApplicationService,
new_token: int,
users: Collection[Union[str, UserID]],
users: Collection[str | UserID],
) -> list[JsonDict]:
"""
Given an application service, determine which events it should receive
@ -733,7 +731,7 @@ class ApplicationServicesHandler:
async def query_room_alias_exists(
self, room_alias: RoomAlias
) -> Optional[RoomAliasMapping]:
) -> RoomAliasMapping | None:
"""Check if an application service knows this room alias exists.
Args:
@ -782,7 +780,7 @@ class ApplicationServicesHandler:
return ret
async def get_3pe_protocols(
self, only_protocol: Optional[str] = None
self, only_protocol: str | None = None
) -> dict[str, JsonDict]:
services = self.store.get_app_services()
protocols: dict[str, list[JsonDict]] = {}
@ -935,7 +933,7 @@ class ApplicationServicesHandler:
return claimed_keys, missing
async def query_keys(
self, query: Mapping[str, Optional[list[str]]]
self, query: Mapping[str, list[str] | None]
) -> dict[str, dict[str, dict[str, JsonDict]]]:
"""Query application services for device keys.

View file

@ -33,8 +33,6 @@ from typing import (
Callable,
Iterable,
Mapping,
Optional,
Union,
cast,
)
@ -289,7 +287,7 @@ class AuthHandler:
request_body: dict[str, Any],
description: str,
can_skip_ui_auth: bool = False,
) -> tuple[dict, Optional[str]]:
) -> tuple[dict, str | None]:
"""
Checks that the user is who they claim to be, via a UI auth.
@ -440,7 +438,7 @@ class AuthHandler:
request: SynapseRequest,
clientdict: dict[str, Any],
description: str,
get_new_session_data: Optional[Callable[[], JsonDict]] = None,
get_new_session_data: Callable[[], JsonDict] | None = None,
) -> tuple[dict, dict, str]:
"""
Takes a dictionary sent by the client in the login / registration
@ -487,7 +485,7 @@ class AuthHandler:
all the stages in any of the permitted flows.
"""
sid: Optional[str] = None
sid: str | None = None
authdict = clientdict.pop("auth", {})
if "session" in authdict:
sid = authdict["session"]
@ -637,7 +635,7 @@ class AuthHandler:
authdict["session"], stagetype, result
)
def get_session_id(self, clientdict: dict[str, Any]) -> Optional[str]:
def get_session_id(self, clientdict: dict[str, Any]) -> str | None:
"""
Gets the session ID for a client given the client dictionary
@ -673,7 +671,7 @@ class AuthHandler:
raise SynapseError(400, "Unknown session ID: %s" % (session_id,))
async def get_session_data(
self, session_id: str, key: str, default: Optional[Any] = None
self, session_id: str, key: str, default: Any | None = None
) -> Any:
"""
Retrieve data stored with set_session_data
@ -699,7 +697,7 @@ class AuthHandler:
async def _check_auth_dict(
self, authdict: dict[str, Any], clientip: str
) -> Union[dict[str, Any], str]:
) -> dict[str, Any] | str:
"""Attempt to validate the auth dict provided by a client
Args:
@ -774,9 +772,9 @@ class AuthHandler:
async def refresh_token(
self,
refresh_token: str,
access_token_valid_until_ms: Optional[int],
refresh_token_valid_until_ms: Optional[int],
) -> tuple[str, str, Optional[int]]:
access_token_valid_until_ms: int | None,
refresh_token_valid_until_ms: int | None,
) -> tuple[str, str, int | None]:
"""
Consumes a refresh token and generate both a new access token and a new refresh token from it.
@ -909,8 +907,8 @@ class AuthHandler:
self,
user_id: str,
duration_ms: int = (2 * 60 * 1000),
auth_provider_id: Optional[str] = None,
auth_provider_session_id: Optional[str] = None,
auth_provider_id: str | None = None,
auth_provider_session_id: str | None = None,
) -> str:
login_token = self.generate_login_token()
now = self._clock.time_msec()
@ -928,8 +926,8 @@ class AuthHandler:
self,
user_id: str,
device_id: str,
expiry_ts: Optional[int],
ultimate_session_expiry_ts: Optional[int],
expiry_ts: int | None,
ultimate_session_expiry_ts: int | None,
) -> tuple[str, int]:
"""
Creates a new refresh token for the user with the given user ID.
@ -961,11 +959,11 @@ class AuthHandler:
async def create_access_token_for_user_id(
self,
user_id: str,
device_id: Optional[str],
valid_until_ms: Optional[int],
puppets_user_id: Optional[str] = None,
device_id: str | None,
valid_until_ms: int | None,
puppets_user_id: str | None = None,
is_appservice_ghost: bool = False,
refresh_token_id: Optional[int] = None,
refresh_token_id: int | None = None,
) -> str:
"""
Creates a new access token for the user with the given user ID.
@ -1034,7 +1032,7 @@ class AuthHandler:
return access_token
async def check_user_exists(self, user_id: str) -> Optional[str]:
async def check_user_exists(self, user_id: str) -> str | None:
"""
Checks to see if a user with the given id exists. Will check case
insensitively, but return None if there are multiple inexact matches.
@ -1061,9 +1059,7 @@ class AuthHandler:
"""
return await self.store.is_user_approved(user_id)
async def _find_user_id_and_pwd_hash(
self, user_id: str
) -> Optional[tuple[str, str]]:
async def _find_user_id_and_pwd_hash(self, user_id: str) -> tuple[str, str] | None:
"""Checks to see if a user with the given id exists. Will check case
insensitively, but will return None if there are multiple inexact
matches.
@ -1141,7 +1137,7 @@ class AuthHandler:
login_submission: dict[str, Any],
ratelimit: bool = False,
is_reauth: bool = False,
) -> tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]:
) -> tuple[str, Callable[["LoginResponse"], Awaitable[None]] | None]:
"""Authenticates the user for the /login API
Also used by the user-interactive auth flow to validate auth types which don't
@ -1297,7 +1293,7 @@ class AuthHandler:
self,
username: str,
login_submission: dict[str, Any],
) -> tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]:
) -> tuple[str, Callable[["LoginResponse"], Awaitable[None]] | None]:
"""Helper for validate_login
Handles login, once we've mapped 3pids onto userids
@ -1386,7 +1382,7 @@ class AuthHandler:
async def check_password_provider_3pid(
self, medium: str, address: str, password: str
) -> tuple[Optional[str], Optional[Callable[["LoginResponse"], Awaitable[None]]]]:
) -> tuple[str | None, Callable[["LoginResponse"], Awaitable[None]] | None]:
"""Check if a password provider is able to validate a thirdparty login
Args:
@ -1413,7 +1409,7 @@ class AuthHandler:
# if result is None then return (None, None)
return None, None
async def _check_local_password(self, user_id: str, password: str) -> Optional[str]:
async def _check_local_password(self, user_id: str, password: str) -> str | None:
"""Authenticate a user against the local password database.
user_id is checked case insensitively, but will return None if there are
@ -1528,8 +1524,8 @@ class AuthHandler:
async def delete_access_tokens_for_user(
self,
user_id: str,
except_token_id: Optional[int] = None,
device_id: Optional[str] = None,
except_token_id: int | None = None,
device_id: str | None = None,
) -> None:
"""Invalidate access tokens belonging to a user
@ -1700,9 +1696,7 @@ class AuthHandler:
return await defer_to_thread(self.hs.get_reactor(), _do_hash)
async def validate_hash(
self, password: str, stored_hash: Union[bytes, str]
) -> bool:
async def validate_hash(self, password: str, stored_hash: bytes | str) -> bool:
"""Validates that self.hash(password) == stored_hash.
Args:
@ -1799,9 +1793,9 @@ class AuthHandler:
auth_provider_id: str,
request: Request,
client_redirect_url: str,
extra_attributes: Optional[JsonDict] = None,
extra_attributes: JsonDict | None = None,
new_user: bool = False,
auth_provider_session_id: Optional[str] = None,
auth_provider_session_id: str | None = None,
) -> None:
"""Having figured out a mxid for this user, complete the HTTP request
@ -1960,7 +1954,7 @@ def load_single_legacy_password_auth_provider(
# All methods that the module provides should be async, but this wasn't enforced
# in the old module system, so we wrap them if needed
def async_wrapper(f: Optional[Callable]) -> Optional[Callable[..., Awaitable]]:
def async_wrapper(f: Callable | None) -> Callable[..., Awaitable] | None:
# f might be None if the callback isn't implemented by the module. In this
# case we don't want to register a callback at all so we return None.
if f is None:
@ -1973,7 +1967,7 @@ def load_single_legacy_password_auth_provider(
async def wrapped_check_password(
username: str, login_type: str, login_dict: JsonDict
) -> Optional[tuple[str, Optional[Callable]]]:
) -> tuple[str, Callable | None] | None:
# We've already made sure f is not None above, but mypy doesn't do well
# across function boundaries so we need to tell it f is definitely not
# None.
@ -1992,12 +1986,12 @@ def load_single_legacy_password_auth_provider(
return wrapped_check_password
# We need to wrap check_auth as in the old form it could return
# just a str, but now it must return Optional[tuple[str, Optional[Callable]]
# just a str, but now it must return tuple[str, Callable | None] | None
if f.__name__ == "check_auth":
async def wrapped_check_auth(
username: str, login_type: str, login_dict: JsonDict
) -> Optional[tuple[str, Optional[Callable]]]:
) -> tuple[str, Callable | None] | None:
# We've already made sure f is not None above, but mypy doesn't do well
# across function boundaries so we need to tell it f is definitely not
# None.
@ -2013,12 +2007,12 @@ def load_single_legacy_password_auth_provider(
return wrapped_check_auth
# We need to wrap check_3pid_auth as in the old form it could return
# just a str, but now it must return Optional[tuple[str, Optional[Callable]]
# just a str, but now it must return tuple[str, Callable | None] | None
if f.__name__ == "check_3pid_auth":
async def wrapped_check_3pid_auth(
medium: str, address: str, password: str
) -> Optional[tuple[str, Optional[Callable]]]:
) -> tuple[str, Callable | None] | None:
# We've already made sure f is not None above, but mypy doesn't do well
# across function boundaries so we need to tell it f is definitely not
# None.
@ -2044,10 +2038,10 @@ def load_single_legacy_password_auth_provider(
# If the module has these methods implemented, then we pull them out
# and register them as hooks.
check_3pid_auth_hook: Optional[CHECK_3PID_AUTH_CALLBACK] = async_wrapper(
check_3pid_auth_hook: CHECK_3PID_AUTH_CALLBACK | None = async_wrapper(
getattr(provider, "check_3pid_auth", None)
)
on_logged_out_hook: Optional[ON_LOGGED_OUT_CALLBACK] = async_wrapper(
on_logged_out_hook: ON_LOGGED_OUT_CALLBACK | None = async_wrapper(
getattr(provider, "on_logged_out", None)
)
@ -2085,24 +2079,20 @@ def load_single_legacy_password_auth_provider(
CHECK_3PID_AUTH_CALLBACK = Callable[
[str, str, str],
Awaitable[
Optional[tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]]
],
Awaitable[tuple[str, Callable[["LoginResponse"], Awaitable[None]] | None] | None],
]
ON_LOGGED_OUT_CALLBACK = Callable[[str, Optional[str], str], Awaitable]
ON_LOGGED_OUT_CALLBACK = Callable[[str, str | None, str], Awaitable]
CHECK_AUTH_CALLBACK = Callable[
[str, str, JsonDict],
Awaitable[
Optional[tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]]
],
Awaitable[tuple[str, Callable[["LoginResponse"], Awaitable[None]] | None] | None],
]
GET_USERNAME_FOR_REGISTRATION_CALLBACK = Callable[
[JsonDict, JsonDict],
Awaitable[Optional[str]],
Awaitable[str | None],
]
GET_DISPLAYNAME_FOR_REGISTRATION_CALLBACK = Callable[
[JsonDict, JsonDict],
Awaitable[Optional[str]],
Awaitable[str | None],
]
IS_3PID_ALLOWED_CALLBACK = Callable[[str, str, bool], Awaitable[bool]]
@ -2133,18 +2123,15 @@ class PasswordAuthProvider:
def register_password_auth_provider_callbacks(
self,
check_3pid_auth: Optional[CHECK_3PID_AUTH_CALLBACK] = None,
on_logged_out: Optional[ON_LOGGED_OUT_CALLBACK] = None,
is_3pid_allowed: Optional[IS_3PID_ALLOWED_CALLBACK] = None,
auth_checkers: Optional[
dict[tuple[str, tuple[str, ...]], CHECK_AUTH_CALLBACK]
] = None,
get_username_for_registration: Optional[
GET_USERNAME_FOR_REGISTRATION_CALLBACK
] = None,
get_displayname_for_registration: Optional[
GET_DISPLAYNAME_FOR_REGISTRATION_CALLBACK
] = None,
check_3pid_auth: CHECK_3PID_AUTH_CALLBACK | None = None,
on_logged_out: ON_LOGGED_OUT_CALLBACK | None = None,
is_3pid_allowed: IS_3PID_ALLOWED_CALLBACK | None = None,
auth_checkers: dict[tuple[str, tuple[str, ...]], CHECK_AUTH_CALLBACK]
| None = None,
get_username_for_registration: GET_USERNAME_FOR_REGISTRATION_CALLBACK
| None = None,
get_displayname_for_registration: GET_DISPLAYNAME_FOR_REGISTRATION_CALLBACK
| None = None,
) -> None:
# Register check_3pid_auth callback
if check_3pid_auth is not None:
@ -2214,7 +2201,7 @@ class PasswordAuthProvider:
async def check_auth(
self, username: str, login_type: str, login_dict: JsonDict
) -> Optional[tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]]:
) -> tuple[str, Callable[["LoginResponse"], Awaitable[None]] | None] | None:
"""Check if the user has presented valid login credentials
Args:
@ -2245,14 +2232,14 @@ class PasswordAuthProvider:
continue
if result is not None:
# Check that the callback returned a Tuple[str, Optional[Callable]]
# Check that the callback returned a tuple[str, Callable | None]
# "type: ignore[unreachable]" is used after some isinstance checks because mypy thinks
# result is always the right type, but as it is 3rd party code it might not be
if not isinstance(result, tuple) or len(result) != 2:
logger.warning( # type: ignore[unreachable]
"Wrong type returned by module API callback %s: %s, expected"
" Optional[tuple[str, Optional[Callable]]]",
" tuple[str, Callable | None] | None",
callback,
result,
)
@ -2265,24 +2252,24 @@ class PasswordAuthProvider:
if not isinstance(str_result, str):
logger.warning( # type: ignore[unreachable]
"Wrong type returned by module API callback %s: %s, expected"
" Optional[tuple[str, Optional[Callable]]]",
" tuple[str, Callable | None] | None",
callback,
result,
)
continue
# the second should be Optional[Callable]
# the second should be Callable | None
if callback_result is not None:
if not callable(callback_result):
logger.warning( # type: ignore[unreachable]
"Wrong type returned by module API callback %s: %s, expected"
" Optional[tuple[str, Optional[Callable]]]",
" tuple[str, Callable | None] | None",
callback,
result,
)
continue
# The result is a (str, Optional[callback]) tuple so return the successful result
# The result is a (str, callback | None) tuple so return the successful result
return result
# If this point has been reached then none of the callbacks successfully authenticated
@ -2291,7 +2278,7 @@ class PasswordAuthProvider:
async def check_3pid_auth(
self, medium: str, address: str, password: str
) -> Optional[tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]]:
) -> tuple[str, Callable[["LoginResponse"], Awaitable[None]] | None] | None:
# This function is able to return a deferred that either
# resolves None, meaning authentication failure, or upon
# success, to a str (which is the user_id) or a tuple of
@ -2308,14 +2295,14 @@ class PasswordAuthProvider:
continue
if result is not None:
# Check that the callback returned a Tuple[str, Optional[Callable]]
# Check that the callback returned a tuple[str, Callable | None]
# "type: ignore[unreachable]" is used after some isinstance checks because mypy thinks
# result is always the right type, but as it is 3rd party code it might not be
if not isinstance(result, tuple) or len(result) != 2:
logger.warning( # type: ignore[unreachable]
"Wrong type returned by module API callback %s: %s, expected"
" Optional[tuple[str, Optional[Callable]]]",
" tuple[str, Callable | None] | None",
callback,
result,
)
@ -2328,24 +2315,24 @@ class PasswordAuthProvider:
if not isinstance(str_result, str):
logger.warning( # type: ignore[unreachable]
"Wrong type returned by module API callback %s: %s, expected"
" Optional[tuple[str, Optional[Callable]]]",
" tuple[str, Callable | None] | None",
callback,
result,
)
continue
# the second should be Optional[Callable]
# the second should be Callable | None
if callback_result is not None:
if not callable(callback_result):
logger.warning( # type: ignore[unreachable]
"Wrong type returned by module API callback %s: %s, expected"
" Optional[tuple[str, Optional[Callable]]]",
" tuple[str, Callable | None] | None",
callback,
result,
)
continue
# The result is a (str, Optional[callback]) tuple so return the successful result
# The result is a (str, callback | None) tuple so return the successful result
return result
# If this point has been reached then none of the callbacks successfully authenticated
@ -2353,7 +2340,7 @@ class PasswordAuthProvider:
return None
async def on_logged_out(
self, user_id: str, device_id: Optional[str], access_token: str
self, user_id: str, device_id: str | None, access_token: str
) -> None:
# call all of the on_logged_out callbacks
for callback in self.on_logged_out_callbacks:
@ -2367,7 +2354,7 @@ class PasswordAuthProvider:
self,
uia_results: JsonDict,
params: JsonDict,
) -> Optional[str]:
) -> str | None:
"""Defines the username to use when registering the user, using the credentials
and parameters provided during the UIA flow.
@ -2412,7 +2399,7 @@ class PasswordAuthProvider:
self,
uia_results: JsonDict,
params: JsonDict,
) -> Optional[str]:
) -> str | None:
"""Defines the display name to use when registering the user, using the
credentials and parameters provided during the UIA flow.

Some files were not shown because too many files have changed in this diff Show more