diff --git a/changelog.d/19204.feature b/changelog.d/19204.feature new file mode 100644 index 0000000000..04a2d93e32 --- /dev/null +++ b/changelog.d/19204.feature @@ -0,0 +1 @@ +Add a new config option [`enable_local_media_storage`](https://element-hq.github.io/synapse/latest/usage/configuration/config_documentation.html#enable_local_media_storage) which controls whether media is additionally stored locally when using configured `media_storage_providers`. Setting this to `false` allows off-site media storage without a local cache. Contributed by Patrice Brend'amour @dr.allgood. diff --git a/docs/usage/configuration/config_documentation.md b/docs/usage/configuration/config_documentation.md index badfe0a03f..98dbc63d78 100644 --- a/docs/usage/configuration/config_documentation.md +++ b/docs/usage/configuration/config_documentation.md @@ -2111,6 +2111,16 @@ Example configuration: enable_media_repo: false ``` --- +### `enable_local_media_storage` + +*(boolean)* Enable the local on-disk media storage provider. When disabled, media is stored only in configured `media_storage_providers` and temporary files are used for processing. +**Warning:** If this option is set to `false` and no `media_storage_providers` are configured, all media requests will return 404 errors as there will be no storage backend available. Defaults to `true`. + +Example configuration: +```yaml +enable_local_media_storage: false +``` +--- ### `media_store_path` *(string)* Directory where uploaded images and attachments are stored. Defaults to `"media_store"`. diff --git a/schema/synapse-config.schema.yaml b/schema/synapse-config.schema.yaml index ca8db9c9ee..d20520bec3 100644 --- a/schema/synapse-config.schema.yaml +++ b/schema/synapse-config.schema.yaml @@ -2348,6 +2348,19 @@ properties: default: true examples: - false + enable_local_media_storage: + type: boolean + description: >- + Enable the local on-disk media storage provider. When disabled, media is + stored only in configured `media_storage_providers` and temporary files are + used for processing. + + **Warning:** If this option is set to `false` and no `media_storage_providers` + are configured, all media requests will return 404 errors as there will be + no storage backend available. + default: true + examples: + - false media_store_path: type: string description: Directory where uploaded images and attachments are stored. diff --git a/synapse/config/repository.py b/synapse/config/repository.py index 221130b0cd..c87442aace 100644 --- a/synapse/config/repository.py +++ b/synapse/config/repository.py @@ -174,6 +174,11 @@ class ContentRepositoryConfig(Config): config.get("media_store_path", "media_store") ) + # Whether to enable the local media storage provider. When disabled, + # media will only be stored in configured storage providers and temp + # files will be used for processing. + self.enable_local_media_storage = config.get("enable_local_media_storage", True) + backup_media_store_path = config.get("backup_media_store_path") synchronous_backup_media_store = config.get( diff --git a/synapse/media/media_repository.py b/synapse/media/media_repository.py index 8d38c1655f..3344e4c7be 100644 --- a/synapse/media/media_repository.py +++ b/synapse/media/media_repository.py @@ -64,7 +64,10 @@ from synapse.media.media_storage import ( SHA256TransparentIOReader, SHA256TransparentIOWriter, ) -from synapse.media.storage_provider import StorageProviderWrapper +from synapse.media.storage_provider import ( + FileStorageProviderBackend, + StorageProviderWrapper, +) from synapse.media.thumbnailer import Thumbnailer, ThumbnailError from synapse.media.url_previewer import UrlPreviewer from synapse.storage.databases.main.media_repository import LocalMedia, RemoteMedia @@ -142,10 +145,23 @@ class MediaRepository: ) storage_providers.append(provider) + # If local media storage is enabled, create the local provider + local_provider: FileStorageProviderBackend | None = None + if hs.config.media.enable_local_media_storage and self.primary_base_path: + local_provider = FileStorageProviderBackend(hs, self.primary_base_path) + self.media_storage: MediaStorage = MediaStorage( - self.hs, self.primary_base_path, self.filepaths, storage_providers + self.hs, self.filepaths, storage_providers, local_provider ) + # Log a warning if there are no storage backends configured + if not hs.config.media.enable_local_media_storage and not storage_providers: + logger.warning( + "Local media storage is disabled and no media_storage_providers are " + "configured. All media requests will return 404 errors as there is " + "no storage backend available." + ) + self.clock.looping_call( self._start_update_recently_accessed, UPDATE_RECENTLY_ACCESSED_TS ) @@ -782,10 +798,18 @@ class MediaRepository: except SynapseError: raise except Exception as e: - # An exception may be because we downloaded media in another - # process, so let's check if we magically have the media. - media_info = await self.store.get_cached_remote_media(server_name, media_id) - if not media_info: + # If this is a constraint violation, it means another worker + # downloaded the media first. We should fetch the existing media info. + if isinstance(e, self.store.database_engine.module.IntegrityError): + # The file has already been cleaned up in _download_remote_file + # Just fetch the existing media info + media_info = await self.store.get_cached_remote_media( + server_name, media_id + ) + if not media_info: + # This shouldn't happen, but let's raise an error if it does + raise SynapseError(500, "Failed to fetch remote media") + else: raise e file_id = media_info.filesystem_id @@ -806,6 +830,39 @@ class MediaRepository: responder = await self.media_storage.fetch_media(file_info) return responder, media_info + async def _store_remote_media_with_cleanup( + self, + server_name: str, + media_id: str, + media_type: str, + time_now_ms: int, + upload_name: str | None, + media_length: int, + filesystem_id: str, + sha256: str, + fname: str, + ) -> None: + """Store remote media in database and clean up file on constraint violation.""" + try: + await self.store.store_cached_remote_media( + origin=server_name, + media_id=media_id, + media_type=media_type, + time_now_ms=time_now_ms, + upload_name=upload_name, + media_length=media_length, + filesystem_id=filesystem_id, + sha256=sha256, + ) + except self.store.database_engine.module.IntegrityError: + # Another worker downloaded the media first. Clean up our file. + try: + os.remove(fname) + except Exception: + pass + # Re-raise so the caller can handle it + raise + async def _download_remote_file( self, server_name: str, @@ -890,26 +947,21 @@ class MediaRepository: upload_name = get_filename_from_headers(headers) time_now_ms = self.clock.time_msec() - # Multiple remote media download requests can race (when using - # multiple media repos), so this may throw a violation constraint - # exception. If it does we'll delete the newly downloaded file from - # disk (as we're in the ctx manager). - # - # However: we've already called `finish()` so we may have also - # written to the storage providers. This is preferable to the - # alternative where we call `finish()` *after* this, where we could - # end up having an entry in the DB but fail to write the files to - # the storage providers. - await self.store.store_cached_remote_media( - origin=server_name, - media_id=media_id, - media_type=media_type, - time_now_ms=time_now_ms, - upload_name=upload_name, - media_length=length, - filesystem_id=file_id, - sha256=sha256writer.hexdigest(), - ) + # Multiple remote media download requests can race (when using + # multiple media repos), so this may throw a violation constraint + # exception. If it does we'll delete the newly downloaded file from + # disk. + await self._store_remote_media_with_cleanup( + server_name=server_name, + media_id=media_id, + media_type=media_type, + time_now_ms=time_now_ms, + upload_name=upload_name, + media_length=length, + filesystem_id=file_id, + sha256=sha256writer.hexdigest(), + fname=fname, + ) logger.info("Stored remote media in file %r", fname) @@ -1023,26 +1075,21 @@ class MediaRepository: upload_name = get_filename_from_headers(headers) time_now_ms = self.clock.time_msec() - # Multiple remote media download requests can race (when using - # multiple media repos), so this may throw a violation constraint - # exception. If it does we'll delete the newly downloaded file from - # disk (as we're in the ctx manager). - # - # However: we've already called `finish()` so we may have also - # written to the storage providers. This is preferable to the - # alternative where we call `finish()` *after* this, where we could - # end up having an entry in the DB but fail to write the files to - # the storage providers. - await self.store.store_cached_remote_media( - origin=server_name, - media_id=media_id, - media_type=media_type, - time_now_ms=time_now_ms, - upload_name=upload_name, - media_length=length, - filesystem_id=file_id, - sha256=sha256writer.hexdigest(), - ) + # Multiple remote media download requests can race (when using + # multiple media repos), so this may throw a violation constraint + # exception. If it does we'll delete the newly downloaded file from + # disk. + await self._store_remote_media_with_cleanup( + server_name=server_name, + media_id=media_id, + media_type=media_type, + time_now_ms=time_now_ms, + upload_name=upload_name, + media_length=length, + filesystem_id=file_id, + sha256=sha256writer.hexdigest(), + fname=fname, + ) logger.debug("Stored remote media in file %r", fname) @@ -1115,32 +1162,31 @@ class MediaRepository: t_type: str, url_cache: bool, ) -> tuple[str, FileInfo] | None: - input_path = await self.media_storage.ensure_media_is_in_local_cache( + async with self.media_storage.ensure_media_is_in_local_cache( FileInfo(None, media_id, url_cache=url_cache) - ) + ) as input_path: + try: + thumbnailer = Thumbnailer(input_path) + except ThumbnailError as e: + logger.warning( + "Unable to generate a thumbnail for local media %s using a method of %s and type of %s: %s", + media_id, + t_method, + t_type, + e, + ) + return None - try: - thumbnailer = Thumbnailer(input_path) - except ThumbnailError as e: - logger.warning( - "Unable to generate a thumbnail for local media %s using a method of %s and type of %s: %s", - media_id, - t_method, - t_type, - e, - ) - return None - - with thumbnailer: - t_byte_source = await defer_to_thread( - self.hs.get_reactor(), - self._generate_thumbnail, - thumbnailer, - t_width, - t_height, - t_method, - t_type, - ) + with thumbnailer: + t_byte_source = await defer_to_thread( + self.hs.get_reactor(), + self._generate_thumbnail, + thumbnailer, + t_width, + t_height, + t_method, + t_type, + ) if t_byte_source: try: @@ -1191,33 +1237,32 @@ class MediaRepository: t_method: str, t_type: str, ) -> str | None: - input_path = await self.media_storage.ensure_media_is_in_local_cache( + async with self.media_storage.ensure_media_is_in_local_cache( FileInfo(server_name, file_id) - ) + ) as input_path: + try: + thumbnailer = Thumbnailer(input_path) + except ThumbnailError as e: + logger.warning( + "Unable to generate a thumbnail for remote media %s from %s using a method of %s and type of %s: %s", + media_id, + server_name, + t_method, + t_type, + e, + ) + return None - try: - thumbnailer = Thumbnailer(input_path) - except ThumbnailError as e: - logger.warning( - "Unable to generate a thumbnail for remote media %s from %s using a method of %s and type of %s: %s", - media_id, - server_name, - t_method, - t_type, - e, - ) - return None - - with thumbnailer: - t_byte_source = await defer_to_thread( - self.hs.get_reactor(), - self._generate_thumbnail, - thumbnailer, - t_width, - t_height, - t_method, - t_type, - ) + with thumbnailer: + t_byte_source = await defer_to_thread( + self.hs.get_reactor(), + self._generate_thumbnail, + thumbnailer, + t_width, + t_height, + t_method, + t_type, + ) if t_byte_source: try: @@ -1287,151 +1332,157 @@ class MediaRepository: if not requirements: return None - input_path = await self.media_storage.ensure_media_is_in_local_cache( + async with self.media_storage.ensure_media_is_in_local_cache( FileInfo(server_name, file_id, url_cache=url_cache) - ) - - try: - thumbnailer = Thumbnailer(input_path) - except ThumbnailError as e: - logger.warning( - "Unable to generate thumbnails for remote media %s from %s of type %s: %s", - media_id, - server_name, - media_type, - e, - ) - return None - - with thumbnailer: - m_width = thumbnailer.width - m_height = thumbnailer.height - - if m_width * m_height >= self.max_image_pixels: - logger.info( - "Image too large to thumbnail %r x %r > %r", - m_width, - m_height, - self.max_image_pixels, + ) as input_path: + try: + thumbnailer = Thumbnailer(input_path) + except ThumbnailError as e: + logger.warning( + "Unable to generate thumbnails for remote media %s from %s of type %s: %s", + media_id, + server_name, + media_type, + e, ) return None - if thumbnailer.transpose_method is not None: - m_width, m_height = await defer_to_thread( - self.hs.get_reactor(), thumbnailer.transpose - ) + with thumbnailer: + m_width = thumbnailer.width + m_height = thumbnailer.height - # We deduplicate the thumbnail sizes by ignoring the cropped versions if - # they have the same dimensions of a scaled one. - thumbnails: dict[tuple[int, int, str], str] = {} - for requirement in requirements: - if requirement.method == "crop": - thumbnails.setdefault( - (requirement.width, requirement.height, requirement.media_type), - requirement.method, + if m_width * m_height >= self.max_image_pixels: + logger.info( + "Image too large to thumbnail %r x %r > %r", + m_width, + m_height, + self.max_image_pixels, ) - elif requirement.method == "scale": - t_width, t_height = thumbnailer.aspect( - requirement.width, requirement.height - ) - t_width = min(m_width, t_width) - t_height = min(m_height, t_height) - thumbnails[(t_width, t_height, requirement.media_type)] = ( - requirement.method + return None + + if thumbnailer.transpose_method is not None: + m_width, m_height = await defer_to_thread( + self.hs.get_reactor(), thumbnailer.transpose ) - # Now we generate the thumbnails for each dimension, store it - for (t_width, t_height, t_type), t_method in thumbnails.items(): - # Generate the thumbnail - if t_method == "crop": - t_byte_source = await defer_to_thread( - self.hs.get_reactor(), - thumbnailer.crop, - t_width, - t_height, - t_type, + # We deduplicate the thumbnail sizes by ignoring the cropped versions if + # they have the same dimensions of a scaled one. + thumbnails: dict[tuple[int, int, str], str] = {} + for requirement in requirements: + if requirement.method == "crop": + thumbnails.setdefault( + ( + requirement.width, + requirement.height, + requirement.media_type, + ), + requirement.method, + ) + elif requirement.method == "scale": + t_width, t_height = thumbnailer.aspect( + requirement.width, requirement.height + ) + t_width = min(m_width, t_width) + t_height = min(m_height, t_height) + thumbnails[(t_width, t_height, requirement.media_type)] = ( + requirement.method + ) + + # Now we generate the thumbnails for each dimension, store it + for (t_width, t_height, t_type), t_method in thumbnails.items(): + # Generate the thumbnail + if t_method == "crop": + t_byte_source = await defer_to_thread( + self.hs.get_reactor(), + thumbnailer.crop, + t_width, + t_height, + t_type, + ) + elif t_method == "scale": + t_byte_source = await defer_to_thread( + self.hs.get_reactor(), + thumbnailer.scale, + t_width, + t_height, + t_type, + ) + else: + logger.error("Unrecognized method: %r", t_method) + continue + + if not t_byte_source: + continue + + file_info = FileInfo( + server_name=server_name, + file_id=file_id, + url_cache=url_cache, + thumbnail=ThumbnailInfo( + width=t_width, + height=t_height, + method=t_method, + type=t_type, + length=t_byte_source.tell(), + ), ) - elif t_method == "scale": - t_byte_source = await defer_to_thread( - self.hs.get_reactor(), - thumbnailer.scale, - t_width, - t_height, - t_type, - ) - else: - logger.error("Unrecognized method: %r", t_method) - continue - if not t_byte_source: - continue - - file_info = FileInfo( - server_name=server_name, - file_id=file_id, - url_cache=url_cache, - thumbnail=ThumbnailInfo( - width=t_width, - height=t_height, - method=t_method, - type=t_type, - length=t_byte_source.tell(), - ), - ) - - async with self.media_storage.store_into_file(file_info) as (f, fname): - try: - await self.media_storage.write_to_file(t_byte_source, f) - finally: - t_byte_source.close() - - # We flush and close the file to ensure that the bytes have - # been written before getting the size. - f.flush() - f.close() - - t_len = os.path.getsize(fname) - - # Write to database - if server_name: - # Multiple remote media download requests can race (when - # using multiple media repos), so this may throw a violation - # constraint exception. If it does we'll delete the newly - # generated thumbnail from disk (as we're in the ctx - # manager). - # - # However: we've already called `finish()` so we may have - # also written to the storage providers. This is preferable - # to the alternative where we call `finish()` *after* this, - # where we could end up having an entry in the DB but fail - # to write the files to the storage providers. + async with self.media_storage.store_into_file(file_info) as ( + f, + fname, + ): try: - await self.store.store_remote_media_thumbnail( - server_name, - media_id, - file_id, - t_width, - t_height, - t_type, - t_method, - t_len, - ) - except Exception as e: - thumbnail_exists = ( - await self.store.get_remote_media_thumbnail( + await self.media_storage.write_to_file(t_byte_source, f) + finally: + t_byte_source.close() + + # We flush and close the file to ensure that the bytes have + # been written before getting the size. + f.flush() + f.close() + + t_len = os.path.getsize(fname) + + # Write to database + if server_name: + # Multiple remote media download requests can race (when + # using multiple media repos), so this may throw a violation + # constraint exception. If it does we'll delete the newly + # generated thumbnail from disk (as we're in the ctx + # manager). + # + # However: we've already called `finish()` so we may have + # also written to the storage providers. This is preferable + # to the alternative where we call `finish()` *after* this, + # where we could end up having an entry in the DB but fail + # to write the files to the storage providers. + try: + await self.store.store_remote_media_thumbnail( server_name, media_id, + file_id, t_width, t_height, t_type, + t_method, + t_len, ) + except Exception as e: + thumbnail_exists = ( + await self.store.get_remote_media_thumbnail( + server_name, + media_id, + t_width, + t_height, + t_type, + ) + ) + if not thumbnail_exists: + raise e + else: + await self.store.store_local_thumbnail( + media_id, t_width, t_height, t_type, t_method, t_len ) - if not thumbnail_exists: - raise e - else: - await self.store.store_local_thumbnail( - media_id, t_width, t_height, t_type, t_method, t_len - ) return {"width": m_width, "height": m_height} diff --git a/synapse/media/media_storage.py b/synapse/media/media_storage.py index e83869bf4d..4c942f07f8 100644 --- a/synapse/media/media_storage.py +++ b/synapse/media/media_storage.py @@ -24,6 +24,7 @@ import json import logging import os import shutil +import tempfile from contextlib import closing from io import BytesIO from types import TracebackType @@ -49,13 +50,13 @@ from twisted.internet.interfaces import IConsumer from synapse.api.errors import NotFoundError from synapse.logging.context import defer_to_thread, run_in_background from synapse.logging.opentracing import start_active_span, trace, trace_with_opname -from synapse.media._base import ThreadedFileSender +from synapse.media.storage_provider import FileStorageProviderBackend from synapse.util.clock import Clock from synapse.util.duration import Duration from synapse.util.file_consumer import BackgroundFileConsumer from ..types import JsonDict -from ._base import FileInfo, Responder +from ._base import FileInfo, Responder, ThreadedFileSender from .filepath import MediaFilePaths if TYPE_CHECKING: @@ -150,27 +151,30 @@ class SHA256TransparentIOReader: class MediaStorage: - """Responsible for storing/fetching files from local sources. + """Responsible for storing/fetching files from storage providers. Args: hs - local_media_directory: Base path where we store media on disk filepaths storage_providers: List of StorageProvider that are used to fetch and store files. + local_provider: Optional local file storage provider for caching media on disk. """ def __init__( self, hs: "HomeServer", - local_media_directory: str, filepaths: MediaFilePaths, storage_providers: Sequence["StorageProvider"], + local_provider: "FileStorageProviderBackend | None" = None, ): self.hs = hs self.reactor = hs.get_reactor() - self.local_media_directory = local_media_directory self.filepaths = filepaths self.storage_providers = storage_providers + self.local_provider = local_provider + self.local_media_directory: str | None = None + if local_provider is not None: + self.local_media_directory = local_provider.base_directory self._spam_checker_module_callbacks = hs.get_module_api_callbacks().spam_checker self.clock = hs.get_clock() @@ -205,11 +209,11 @@ class MediaStorage: """Async Context manager used to get a file like object to write into, as described by file_info. - Actually yields a 2-tuple (file, fname,), where file is a file - like object that can be written to and fname is the absolute path of file - on disk. + Actually yields a 2-tuple (file, media_filepath,), where file is a file + like object that can be written to and media_filepath is the absolute path + of the file on disk. - fname can be used to read the contents from after upload, e.g. to + media_filepath can be used to read the contents from after upload, e.g. to generate thumbnails. Args: @@ -217,25 +221,33 @@ class MediaStorage: Example: - async with media_storage.store_into_file(info) as (f, fname,): + async with media_storage.store_into_file(info) as (f, media_filepath,): # .. write into f ... """ path = self._file_info_to_path(file_info) - fname = os.path.join(self.local_media_directory, path) + is_temp_file = False - dirname = os.path.dirname(fname) - os.makedirs(dirname, exist_ok=True) + if self.local_provider: + media_filepath = os.path.join(self.local_media_directory, path) # type: ignore[arg-type] + os.makedirs(os.path.dirname(media_filepath), exist_ok=True) - try: with start_active_span("writing to main media repo"): - with open(fname, "wb") as f: - yield f, fname + with open(media_filepath, "wb") as f: + yield f, media_filepath + else: + # No local provider, write to temp file + is_temp_file = True + with tempfile.NamedTemporaryFile(delete=False) as f: + media_filepath = f.name + yield cast(BinaryIO, f), media_filepath - with start_active_span("writing to other storage providers"): + # Spam check and store to other providers (runs for both local and temp file cases) + try: + with start_active_span("spam checking and writing to storage providers"): spam_check = ( await self._spam_checker_module_callbacks.check_media_file_for_spam( - ReadableFileWrapper(self.clock, fname), file_info + ReadableFileWrapper(self.clock, media_filepath), file_info ) ) if spam_check != self._spam_checker_module_callbacks.NOT_SPAM: @@ -251,17 +263,23 @@ class MediaStorage: with start_active_span(str(provider)): await provider.store_file(path, file_info) + # If using a temp file, delete it after uploading to storage providers + if is_temp_file: + try: + os.remove(media_filepath) + except Exception: + pass + except Exception as e: try: - os.remove(fname) + os.remove(media_filepath) except Exception: pass raise e from None async def fetch_media(self, file_info: FileInfo) -> Responder | None: - """Attempts to fetch media described by file_info from the local cache - and configured storage providers. + """Attempts to fetch media described by file_info from the configured storage providers. Args: file_info: Metadata about the media file @@ -269,6 +287,18 @@ class MediaStorage: Returns: Returns a Responder if the file was found, otherwise None. """ + # URL cache files are stored locally and should not go through storage providers + if file_info.url_cache: + path = self._file_info_to_path(file_info) + if self.local_provider: + local_path = os.path.join(self.local_media_directory, path) # type: ignore[arg-type] + if os.path.isfile(local_path): + # Import here to avoid circular import + from .media_storage import FileResponder + + return FileResponder(self.hs, open(local_path, "rb")) + return None + paths = [self._file_info_to_path(file_info)] # fallback for remote thumbnails with no method in the filename @@ -283,16 +313,18 @@ class MediaStorage: ) ) - for path in paths: - local_path = os.path.join(self.local_media_directory, path) - if os.path.exists(local_path): - logger.debug("responding with local file %s", local_path) - return FileResponder(self.hs, open(local_path, "rb")) - logger.debug("local file %s did not exist", local_path) + # Check local provider first, then other storage providers + if self.local_provider: + for path in paths: + res: Any = await self.local_provider.fetch(path, file_info) + if res: + logger.debug("Streaming %s from %s", path, self.local_provider) + return res + logger.debug("%s not found on %s", path, self.local_provider) for provider in self.storage_providers: for path in paths: - res: Any = await provider.fetch(path, file_info) + res = await provider.fetch(path, file_info) if res: logger.debug("Streaming %s from %s", path, provider) return res @@ -301,50 +333,93 @@ class MediaStorage: return None @trace - async def ensure_media_is_in_local_cache(self, file_info: FileInfo) -> str: - """Ensures that the given file is in the local cache. Attempts to - download it from storage providers if it isn't. + @contextlib.asynccontextmanager + async def ensure_media_is_in_local_cache( + self, file_info: FileInfo + ) -> AsyncIterator[str]: + """Async context manager that ensures the given file is in the local cache. + Attempts to download it from storage providers if it isn't. + + When no local provider is configured, the file is downloaded to a temporary + location and automatically cleaned up when the context manager exits. Args: file_info - Returns: + Yields: Full path to local file + + Example: + async with media_storage.ensure_media_is_in_local_cache(file_info) as path: + # use path to read the file """ path = self._file_info_to_path(file_info) - local_path = os.path.join(self.local_media_directory, path) - if os.path.exists(local_path): - return local_path + if self.local_provider: + local_path = os.path.join(self.local_media_directory, path) # type: ignore[arg-type] + if os.path.exists(local_path): + yield local_path + return - # Fallback for paths without method names - # Should be removed in the future - if file_info.thumbnail and file_info.server_name: - legacy_path = self.filepaths.remote_media_thumbnail_rel_legacy( - server_name=file_info.server_name, - file_id=file_info.file_id, - width=file_info.thumbnail.width, - height=file_info.thumbnail.height, - content_type=file_info.thumbnail.type, - ) - legacy_local_path = os.path.join(self.local_media_directory, legacy_path) - if os.path.exists(legacy_local_path): - return legacy_local_path + # Fallback for paths without method names + # Should be removed in the future + if file_info.thumbnail and file_info.server_name: + legacy_path = self.filepaths.remote_media_thumbnail_rel_legacy( + server_name=file_info.server_name, + file_id=file_info.file_id, + width=file_info.thumbnail.width, + height=file_info.thumbnail.height, + content_type=file_info.thumbnail.type, + ) + legacy_local_path = os.path.join( + self.local_media_directory, # type: ignore[arg-type] + legacy_path, + ) + if os.path.exists(legacy_local_path): + yield legacy_local_path + return - dirname = os.path.dirname(local_path) - os.makedirs(dirname, exist_ok=True) + os.makedirs(os.path.dirname(local_path), exist_ok=True) - for provider in self.storage_providers: - res: Any = await provider.fetch(path, file_info) - if res: - with res: - consumer = BackgroundFileConsumer( - open(local_path, "wb"), self.reactor - ) - await res.write_to_consumer(consumer) - await consumer.wait() - return local_path + for provider in self.storage_providers: + remote_res: Any = await provider.fetch(path, file_info) + if remote_res: + with remote_res: + consumer = BackgroundFileConsumer( + open(local_path, "wb"), self.reactor + ) + await remote_res.write_to_consumer(consumer) + await consumer.wait() + yield local_path + return - raise NotFoundError() + raise NotFoundError() + else: + # No local provider, download to temp file and clean up after use + for provider in self.storage_providers: + res: Any = await provider.fetch(path, file_info) + if res: + temp_path = None + try: + with tempfile.NamedTemporaryFile( + delete=False, suffix=os.path.splitext(path)[1] + ) as tmp: + temp_path = tmp.name + with res: + consumer = BackgroundFileConsumer( + open(temp_path, "wb"), self.reactor + ) + await res.write_to_consumer(consumer) + await consumer.wait() + yield temp_path + finally: + if temp_path: + try: + os.remove(temp_path) + except Exception: + pass + return + + raise NotFoundError() @trace def _file_info_to_path(self, file_info: FileInfo) -> str: diff --git a/synapse/media/storage_provider.py b/synapse/media/storage_provider.py index a87ffa0892..c5faa25b96 100644 --- a/synapse/media/storage_provider.py +++ b/synapse/media/storage_provider.py @@ -31,7 +31,6 @@ from synapse.logging.opentracing import start_active_span, trace_with_opname from synapse.util.async_helpers import maybe_awaitable from ._base import FileInfo, Responder -from .media_storage import FileResponder logger = logging.getLogger(__name__) @@ -178,6 +177,9 @@ class FileStorageProviderBackend(StorageProvider): backup_fname = os.path.join(self.base_directory, path) if os.path.isfile(backup_fname): + # Import here to avoid circular import + from .media_storage import FileResponder + return FileResponder(self.hs, open(backup_fname, "rb")) return None diff --git a/synapse/media/thumbnailer.py b/synapse/media/thumbnailer.py index fd65131c63..cb86a0ff66 100644 --- a/synapse/media/thumbnailer.py +++ b/synapse/media/thumbnailer.py @@ -633,13 +633,6 @@ class ThumbnailProvider: # width/height/method so we can just call the "generate exact" # methods. - # First let's check that we do actually have the original image - # still. This will throw a 404 if we don't. - # TODO: We should refetch the thumbnails for remote media. - await self.media_storage.ensure_media_is_in_local_cache( - FileInfo(server_name, file_id, url_cache=url_cache) - ) - if server_name: await self.media_repo.generate_remote_exact_thumbnail( server_name, diff --git a/tests/federation/test_federation_media.py b/tests/federation/test_federation_media.py index 1e849fa605..9c4e813537 100644 --- a/tests/federation/test_federation_media.py +++ b/tests/federation/test_federation_media.py @@ -49,6 +49,7 @@ class FederationMediaDownloadsTest(unittest.FederatingHomeserverTestCase): hs.config.media.media_store_path = self.primary_base_path + local_provider = FileStorageProviderBackend(hs, self.primary_base_path) storage_providers = [ StorageProviderWrapper( FileStorageProviderBackend(hs, self.secondary_base_path), @@ -60,7 +61,7 @@ class FederationMediaDownloadsTest(unittest.FederatingHomeserverTestCase): self.filepaths = MediaFilePaths(self.primary_base_path) self.media_storage = MediaStorage( - hs, self.primary_base_path, self.filepaths, storage_providers + hs, self.filepaths, storage_providers, local_provider ) self.media_repo = hs.get_media_repository() @@ -187,7 +188,7 @@ class FederationMediaDownloadsTest(unittest.FederatingHomeserverTestCase): self.assertNotIn("body", channel.result) -class FederationThumbnailTest(unittest.FederatingHomeserverTestCase): +class FederationMediaTest(unittest.FederatingHomeserverTestCase): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: super().prepare(reactor, clock, hs) self.test_dir = tempfile.mkdtemp(prefix="synapse-tests-") @@ -197,6 +198,7 @@ class FederationThumbnailTest(unittest.FederatingHomeserverTestCase): hs.config.media.media_store_path = self.primary_base_path + local_provider = FileStorageProviderBackend(hs, self.primary_base_path) storage_providers = [ StorageProviderWrapper( FileStorageProviderBackend(hs, self.secondary_base_path), @@ -208,7 +210,116 @@ class FederationThumbnailTest(unittest.FederatingHomeserverTestCase): self.filepaths = MediaFilePaths(self.primary_base_path) self.media_storage = MediaStorage( - hs, self.primary_base_path, self.filepaths, storage_providers + hs, self.filepaths, storage_providers, local_provider + ) + self.media_repo = hs.get_media_repository() + + def test_thumbnail_download_scaled(self) -> None: + content = io.BytesIO(small_png.data) + content_uri = self.get_success( + self.media_repo.create_or_update_content( + "image/png", + "test_png_thumbnail", + content, + 67, + UserID.from_string("@user_id:whatever.org"), + ) + ) + # test with an image file + channel = self.make_signed_federation_request( + "GET", + f"/_matrix/federation/v1/media/thumbnail/{content_uri.media_id}?width=32&height=32&method=scale", + ) + self.pump() + self.assertEqual(200, channel.code) + + content_type = channel.headers.getRawHeaders("content-type") + assert content_type is not None + assert "multipart/mixed" in content_type[0] + assert "boundary" in content_type[0] + + # extract boundary + boundary = content_type[0].split("boundary=")[1] + # split on boundary and check that json field and expected value exist + body = channel.result.get("body") + assert body is not None + stripped_bytes = body.split(b"\r\n" + b"--" + boundary.encode("utf-8")) + found_json = any( + b"\r\nContent-Type: application/json\r\n\r\n{}" in field + for field in stripped_bytes + ) + self.assertTrue(found_json) + + # check that the png file exists and matches the expected scaled bytes + found_file = any(small_png.expected_scaled in field for field in stripped_bytes) + self.assertTrue(found_file) + + def test_thumbnail_download_cropped(self) -> None: + content = io.BytesIO(small_png.data) + content_uri = self.get_success( + self.media_repo.create_or_update_content( + "image/png", + "test_png_thumbnail", + content, + 67, + UserID.from_string("@user_id:whatever.org"), + ) + ) + # test with an image file + channel = self.make_signed_federation_request( + "GET", + f"/_matrix/federation/v1/media/thumbnail/{content_uri.media_id}?width=32&height=32&method=crop", + ) + self.pump() + self.assertEqual(200, channel.code) + + content_type = channel.headers.getRawHeaders("content-type") + assert content_type is not None + assert "multipart/mixed" in content_type[0] + assert "boundary" in content_type[0] + + # extract boundary + boundary = content_type[0].split("boundary=")[1] + # split on boundary and check that json field and expected value exist + body = channel.result.get("body") + assert body is not None + stripped_bytes = body.split(b"\r\n" + b"--" + boundary.encode("utf-8")) + found_json = any( + b"\r\nContent-Type: application/json\r\n\r\n{}" in field + for field in stripped_bytes + ) + self.assertTrue(found_json) + + # check that the png file exists and matches the expected cropped bytes + found_file = any( + small_png.expected_cropped in field for field in stripped_bytes + ) + self.assertTrue(found_file) + + +class FederationThumbnailTest(unittest.FederatingHomeserverTestCase): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + super().prepare(reactor, clock, hs) + self.test_dir = tempfile.mkdtemp(prefix="synapse-tests-") + self.addCleanup(shutil.rmtree, self.test_dir) + self.primary_base_path = os.path.join(self.test_dir, "primary") + self.secondary_base_path = os.path.join(self.test_dir, "secondary") + + hs.config.media.media_store_path = self.primary_base_path + + local_provider = FileStorageProviderBackend(hs, self.primary_base_path) + storage_providers = [ + StorageProviderWrapper( + FileStorageProviderBackend(hs, self.secondary_base_path), + store_local=True, + store_remote=False, + store_synchronous=True, + ) + ] + + self.filepaths = MediaFilePaths(self.primary_base_path) + self.media_storage = MediaStorage( + hs, self.filepaths, storage_providers, local_provider ) self.media_repo = hs.get_media_repository() diff --git a/tests/media/test_media_storage.py b/tests/media/test_media_storage.py index e56354e0b3..631718a366 100644 --- a/tests/media/test_media_storage.py +++ b/tests/media/test_media_storage.py @@ -48,7 +48,10 @@ from synapse.logging.context import make_deferred_yieldable from synapse.media._base import FileInfo, ThumbnailInfo from synapse.media.filepath import MediaFilePaths from synapse.media.media_storage import MediaStorage, ReadableFileWrapper -from synapse.media.storage_provider import FileStorageProviderBackend +from synapse.media.storage_provider import ( + FileStorageProviderBackend, + StorageProviderWrapper, +) from synapse.media.thumbnailer import ThumbnailProvider from synapse.module_api import ModuleApi from synapse.module_api.callbacks.spamchecker_callbacks import load_legacy_spam_checkers @@ -77,11 +80,19 @@ class MediaStorageTests(unittest.HomeserverTestCase): hs.config.media.media_store_path = self.primary_base_path - storage_providers = [FileStorageProviderBackend(hs, self.secondary_base_path)] + local_provider = FileStorageProviderBackend(hs, self.primary_base_path) + storage_providers = [ + StorageProviderWrapper( + FileStorageProviderBackend(hs, self.secondary_base_path), + store_local=True, + store_remote=False, + store_synchronous=True, + ), + ] self.filepaths = MediaFilePaths(self.primary_base_path) self.media_storage = MediaStorage( - hs, self.primary_base_path, self.filepaths, storage_providers + hs, self.filepaths, storage_providers, local_provider ) def test_ensure_media_is_in_local_cache(self) -> None: @@ -102,29 +113,31 @@ class MediaStorageTests(unittest.HomeserverTestCase): # to the local cache. file_info = FileInfo(None, media_id) + async def test_ensure_media() -> None: + async with self.media_storage.ensure_media_is_in_local_cache( + file_info + ) as local_path: + self.assertTrue(os.path.exists(local_path)) + + # Asserts the file is under the expected local cache directory + self.assertEqual( + os.path.commonprefix([self.primary_base_path, local_path]), + self.primary_base_path, + ) + + with open(local_path) as f: + body = f.read() + + self.assertEqual(test_body, body) + # This uses a real blocking threadpool so we have to wait for it to be # actually done :/ - x = defer.ensureDeferred( - self.media_storage.ensure_media_is_in_local_cache(file_info) - ) + x = defer.ensureDeferred(test_ensure_media()) # Hotloop until the threadpool does its job... self.wait_on_thread(x) - local_path = self.get_success(x) - - self.assertTrue(os.path.exists(local_path)) - - # Asserts the file is under the expected local cache directory - self.assertEqual( - os.path.commonprefix([self.primary_base_path, local_path]), - self.primary_base_path, - ) - - with open(local_path) as f: - body = f.read() - - self.assertEqual(test_body, body) + self.get_success(x) @attr.s(auto_attribs=True, slots=True, frozen=True) diff --git a/tests/replication/test_multi_media_repo.py b/tests/replication/test_multi_media_repo.py index 8dbb989850..000dbb5594 100644 --- a/tests/replication/test_multi_media_repo.py +++ b/tests/replication/test_multi_media_repo.py @@ -206,6 +206,53 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase): # We expect only one new file to have been persisted. self.assertEqual(start_count + 1, self._count_remote_media()) + @override_config( + {"enable_authenticated_media": False, "enable_local_media_storage": False} + ) + def test_download_simple_file_race_no_local_storage(self) -> None: + """Test that fetching remote media works when local storage is disabled. + + This test verifies that the system handles the case where local storage + is disabled. Without storage providers configured, the media cannot be + retrieved, but the important thing is that the race condition is still + handled correctly. + """ + hs1 = self.make_worker_hs("synapse.app.generic_worker") + hs2 = self.make_worker_hs("synapse.app.generic_worker") + + start_count = self._count_remote_media() + + # Make two requests without responding to the outbound media requests. + channel1, request1 = self._get_media_req(hs1, "example.com:443", "ABC123") + channel2, request2 = self._get_media_req(hs2, "example.com:443", "ABC123") + + # Respond to the first outbound media request and check that the client + # request is successful + request1.setResponseCode(200) + request1.responseHeaders.setRawHeaders(b"Content-Type", [b"text/plain"]) + request1.write(b"Hello!") + request1.finish() + + self.pump(0.1) + + # With local storage disabled and no storage providers, + # we expect a 404 error + self.assertEqual(channel1.code, 404, channel1.result["body"]) + + # Now respond to the second with the same content. + request2.setResponseCode(200) + request2.responseHeaders.setRawHeaders(b"Content-Type", [b"text/plain"]) + request2.write(b"Hello!") + request2.finish() + + self.pump(0.1) + + # Same for the second request + self.assertEqual(channel2.code, 404, channel2.result["body"]) + + # No files should be stored locally + self.assertEqual(start_count, self._count_remote_media()) + @override_config({"enable_authenticated_media": False}) def test_download_image_race(self) -> None: """Test that fetching remote *images* from two different processes at @@ -428,6 +475,53 @@ class AuthenticatedMediaRepoShardTestCase(BaseMultiWorkerStreamTestCase): # We expect only one new file to have been persisted. self.assertEqual(start_count + 1, self._count_remote_media()) + @override_config({"enable_local_media_storage": False}) + def test_download_simple_file_race_no_local_storage(self) -> None: + """Test that fetching remote media works when local storage is disabled. + + When enable_local_media_storage is False, files should only be stored in + the storage providers and not in the local filesystem. + """ + hs1 = self.make_worker_hs("synapse.app.generic_worker") + hs2 = self.make_worker_hs("synapse.app.generic_worker") + + start_count = self._count_remote_media() + + # Make two requests without responding to the outbound media requests. + channel1, request1 = self._get_media_req(hs1, "example.com:443", "ABC123") + channel2, request2 = self._get_media_req(hs2, "example.com:443", "ABC123") + + # Respond to the first outbound media request and check that the client + # request is successful + request1.setResponseCode(200) + request1.responseHeaders.setRawHeaders( + b"Content-Type", + ["multipart/mixed; boundary=6067d4698f8d40a0a794ea7d7379d53a"], + ) + request1.write(self.file_data) + request1.finish() + + self.pump(0.1) + # With local storage disabled and no storage providers, + # we expect a 404 error + self.assertEqual(channel1.code, 404, channel1.result["body"]) + + # Now respond to the second with the same content. + request2.setResponseCode(200) + request2.responseHeaders.setRawHeaders( + b"Content-Type", + ["multipart/mixed; boundary=6067d4698f8d40a0a794ea7d7379d53a"], + ) + request2.write(self.file_data) + request2.finish() + + self.pump(0.1) + + self.assertEqual(channel2.code, 404, channel2.result["body"]) + + # With local storage disabled, no files should be stored locally + self.assertEqual(start_count, self._count_remote_media()) + def test_download_image_race(self) -> None: """Test that fetching remote *images* from two different processes at the same time works.