Caching responses
The cache
module provides a framework and various implementations for storing responses
to common backend data stores.
See also
This module implements some common backend data stores, though you may wish to extend this functionality.
Core concepts and usage
There are three classes that are key to setting up and managing the backend.
ResponseRepositorySettings
Stores the settings which define where and how a response is stored in the cache.
from aiorequestful.cache.backend.base import ResponseRepositorySettings
settings = ResponseRepositorySettings(name="main")
You may also define a handler to transform the data before storing it in the cache. If not defined, the payload data will be extracted and stored as simple text data.
from aiorequestful.response.payload import JSONPayloadHandler
payload_handler = JSONPayloadHandler()
settings = ResponseRepositorySettings(name="main", payload_handler=payload_handler)
However, the ResponseRepositorySettings
class is actually an abstract class and none of the above code
will actually run.
In order to actually instantiate our settings, we need to implement the ResponseRepositorySettings
interface by defining the following:
the
key
that can be used to identify a response in the cache repositorythe names of the
fields
of each of these keyshow we get the
name
of the repository from the payload data
i.e. you will need to implement the following interface:
@dataclass
class ResponseRepositorySettings[V](metaclass=ABCMeta):
"""Settings for a response type from a given endpoint to be used to configure a repository in the cache backend."""
#: That name of the repository in the backend
name: str
#: Handles payload data conversion to/from expected format for de/serialization.
payload_handler: PayloadHandler[V] = field(default=StringPayloadHandler())
@property
@abstractmethod
def fields(self) -> tuple[str, ...]:
"""
The names of the fields relating to the keys extracted by :py:meth:`get_key` in the order
in which they appear from the results of this method.
"""
raise NotImplementedError
@abstractmethod
def get_key(self, **kwargs) -> tuple:
"""
Extracts the name to assign to a cache entry in the repository from the given request kwargs.
See aiohttp reference for more info on available kwargs:
https://docs.aiohttp.org/en/stable/client_reference.html#aiohttp.ClientSession.request
"""
raise NotImplementedError
@staticmethod
@abstractmethod
def get_name(payload: V) -> str | None:
"""Extracts the name to assign to a cache entry in the repository from a given ``response``."""
raise NotImplementedError
As an example, see below for extracting data from a Spotify Web API response.
from http import HTTPMethod
from typing import Any, Unpack
from yarl import URL
from aiorequestful.cache.backend.base import ResponseRepositorySettings
from aiorequestful.types import RequestKwargs
class SpotifyRepositorySettings(ResponseRepositorySettings):
@property
def fields(self) -> tuple[str, ...]:
return "id", "version"
def get_key(self, method: str | HTTPMethod, url: str | URL, **__) -> tuple[str | None, str | None]:
if HTTPMethod(method) != HTTPMethod.GET: # don't store any response that is not from GET request
return None, None
url = URL(url) # e.g. https://api.spotify.com/v1/tracks/6fWoFduMpBem73DMLCOh1Z
path_parts = url.path.strip("/").split("/") # '<version>', '<name>', '<id>', ...
if len(path_parts) < 3:
return None, None
return path_parts[2], path_parts[0]
def get_name(self, payload: dict[str, Any]) -> str | None:
if payload.get("type") == "user":
return payload["display_name"]
return payload.get("name")
Here we see that the ResponseRepositorySettings.fields
are defined as the id
and the version
.
These are the fields required to identify a unique response in the repository.
In the ResponseRepositorySettings.get_key()
method, we extract the id
and the version
and return them.
We also return None, None
for any request that is not a GET
request, or if the URL does not
match the expected format.
This will force our cache to not be able to identify the response and therefore not cache it.
We also define the ResponseRepositorySettings.get_name()
method to extract the name
of the repository
from a response’s payload.
ResponseRepository
Once we’ve defined our ResponseRepositorySettings
, we can use these to define a
ResponseRepository
. This represents a store of data relating to the given settings.
from datetime import timedelta
from aiorequestful.cache.backend.base import ResponseRepository
from aiorequestful.response.payload import JSONPayloadHandler
settings = SpotifyRepositorySettings(name="tracks", payload_handler=JSONPayloadHandler())
repository = ResponseRepository(settings=settings, expire=timedelta(weeks=2))
The expire
parameter here defines how long after caching a response it will remain valid and retrievable in the
cache.
For example, if we cache a response today and make the same request with this cache in 1 week, we will
retrieve the cached response.
However, if we make the same request in 3 weeks, the HTTP request will be made without retrieving the cached response.
Below is an example on how we might use this cache to handle caching our responses.
async def process_response(rep: ResponseRepository, req: tuple, payload: dict[str, Any]) -> None:
await rep.create() # create the repository in the cache backend
await rep # does the same as above
await rep.save_response(response=(req, payload)) # store the payload in the cache
assert await rep.get_response(request=req) == payload # retrieve the cached payload
assert rep.get_key_from_request(payload) == req # get the key from a payload
assert await rep.count(True) == 1 # the number of responses cached, including expired responses
assert await rep.count(False) == 1 # the number of responses cached, excluding expired responses
assert await rep.contains(req) # check the key is in the repository
await rep.delete_response(request=req) # delete the response
await rep.clear() # OR clear all responses
await rep.clear(True) # OR clear only expired responses
async def process_request(rep: ResponseRepository) -> None:
request = ("6fWoFduMpBem73DMLCOh1Z", "v1")
payload = {"name": "super cool song"}
await process_response(rep=rep, req=request, payload=payload)
Here we create the repository and manage storing responses manually. However, it is advised that we use a
ResponseCache
to manage our repositories.
ResponseCache
Manages a collection of repositories along with a connection to the backend.
from aiorequestful.cache.backend import ResponseCache
def repository_getter(cache: ResponseCache, url: str | URL) -> ResponseRepository:
path = URL(url).path
path_split = [part.replace("-", "_") for part in path.split("/")[2:]]
if len(path_split) < 3:
name = path_split[0]
else:
name = "_".join([path_split[0].rstrip("s"), path_split[2].rstrip("s") + "s"])
return cache.get(name)
response_cache = ResponseCache(
cache_name="cache", repository_getter=repository_getter, expire=timedelta(weeks=2)
)
# or we can call the 'create' class method to simplify cache creation
response_cache = ResponseCache.create(value="db")
Now we can use the cache to help manage the creation of repositories.
from aiohttp import ClientRequest, ClientResponse
async def setup_cache(cache: ResponseCache) -> None:
payload_handler = JSONPayloadHandler()
cache.create_repository(SpotifyRepositorySettings(name="tracks", payload_handler=payload_handler))
cache.create_repository(SpotifyRepositorySettings(name="albums", payload_handler=payload_handler))
cache.create_repository(SpotifyRepositorySettings(name="artists", payload_handler=payload_handler))
See also
Here we assign the PayloadHandler
for each repository manually.
However, as we will usually use the ResponseCache
as part of the RequestHandler
,
we do not need to add the PayloadHandler
here as the RequestHandler
will manage
that for us.
For more info on how this can be used, see Handling the response payload.
Warning
By calling ResponseCache.create_repository()
, we are only defining the repository on the
ResponseCache
object and have not actually created the repository on the backend.
To create the repository on the backend, we will need to connect to the cache and then create them.
This is most easily achieved by entering the context of the ResponseCache
async def process_response(cache: ResponseCache, request: ClientRequest, response: ClientResponse) -> None:
async with cache: # connect and set up the repositories on the backed
await cache.save_response(response=response) # store the payload in the cache
assert await cache.get_response(request=request) == await response.json() # retrieve the cached payload
await cache.delete_response(request=request) # delete the response
Cached sessions and responses
To aid in a seamless usage of cached and non-cached setups when using the RequestHandler
,
the cache
module also provides a CachedSession
and a CachedResponse
implementation.
CachedResponse
Takes a ClientRequest
and payload
to mock a
ClientResponse.
Allows for seamless usage of cached responses as if they were returned by a genuine HTTP request.
You can use CachedResponse
in exactly the same way you would a regular
ClientResponse.
import json
from aiohttp import ClientRequest
from yarl import URL
from aiorequestful.cache.response import CachedResponse
url = URL("https://official-joke-api.appspot.com/jokes/programming/random")
request = ClientRequest(method="GET", url=url)
payload = [
{
"type": "programming",
"setup": "I was gonna tell you a joke about UDP...",
"punchline": "...but you might not get it.",
"id": 72
}
]
response = CachedResponse(request, payload=json.dumps(payload))
CachedSession
Takes a ResponseCache
to mock a
ClientSession.
The CachedSession
will always attempt to call the cache first before falling back to making an actual
HTTP request to get the response.
You can use CachedSession
in exactly the same way you would a regular
ClientSession.
from aiorequestful.cache.backend.sqlite import SQLiteCache
from aiorequestful.cache.session import CachedSession
cache = SQLiteCache.connect_with_in_memory_db()
session = CachedSession(cache=cache)
SQLite
SQLiteTable
In addition to the available kwargs for the base ResponseRepository
, we also need to provide a
Connection to the database.
from datetime import timedelta
import aiosqlite
from aiorequestful.cache.backend.sqlite import SQLiteTable
connection = aiosqlite.connect(database="file::memory:?cache=shared", uri=True)
settings = SpotifyRepositorySettings(name="tracks")
repository = SQLiteTable(
connection=connection, settings=settings, expire=timedelta(weeks=2)
)
SQLiteCache
The implementation provides a variety of standard database connections to aid in the quick set up of an SQLite cache backend.
While we can instantiate the SQLiteCache directly by providing a connector
to a
Connection object…
import tempfile
from aiorequestful.cache.backend.sqlite import SQLiteCache
cache = SQLiteCache(
cache_name="__IN_MEMORY__",
connector=lambda: aiosqlite.connect(database="file::memory:?cache=shared", uri=True),
)
…it is preferable to use one of the class methods for connecting to an SQLite database backend.
cache = SQLiteCache.connect(value="file::memory:?cache=shared")
cache = SQLiteCache.connect_with_path(path=f"{tempfile.gettempdir()}/path/to/db.sqlite")
cache = SQLiteCache.connect_with_temp_db()
cache = SQLiteCache.connect_with_in_memory_db()
Writing a ResponseRepository
To implement a ResponseRepository
, you will need to implement the abstract methods as shown below.
class ResponseRepository[K: tuple, V: Any](AsyncIterable[tuple[K, V]], metaclass=ABCMeta):
"""
Represents a repository in the backend cache, providing a dict-like interface
for interacting with this repository.
A repository is a data store within the backend e.g. a table in a database.
:param settings: The settings to use to identify and interact with the repository in the backend.
:param expire: The expiry time to apply to cached responses after which responses are invalidated.
"""
__slots__ = ("logger", "connection", "settings", "_expire")
# noinspection PyPropertyDefinition,PyMethodParameters
@property
@abstractmethod
def _required_modules(cls) -> list:
"""The modules required to instantiate this repository"""
return []
@property
def expire(self) -> datetime:
"""The datetime representing the maximum allowed expiry time from now."""
return datetime.now() + self._expire
@classmethod
@abstractmethod
def create(cls, *args, **kwargs) -> Self:
"""
Set up the backend repository in the backend database if it doesn't already exist
and return the initialised object that represents it.
"""
raise NotImplementedError
def __init__(self, settings: ResponseRepositorySettings[V], expire: timedelta | relativedelta = DEFAULT_EXPIRE):
#: The :py:class:`logging.Logger` for this object
self.logger: logging.Logger = logging.getLogger(__name__)
#: The settings to use to identify and interact with the repository in the backend.
self.settings = settings
self._expire = expire
#: The current connection to the backend.
self.connection = None
@abstractmethod
def __await__(self) -> Generator[None, None, Self]:
raise NotImplementedError
def __hash__(self):
return hash(self.settings.name)
@abstractmethod
async def commit(self) -> None:
"""Commit the changes to the repository"""
raise NotImplementedError
@abstractmethod
async def close(self) -> None:
"""Close the connection to the repository."""
raise NotImplementedError
@abstractmethod
async def count(self, include_expired: bool = True) -> int:
"""
Get the number of responses in this repository.
:param include_expired: Whether to include expired responses in the final count.
:return: The number of responses in this repository.
"""
raise NotImplementedError
@abstractmethod
async def contains(self, request: RepositoryRequestType[K]) -> bool:
"""Check whether the repository contains a given ``request``"""
raise NotImplementedError
@abstractmethod
async def clear(self, expired_only: bool = False) -> int:
"""
Clear the repository of all entries.
:param expired_only: Whether to only remove responses that have expired.
:return: The number of responses cleared from the repository.
"""
raise NotImplementedError
async def serialize(self, value: Any) -> V | None:
"""
Serialize a given ``value`` to a type that can be persisted to the repository.
:return: Serialized object if serializing is possible, None otherwise.
"""
if value is None:
return
try:
return await self.settings.payload_handler.serialize(value)
except PayloadHandlerError:
return
async def deserialize(self, value: V | None) -> Any:
"""
Deserialize a value from the repository to the expected response value type.
:return: Deserialized object if deserializing is possible, None otherwise.
"""
if value is None:
return
try:
return await self.settings.payload_handler.deserialize(value)
except PayloadHandlerError:
return
@abstractmethod
def get_key_from_request(self, request: RepositoryRequestType[K]) -> K:
"""Extract the key to use when persisting responses for a given ``request``"""
raise NotImplementedError
@abstractmethod
async def get_response(self, request: RepositoryRequestType[K]) -> V | None:
"""
Get the response relating to the given ``request`` from this repository if it exists.
:return: The result if found.
"""
raise NotImplementedError
async def get_responses(self, requests: Collection[RepositoryRequestType[K]]) -> list[V]:
"""
Get the responses relating to the given ``requests`` from this repository if they exist.
:return: Results unordered.
"""
tasks = asyncio.gather(*map(self.get_response, requests))
return list(filter(lambda result: result is not None, await tasks))
async def save_response(self, response: Collection[K, V] | ClientResponse) -> None:
"""Save the given ``response`` to this repository if a key can be extracted from it. Safely fail if not"""
if isinstance(response, Collection):
key, value = response
else:
key = self.get_key_from_request(response)
if not key:
return
value: V = await self.deserialize(response)
await self._set_item_from_key_value_pair(key, await self.serialize(value))
@abstractmethod
async def _set_item_from_key_value_pair(self, __key: K, __value: Any) -> None:
raise NotImplementedError
async def save_responses(self, responses: Mapping[K, V] | Collection[ClientResponse]) -> None:
"""
Save the given ``responses`` to this repository if a key can be extracted from them.
Safely fail on those that can't.
"""
if isinstance(responses, Mapping):
tasks = [
self._set_item_from_key_value_pair(key, await self.serialize(value))
for key, value in responses.items()
]
else:
tasks = map(self.save_response, responses)
await asyncio.gather(*tasks)
@abstractmethod
async def delete_response(self, request: RepositoryRequestType[K]) -> bool:
"""
Delete the given ``request`` from this repository if it exists.
:return: True if deleted in the repository and False if ``request`` was not found in the repository.
"""
raise NotImplementedError
async def delete_responses(self, requests: Collection[RepositoryRequestType[K]]) -> int:
"""
Delete the given ``requests`` from this repository if they exist.
:return: The number of the given ``requests`` deleted in the repository.
"""
tasks = asyncio.gather(*map(self.delete_response, requests))
return sum(await tasks)
As an example, the following implements the SQLiteTable
.
class SQLiteTable[K: tuple[Any, ...], V: str](ResponseRepository[K, V]):
__slots__ = ()
#: The column under which a response's name is stored in the table
name_column = "name"
#: The column under which the response payload is stored in the table
payload_column = "payload"
#: The column under which the response cache time is stored in the table
cached_column = "cached_at"
#: The column under which the response expiry time is stored in the table
expiry_column = "expires_at"
# noinspection PyMethodParameters
@classproperty
def _required_modules(cls) -> list:
return [aiosqlite]
async def create(self) -> Self:
ddl_sep = "\t, "
ddl = "\n".join((
f'CREATE TABLE IF NOT EXISTS "{self.settings.name}" (',
'\t' + f'\n{ddl_sep}'.join(
f'"{key}" {data_type} NOT NULL' for key, data_type in self._primary_key_columns.items()
),
f'{ddl_sep}"{self.name_column}" TEXT',
f'{ddl_sep}"{self.cached_column}" TIMESTAMP NOT NULL',
f'{ddl_sep}"{self.expiry_column}" TIMESTAMP NOT NULL',
f'{ddl_sep}"{self.payload_column}" TEXT',
f'{ddl_sep}PRIMARY KEY ("{'", "'.join(self._primary_key_columns)}")',
');',
f'CREATE INDEX IF NOT EXISTS idx_{self.expiry_column} '
f'ON "{self.settings.name}"({self.expiry_column});'
))
self.logger.debug(f"Creating {self.settings.name!r} table with the following DDL:\n{ddl}")
await self.connection.executescript(ddl)
await self.commit()
return self
def __init__(
self,
connection: aiosqlite.Connection,
settings: ResponseRepositorySettings,
expire: timedelta | relativedelta = DEFAULT_EXPIRE,
):
required_modules_installed(self._required_modules, self)
super().__init__(settings=settings, expire=expire)
self.connection = connection
def __await__(self) -> Generator[None, None, Self]:
return self.create().__await__()
async def commit(self) -> None:
"""Commit the transactions to the database."""
try:
await self.connection.commit()
except ValueError:
pass
async def close(self) -> None:
try:
await self.commit()
await self.connection.close()
except ValueError:
pass
@property
def _primary_key_columns(self) -> Mapping[str, str]:
"""A map of column names to column data types for the primary keys of this repository."""
expected_columns = self.settings.fields
keys = {"method": f"VARCHAR({max(len(method) for method in HTTPMethod)})"}
if "id" in expected_columns:
keys["id"] = "VARCHAR(255)"
if "offset" in expected_columns:
keys["offset"] = "INT"
if "size" in expected_columns:
keys["size"] = "INT"
return keys
def get_key_from_request(self, request: RepositoryRequestType[K]) -> K | None:
if isinstance(request, ClientRequest | ClientResponse):
request = request.request_info
if not isinstance(request, RequestInfo):
return request # `request` is the key
key = self.settings.get_key(method=request.method, url=request.url, headers=request.headers)
if any(part is None for part in key):
return
return str(request.method).upper(), *key
async def count(self, include_expired: bool = True) -> int:
query = f'SELECT COUNT(*) FROM "{self.settings.name}"'
params = []
if not include_expired:
query += f'\nWHERE "{self.expiry_column}" > ?'
params.append(datetime.now().isoformat())
async with self.connection.execute(query, params) as cur:
row = await cur.fetchone()
return row[0]
async def contains(self, request: RepositoryRequestType[K]) -> bool:
key = self.get_key_from_request(request)
query = "\n".join((
f'SELECT COUNT(*) FROM "{self.settings.name}"',
f'WHERE "{self.expiry_column}" > ?',
f'\tAND {'\n\tAND '.join(f'"{key}" = ?' for key in self._primary_key_columns)}',
))
async with self.connection.execute(query, (datetime.now().isoformat(), *key)) as cur:
rows = await cur.fetchone()
return rows[0] > 0
async def clear(self, expired_only: bool = False) -> int:
query = f'DELETE FROM "{self.settings.name}"'
params = []
if expired_only:
query += f'\nWHERE "{self.expiry_column}" > ?'
params.append(datetime.now().isoformat())
async with self.connection.execute(query, params) as cur:
count = cur.rowcount
return count
async def __aiter__(self):
query = "\n".join((
f'SELECT "{'", "'.join(self._primary_key_columns)}", "{self.payload_column}" ',
f'FROM "{self.settings.name}"',
f'WHERE "{self.expiry_column}" > ?',
))
async with self.connection.execute(query, (datetime.now().isoformat(),)) as cur:
async for row in cur:
yield row[:-1], await self.deserialize(row[-1])
async def get_response(self, request: RepositoryRequestType[K]) -> V | None:
key = self.get_key_from_request(request)
if not key:
return
query = "\n".join((
f'SELECT "{self.payload_column}" FROM {self.settings.name}',
f'WHERE "{self.payload_column}" IS NOT NULL',
f'\tAND "{self.expiry_column}" > ?',
f'\tAND {'\n\tAND '.join(f'"{key}" = ?' for key in self._primary_key_columns)}',
))
async with self.connection.execute(query, (datetime.now().isoformat(), *key)) as cur:
row = await cur.fetchone()
if not row:
return
return await self.deserialize(row[0])
async def _set_item_from_key_value_pair(self, __key: K, __value: V) -> None:
columns = (
*self._primary_key_columns,
self.name_column,
self.cached_column,
self.expiry_column,
self.payload_column
)
query = "\n".join((
f'INSERT OR REPLACE INTO "{self.settings.name}" (',
f'\t"{'", "'.join(columns)}"',
') ',
f'VALUES({','.join('?' * len(columns))});',
))
params = (
*__key,
self.settings.get_name(await self.deserialize(__value)),
datetime.now().isoformat(),
self.expire.isoformat(),
__value,
)
await self.connection.execute(query, params)
async def delete_response(self, request: RepositoryRequestType[K]) -> bool:
key = self.get_key_from_request(request)
query = "\n".join((
f'DELETE FROM "{self.settings.name}"',
f'WHERE {'\n\tAND '.join(f'"{key}" = ?' for key in self._primary_key_columns)}',
))
async with self.connection.execute(query, key) as cur:
count = cur.rowcount
return count > 0
Writing a ResponseCache
To implement a ResponseCache
, you will need to implement the abstract methods as shown below.
class ResponseCache[R: ResponseRepository](MutableMapping[str, R], metaclass=ABCMeta):
"""
Represents a backend cache of many repositories, providing a dict-like interface for interacting with them.
:param cache_name: The name to give to this cache.
:param repository_getter: A function that can be used to identify the repository in this cache
that matches a given URL.
:param expire: The expiry time to apply to cached responses after which responses are invalidated.
"""
__slots__ = ("cache_name", "repository_getter", "expire", "_repositories")
# noinspection PyMethodParameters
@classproperty
@abstractmethod
def type(cls) -> str:
"""A string representing the type of the backend this class represents."""
# raise NotImplementedError - omitted here as it causes docs build to fail
@classmethod
@abstractmethod
async def connect(cls, value: Any, **kwargs) -> Self:
"""Connect to the backend from a given generic ``value``."""
raise NotImplementedError
def __init__(
self,
cache_name: str,
repository_getter: Callable[[Self, URLInput], R | None] = None,
expire: timedelta | relativedelta = DEFAULT_EXPIRE,
):
super().__init__()
#: The name to give to this cache.
self.cache_name = cache_name
#: A function that can be used to identify the repository in this cache that matches a given URL.
self.repository_getter = repository_getter
#: The expiry time to apply to cached responses after which responses are invalidated.
self.expire = expire
self._repositories: dict[str, R] = {}
@abstractmethod
def __await__(self) -> Generator[None, None, Self]:
raise NotImplementedError
async def __aenter__(self) -> Self:
return await self
@abstractmethod
async def __aexit__(self, exc_type, exc_val, exc_tb):
raise NotImplementedError
def __repr__(self):
return repr(self._repositories)
def __str__(self):
return str(self._repositories)
def __iter__(self):
return iter(self._repositories)
def __len__(self):
return len(self._repositories)
def __getitem__(self, item):
return self._repositories[item]
def __setitem__(self, key, value):
self._repositories[key] = value
def __delitem__(self, key):
del self._repositories[key]
@abstractmethod
async def commit(self) -> None:
"""Commit the changes to the cache"""
raise NotImplementedError
@abstractmethod
async def close(self):
"""Close the connection to the repository."""
raise NotImplementedError
@abstractmethod
def create_repository(self, settings: ResponseRepositorySettings) -> ResponseRepository:
"""
Create and return a :py:class:`SQLiteResponseStorage` and store this object in this cache.
Creates a repository with the given ``settings`` in the cache if it doesn't exist.
"""
raise NotImplementedError
def get_repository_from_url(self, url: URLInput) -> R | None:
"""Returns the repository to use from the stored repositories in this cache for the given ``url``."""
if self.repository_getter is not None:
return self.repository_getter(self, url)
def get_repository_from_requests(self, requests: UnitCollection[CacheRequestType]) -> R | None:
"""Returns the repository to use from the stored repositories in this cache for the given ``requests``."""
requests = get_iterator(requests)
results = {self.get_repository_from_url(request.url) for request in requests}
if len(results) > 1:
raise CacheError(
"Too many different types of requests given. Given requests must relate to the same repository type"
)
return next(iter(results), None)
async def get_response(self, request: CacheRequestType) -> Any:
"""
Get the response relating to the given ``request`` from the appropriate repository if it exists.
:return: The result if found.
"""
repository = self.get_repository_from_requests([request])
if repository is not None:
return await repository.get_response(request)
async def get_responses(self, requests: Collection[CacheRequestType]) -> list:
"""
Get the responses relating to the given ``requests`` from the appropriate repository if they exist.
:return: Results unordered.
"""
repository = self.get_repository_from_requests(requests)
if repository is not None:
return await repository.get_responses(requests)
async def save_response(self, response: ClientResponse) -> None:
"""Save the given ``response`` to the appropriate repository if a key can be extracted from it."""
repository = self.get_repository_from_requests([response])
if repository is not None:
return await repository.save_response(response)
async def save_responses(self, responses: Collection[ClientResponse]) -> None:
"""
Save the given ``responses`` to the appropriate repository if a key can be extracted from them.
Safely fail on those that can't.
"""
repository = self.get_repository_from_requests(responses)
if repository is not None:
return await repository.save_responses(responses)
async def delete_response(self, request: CacheRequestType) -> bool:
"""
Delete the given ``request`` from the appropriate repository if it exists.
:return: True if deleted in the repository and False if ``request`` was not found in the repository.
"""
repository = self.get_repository_from_requests([request])
if repository is not None:
return await repository.delete_response(request)
async def delete_responses(self, requests: Collection[CacheRequestType]) -> int:
"""
Delete the given ``requests`` from the appropriate repository.
:return: The number of the given ``requests`` deleted in the repository.
"""
repository = self.get_repository_from_requests(requests)
if repository is not None:
return await repository.delete_responses(requests)
As an example, the following implements the SQLiteCache
.
class SQLiteCache(ResponseCache[SQLiteTable]):
__slots__ = ("_connector", "connection")
# noinspection PyMethodParameters
@classproperty
def type(cls):
return "sqlite"
@property
def closed(self):
"""Is the stored client session closed."""
return self.connection is None or not self.connection.is_alive()
@staticmethod
def _get_sqlite_path(path: Path) -> Path:
return path.with_suffix(".sqlite")
@staticmethod
def _clean_kwargs[T: dict](kwargs: T) -> T:
kwargs.pop("cache_name", None)
kwargs.pop("connection", None)
return kwargs
@classmethod
def connect(cls, value: Any, **kwargs) -> Self:
return cls.connect_with_path(path=value, **kwargs)
@classmethod
def connect_with_path(cls, path: str | Path, **kwargs) -> Self:
"""Connect with an SQLite DB at the given ``path`` and return an instantiated :py:class:`SQLiteResponseCache`"""
path = cls._get_sqlite_path(Path(path))
os.makedirs(path.parent, exist_ok=True)
return cls(
cache_name=str(path),
connector=lambda: aiosqlite.connect(database=path),
**cls._clean_kwargs(kwargs)
)
@classmethod
def connect_with_in_memory_db(cls, **kwargs) -> Self:
"""Connect with an in-memory SQLite DB and return an instantiated :py:class:`SQLiteResponseCache`"""
return cls(
cache_name="__IN_MEMORY__",
connector=lambda: aiosqlite.connect(database="file::memory:?cache=shared", uri=True),
**cls._clean_kwargs(kwargs)
)
@classmethod
def connect_with_temp_db(cls, name: str = f"{PROGRAM_NAME.lower()}_db.tmp", **kwargs) -> Self:
"""Connect with a temporary SQLite DB and return an instantiated :py:class:`SQLiteResponseCache`"""
path = cls._get_sqlite_path(Path(gettempdir(), name))
return cls(
cache_name=name,
connector=lambda: aiosqlite.connect(database=path),
**cls._clean_kwargs(kwargs)
)
def __init__(
self,
cache_name: str,
connector: Callable[[], aiosqlite.Connection],
repository_getter: Callable[[Self, URLInput], SQLiteTable] = None,
expire: timedelta | relativedelta = DEFAULT_EXPIRE,
):
required_modules_installed(SQLiteTable._required_modules, self)
super().__init__(cache_name=cache_name, repository_getter=repository_getter, expire=expire)
self._connector = connector
#: The current connection to the SQLite database.
self.connection: aiosqlite.Connection | None = None
async def _connect(self) -> Self:
if self.closed:
self.connection = self._connector()
await self.connection
for repository in self._repositories.values():
repository.connection = self.connection
await repository.create()
return self
def __await__(self) -> Generator[None, None, Self]:
return self._connect().__await__()
async def __aexit__(self, __exc_type, __exc_value, __traceback) -> None:
if self.closed:
return
await self.commit()
await self.connection.__aexit__(__exc_type, __exc_value, __traceback)
self.connection = None
async def commit(self):
"""Commit the transactions to the database."""
if self.closed:
return
try:
await self.connection.commit()
except ValueError:
pass
async def close(self):
if self.closed:
return
try:
await self.commit()
await self.connection.close()
except ValueError:
pass
def create_repository(self, settings: ResponseRepositorySettings) -> SQLiteTable:
if settings.name in self:
raise CacheError(f"Repository already exists: {settings.name}")
repository = SQLiteTable(connection=self.connection, settings=settings, expire=self.expire)
self._repositories[settings.name] = repository
return repository