Files
myasgi/app/components/lifespan.py

153 lines
4.9 KiB
Python
Raw Normal View History

2025-04-15 16:36:23 +08:00
import asyncio
2025-04-11 15:04:22 +08:00
from collections.abc import AsyncGenerator, Callable
2025-04-15 16:36:23 +08:00
from contextlib import AbstractAsyncContextManager, AbstractContextManager
from typing import Any, TypeGuard, cast, overload, override
2025-04-09 16:54:34 +08:00
2025-04-15 16:36:23 +08:00
from app.exceptions import LifespanError
2025-04-11 15:04:22 +08:00
from app.types_ import (
AnyScope,
AsyncCallable,
LifespanScope,
2025-04-15 16:36:23 +08:00
PassthroughDecorator,
2025-04-11 15:04:22 +08:00
Receive,
ReceiveLifespan,
Send,
)
2025-04-09 16:54:34 +08:00
from .base import Component as _Component
2025-04-15 16:36:23 +08:00
async def resolve_context[T](
*async_generators: AsyncGenerator[T, None],
) -> AsyncGenerator[tuple[T, ...], None]:
"""
Resolve `AsyncGenerator`s context.
"""
iterators = [ag.__aiter__() for ag in async_generators]
while True:
try:
results = await asyncio.gather(
*[iterator.__anext__() for iterator in iterators]
)
yield tuple(results)
except StopAsyncIteration:
break
yield ()
2025-04-09 16:54:34 +08:00
class LifespanComponent(_Component[LifespanScope, ReceiveLifespan]):
2025-04-11 15:04:22 +08:00
startups: list[AsyncCallable[[], None]]
shutdowns: list[AsyncCallable[[], None]]
2025-04-15 16:36:23 +08:00
contexts: list[tuple[str | None, Callable[[], AsyncGenerator[Any, None]]]]
loaded_context: dict[str, Any]
2025-04-11 15:04:22 +08:00
def __init__(self, *args: Any, **kwds: Any) -> None:
super().__init__(*args, **kwds)
self.startups = []
self.shutdowns = []
self.contexts = []
2025-04-15 16:36:23 +08:00
self.loaded_context = {}
2025-04-11 15:04:22 +08:00
2025-04-09 16:54:34 +08:00
@override
async def condition(self, scope: AnyScope) -> TypeGuard[LifespanScope]:
return scope["type"] == "lifespan"
@override
2025-04-11 15:04:22 +08:00
async def handle(
self, scope: LifespanScope, receive: Receive[ReceiveLifespan], send: Send
) -> None:
message = await receive()
2025-04-15 16:36:23 +08:00
async for ctxs in resolve_context(*[ctx[1]() for ctx in self.contexts]):
2025-04-11 15:04:22 +08:00
if message["type"] == "lifespan.startup":
for fn in self.startups:
await fn()
2025-04-15 16:36:23 +08:00
for name, val in zip((ctx[0] for ctx in self.contexts), ctxs):
if name is None:
continue
if name in self.loaded_context:
raise LifespanError(
f"Name {name!r} is already used by context {self.loaded_context[name]!r}."
)
self.loaded_context[name] = val
2025-04-11 15:04:22 +08:00
await send({"type": "lifespan.startup.complete"})
elif message["type"] == "lifespan.shutdown":
for fn in self.shutdowns:
await fn()
await send({"type": "lifespan.shutdown.complete"})
return
2025-04-09 16:54:34 +08:00
message = await receive()
2025-04-11 15:04:22 +08:00
def on_startup[Call_T: AsyncCallable[[], None]](self, fn: Call_T) -> Call_T:
self.startups.append(fn)
return fn
def on_shutdown[Call_T: AsyncCallable[[], None]](self, fn: Call_T) -> Call_T:
self.shutdowns.append(fn)
return fn
2025-04-15 16:36:23 +08:00
@overload
def on_context(
self, *, name: str | None = None
) -> PassthroughDecorator[Callable[[], AsyncGenerator[Any, None]]]: ...
@overload
def on_context[Ctx_T: Callable[[], AsyncGenerator[Any, None]]](
self, fn: Ctx_T
) -> Ctx_T: ...
def on_context[Ctx_T: Callable[[], AsyncGenerator[Any, None]]](
self, fn: Ctx_T | None = None, *, name: str | None = None
) -> PassthroughDecorator[Ctx_T] | Ctx_T:
if fn is None:
def __wrap_context(fn: Ctx_T) -> Ctx_T:
self.contexts.append((name, fn))
return fn
return __wrap_context
self.contexts.append((name, fn))
2025-04-11 15:04:22 +08:00
return fn
2025-04-15 16:36:23 +08:00
@overload
def add_managed_context(
self,
ctx: AbstractContextManager[Any, Any],
async_: None = None,
name: str | None = None,
) -> None: ...
@overload
def add_managed_context(
self,
ctx: AbstractAsyncContextManager[Any, Any],
async_: None = None,
name: str | None = None,
) -> None: ...
@overload
def add_managed_context(
self, ctx: Any, async_: bool, name: str | None = None
) -> None: ...
def add_managed_context(
self,
ctx: AbstractContextManager[Any, Any] | AbstractAsyncContextManager[Any, Any],
async_: bool | None = None,
name: str | None = None,
) -> None:
if async_ is None:
async_ = isinstance(ctx, AbstractAsyncContextManager)
@self.on_context(name=name)
async def __make_context() -> AsyncGenerator[Any, Any]: # pyright: ignore[reportUnusedFunction]
if async_:
async with cast(AbstractAsyncContextManager[Any, Any], ctx) as c:
yield c
else:
with cast(AbstractContextManager[Any, Any], ctx) as c:
yield c
def get_context[T](self, name: str, type_: type[T] | None = None) -> T: # pyright: ignore[reportUnusedParameter]
return cast(T, self.loaded_context[name])