This commit is contained in:
2025-04-15 16:36:23 +08:00
parent dcceb275ad
commit 7cb0a87d9c
5 changed files with 111 additions and 33 deletions

View File

@@ -71,6 +71,7 @@ class HTTPComponent(
post: bool = False,
put: bool = False,
delete: bool = False,
head: bool = False,
) -> PassthroughDecorator[T]:
def __wrap_route(fn: T) -> T:
if get:
@@ -81,6 +82,8 @@ class HTTPComponent(
self.route_install(route, fn, type_="PUT")
if delete:
self.route_install(route, fn, type_="DELETE")
if head:
self.route_install(route, fn, type_="HEAD")
return fn
return __wrap_route

View File

@@ -1,11 +1,14 @@
import asyncio
from collections.abc import AsyncGenerator, Callable
from typing import Any, TypeGuard, override
from contextlib import AbstractAsyncContextManager, AbstractContextManager
from typing import Any, TypeGuard, cast, overload, override
from app.subroutines.asyncutils import agzip
from app.exceptions import LifespanError
from app.types_ import (
AnyScope,
AsyncCallable,
LifespanScope,
PassthroughDecorator,
Receive,
ReceiveLifespan,
Send,
@@ -14,16 +17,37 @@ from app.types_ import (
from .base import Component as _Component
async def resolve_context[T](
*async_generators: AsyncGenerator[T, None],
) -> AsyncGenerator[tuple[T, ...], None]:
"""
Resolve `AsyncGenerator`s context.
"""
iterators = [ag.__aiter__() for ag in async_generators]
while True:
try:
results = await asyncio.gather(
*[iterator.__anext__() for iterator in iterators]
)
yield tuple(results)
except StopAsyncIteration:
break
yield ()
class LifespanComponent(_Component[LifespanScope, ReceiveLifespan]):
startups: list[AsyncCallable[[], None]]
shutdowns: list[AsyncCallable[[], None]]
contexts: list[Callable[[], AsyncGenerator[Any, None]]]
contexts: list[tuple[str | None, Callable[[], AsyncGenerator[Any, None]]]]
loaded_context: dict[str, Any]
def __init__(self, *args: Any, **kwds: Any) -> None:
super().__init__(*args, **kwds)
self.startups = []
self.shutdowns = []
self.contexts = []
self.loaded_context = {}
@override
async def condition(self, scope: AnyScope) -> TypeGuard[LifespanScope]:
@@ -34,10 +58,18 @@ class LifespanComponent(_Component[LifespanScope, ReceiveLifespan]):
self, scope: LifespanScope, receive: Receive[ReceiveLifespan], send: Send
) -> None:
message = await receive()
async for _ in agzip(*[ctx() for ctx in self.contexts]):
async for ctxs in resolve_context(*[ctx[1]() for ctx in self.contexts]):
if message["type"] == "lifespan.startup":
for fn in self.startups:
await fn()
for name, val in zip((ctx[0] for ctx in self.contexts), ctxs):
if name is None:
continue
if name in self.loaded_context:
raise LifespanError(
f"Name {name!r} is already used by context {self.loaded_context[name]!r}."
)
self.loaded_context[name] = val
await send({"type": "lifespan.startup.complete"})
elif message["type"] == "lifespan.shutdown":
for fn in self.shutdowns:
@@ -54,6 +86,67 @@ class LifespanComponent(_Component[LifespanScope, ReceiveLifespan]):
self.shutdowns.append(fn)
return fn
def on_context[Ctx_T: Callable[[], AsyncGenerator[Any, None]]](self, fn: Ctx_T) -> Ctx_T:
self.contexts.append(fn)
@overload
def on_context(
self, *, name: str | None = None
) -> PassthroughDecorator[Callable[[], AsyncGenerator[Any, None]]]: ...
@overload
def on_context[Ctx_T: Callable[[], AsyncGenerator[Any, None]]](
self, fn: Ctx_T
) -> Ctx_T: ...
def on_context[Ctx_T: Callable[[], AsyncGenerator[Any, None]]](
self, fn: Ctx_T | None = None, *, name: str | None = None
) -> PassthroughDecorator[Ctx_T] | Ctx_T:
if fn is None:
def __wrap_context(fn: Ctx_T) -> Ctx_T:
self.contexts.append((name, fn))
return fn
return __wrap_context
self.contexts.append((name, fn))
return fn
@overload
def add_managed_context(
self,
ctx: AbstractContextManager[Any, Any],
async_: None = None,
name: str | None = None,
) -> None: ...
@overload
def add_managed_context(
self,
ctx: AbstractAsyncContextManager[Any, Any],
async_: None = None,
name: str | None = None,
) -> None: ...
@overload
def add_managed_context(
self, ctx: Any, async_: bool, name: str | None = None
) -> None: ...
def add_managed_context(
self,
ctx: AbstractContextManager[Any, Any] | AbstractAsyncContextManager[Any, Any],
async_: bool | None = None,
name: str | None = None,
) -> None:
if async_ is None:
async_ = isinstance(ctx, AbstractAsyncContextManager)
@self.on_context(name=name)
async def __make_context() -> AsyncGenerator[Any, Any]: # pyright: ignore[reportUnusedFunction]
if async_:
async with cast(AbstractAsyncContextManager[Any, Any], ctx) as c:
yield c
else:
with cast(AbstractContextManager[Any, Any], ctx) as c:
yield c
def get_context[T](self, name: str, type_: type[T] | None = None) -> T: # pyright: ignore[reportUnusedParameter]
return cast(T, self.loaded_context[name])