upload
This commit is contained in:
@@ -1,11 +1,14 @@
|
||||
import asyncio
|
||||
from collections.abc import AsyncGenerator, Callable
|
||||
from typing import Any, TypeGuard, override
|
||||
from contextlib import AbstractAsyncContextManager, AbstractContextManager
|
||||
from typing import Any, TypeGuard, cast, overload, override
|
||||
|
||||
from app.subroutines.asyncutils import agzip
|
||||
from app.exceptions import LifespanError
|
||||
from app.types_ import (
|
||||
AnyScope,
|
||||
AsyncCallable,
|
||||
LifespanScope,
|
||||
PassthroughDecorator,
|
||||
Receive,
|
||||
ReceiveLifespan,
|
||||
Send,
|
||||
@@ -14,16 +17,37 @@ from app.types_ import (
|
||||
from .base import Component as _Component
|
||||
|
||||
|
||||
async def resolve_context[T](
|
||||
*async_generators: AsyncGenerator[T, None],
|
||||
) -> AsyncGenerator[tuple[T, ...], None]:
|
||||
"""
|
||||
Resolve `AsyncGenerator`s context.
|
||||
"""
|
||||
iterators = [ag.__aiter__() for ag in async_generators]
|
||||
|
||||
while True:
|
||||
try:
|
||||
results = await asyncio.gather(
|
||||
*[iterator.__anext__() for iterator in iterators]
|
||||
)
|
||||
yield tuple(results)
|
||||
except StopAsyncIteration:
|
||||
break
|
||||
yield ()
|
||||
|
||||
|
||||
class LifespanComponent(_Component[LifespanScope, ReceiveLifespan]):
|
||||
startups: list[AsyncCallable[[], None]]
|
||||
shutdowns: list[AsyncCallable[[], None]]
|
||||
contexts: list[Callable[[], AsyncGenerator[Any, None]]]
|
||||
contexts: list[tuple[str | None, Callable[[], AsyncGenerator[Any, None]]]]
|
||||
loaded_context: dict[str, Any]
|
||||
|
||||
def __init__(self, *args: Any, **kwds: Any) -> None:
|
||||
super().__init__(*args, **kwds)
|
||||
self.startups = []
|
||||
self.shutdowns = []
|
||||
self.contexts = []
|
||||
self.loaded_context = {}
|
||||
|
||||
@override
|
||||
async def condition(self, scope: AnyScope) -> TypeGuard[LifespanScope]:
|
||||
@@ -34,10 +58,18 @@ class LifespanComponent(_Component[LifespanScope, ReceiveLifespan]):
|
||||
self, scope: LifespanScope, receive: Receive[ReceiveLifespan], send: Send
|
||||
) -> None:
|
||||
message = await receive()
|
||||
async for _ in agzip(*[ctx() for ctx in self.contexts]):
|
||||
async for ctxs in resolve_context(*[ctx[1]() for ctx in self.contexts]):
|
||||
if message["type"] == "lifespan.startup":
|
||||
for fn in self.startups:
|
||||
await fn()
|
||||
for name, val in zip((ctx[0] for ctx in self.contexts), ctxs):
|
||||
if name is None:
|
||||
continue
|
||||
if name in self.loaded_context:
|
||||
raise LifespanError(
|
||||
f"Name {name!r} is already used by context {self.loaded_context[name]!r}."
|
||||
)
|
||||
self.loaded_context[name] = val
|
||||
await send({"type": "lifespan.startup.complete"})
|
||||
elif message["type"] == "lifespan.shutdown":
|
||||
for fn in self.shutdowns:
|
||||
@@ -54,6 +86,67 @@ class LifespanComponent(_Component[LifespanScope, ReceiveLifespan]):
|
||||
self.shutdowns.append(fn)
|
||||
return fn
|
||||
|
||||
def on_context[Ctx_T: Callable[[], AsyncGenerator[Any, None]]](self, fn: Ctx_T) -> Ctx_T:
|
||||
self.contexts.append(fn)
|
||||
@overload
|
||||
def on_context(
|
||||
self, *, name: str | None = None
|
||||
) -> PassthroughDecorator[Callable[[], AsyncGenerator[Any, None]]]: ...
|
||||
|
||||
@overload
|
||||
def on_context[Ctx_T: Callable[[], AsyncGenerator[Any, None]]](
|
||||
self, fn: Ctx_T
|
||||
) -> Ctx_T: ...
|
||||
|
||||
def on_context[Ctx_T: Callable[[], AsyncGenerator[Any, None]]](
|
||||
self, fn: Ctx_T | None = None, *, name: str | None = None
|
||||
) -> PassthroughDecorator[Ctx_T] | Ctx_T:
|
||||
if fn is None:
|
||||
|
||||
def __wrap_context(fn: Ctx_T) -> Ctx_T:
|
||||
self.contexts.append((name, fn))
|
||||
return fn
|
||||
|
||||
return __wrap_context
|
||||
self.contexts.append((name, fn))
|
||||
return fn
|
||||
|
||||
@overload
|
||||
def add_managed_context(
|
||||
self,
|
||||
ctx: AbstractContextManager[Any, Any],
|
||||
async_: None = None,
|
||||
name: str | None = None,
|
||||
) -> None: ...
|
||||
|
||||
@overload
|
||||
def add_managed_context(
|
||||
self,
|
||||
ctx: AbstractAsyncContextManager[Any, Any],
|
||||
async_: None = None,
|
||||
name: str | None = None,
|
||||
) -> None: ...
|
||||
|
||||
@overload
|
||||
def add_managed_context(
|
||||
self, ctx: Any, async_: bool, name: str | None = None
|
||||
) -> None: ...
|
||||
|
||||
def add_managed_context(
|
||||
self,
|
||||
ctx: AbstractContextManager[Any, Any] | AbstractAsyncContextManager[Any, Any],
|
||||
async_: bool | None = None,
|
||||
name: str | None = None,
|
||||
) -> None:
|
||||
if async_ is None:
|
||||
async_ = isinstance(ctx, AbstractAsyncContextManager)
|
||||
|
||||
@self.on_context(name=name)
|
||||
async def __make_context() -> AsyncGenerator[Any, Any]: # pyright: ignore[reportUnusedFunction]
|
||||
if async_:
|
||||
async with cast(AbstractAsyncContextManager[Any, Any], ctx) as c:
|
||||
yield c
|
||||
else:
|
||||
with cast(AbstractContextManager[Any, Any], ctx) as c:
|
||||
yield c
|
||||
|
||||
def get_context[T](self, name: str, type_: type[T] | None = None) -> T: # pyright: ignore[reportUnusedParameter]
|
||||
return cast(T, self.loaded_context[name])
|
||||
|
||||
Reference in New Issue
Block a user