This commit is contained in:
2025-04-15 16:36:23 +08:00
parent dcceb275ad
commit 7cb0a87d9c
5 changed files with 111 additions and 33 deletions

View File

@@ -71,6 +71,7 @@ class HTTPComponent(
post: bool = False, post: bool = False,
put: bool = False, put: bool = False,
delete: bool = False, delete: bool = False,
head: bool = False,
) -> PassthroughDecorator[T]: ) -> PassthroughDecorator[T]:
def __wrap_route(fn: T) -> T: def __wrap_route(fn: T) -> T:
if get: if get:
@@ -81,6 +82,8 @@ class HTTPComponent(
self.route_install(route, fn, type_="PUT") self.route_install(route, fn, type_="PUT")
if delete: if delete:
self.route_install(route, fn, type_="DELETE") self.route_install(route, fn, type_="DELETE")
if head:
self.route_install(route, fn, type_="HEAD")
return fn return fn
return __wrap_route return __wrap_route

View File

@@ -1,11 +1,14 @@
import asyncio
from collections.abc import AsyncGenerator, Callable from collections.abc import AsyncGenerator, Callable
from typing import Any, TypeGuard, override from contextlib import AbstractAsyncContextManager, AbstractContextManager
from typing import Any, TypeGuard, cast, overload, override
from app.subroutines.asyncutils import agzip from app.exceptions import LifespanError
from app.types_ import ( from app.types_ import (
AnyScope, AnyScope,
AsyncCallable, AsyncCallable,
LifespanScope, LifespanScope,
PassthroughDecorator,
Receive, Receive,
ReceiveLifespan, ReceiveLifespan,
Send, Send,
@@ -14,16 +17,37 @@ from app.types_ import (
from .base import Component as _Component from .base import Component as _Component
async def resolve_context[T](
*async_generators: AsyncGenerator[T, None],
) -> AsyncGenerator[tuple[T, ...], None]:
"""
Resolve `AsyncGenerator`s context.
"""
iterators = [ag.__aiter__() for ag in async_generators]
while True:
try:
results = await asyncio.gather(
*[iterator.__anext__() for iterator in iterators]
)
yield tuple(results)
except StopAsyncIteration:
break
yield ()
class LifespanComponent(_Component[LifespanScope, ReceiveLifespan]): class LifespanComponent(_Component[LifespanScope, ReceiveLifespan]):
startups: list[AsyncCallable[[], None]] startups: list[AsyncCallable[[], None]]
shutdowns: list[AsyncCallable[[], None]] shutdowns: list[AsyncCallable[[], None]]
contexts: list[Callable[[], AsyncGenerator[Any, None]]] contexts: list[tuple[str | None, Callable[[], AsyncGenerator[Any, None]]]]
loaded_context: dict[str, Any]
def __init__(self, *args: Any, **kwds: Any) -> None: def __init__(self, *args: Any, **kwds: Any) -> None:
super().__init__(*args, **kwds) super().__init__(*args, **kwds)
self.startups = [] self.startups = []
self.shutdowns = [] self.shutdowns = []
self.contexts = [] self.contexts = []
self.loaded_context = {}
@override @override
async def condition(self, scope: AnyScope) -> TypeGuard[LifespanScope]: async def condition(self, scope: AnyScope) -> TypeGuard[LifespanScope]:
@@ -34,10 +58,18 @@ class LifespanComponent(_Component[LifespanScope, ReceiveLifespan]):
self, scope: LifespanScope, receive: Receive[ReceiveLifespan], send: Send self, scope: LifespanScope, receive: Receive[ReceiveLifespan], send: Send
) -> None: ) -> None:
message = await receive() message = await receive()
async for _ in agzip(*[ctx() for ctx in self.contexts]): async for ctxs in resolve_context(*[ctx[1]() for ctx in self.contexts]):
if message["type"] == "lifespan.startup": if message["type"] == "lifespan.startup":
for fn in self.startups: for fn in self.startups:
await fn() await fn()
for name, val in zip((ctx[0] for ctx in self.contexts), ctxs):
if name is None:
continue
if name in self.loaded_context:
raise LifespanError(
f"Name {name!r} is already used by context {self.loaded_context[name]!r}."
)
self.loaded_context[name] = val
await send({"type": "lifespan.startup.complete"}) await send({"type": "lifespan.startup.complete"})
elif message["type"] == "lifespan.shutdown": elif message["type"] == "lifespan.shutdown":
for fn in self.shutdowns: for fn in self.shutdowns:
@@ -54,6 +86,67 @@ class LifespanComponent(_Component[LifespanScope, ReceiveLifespan]):
self.shutdowns.append(fn) self.shutdowns.append(fn)
return fn return fn
def on_context[Ctx_T: Callable[[], AsyncGenerator[Any, None]]](self, fn: Ctx_T) -> Ctx_T: @overload
self.contexts.append(fn) def on_context(
self, *, name: str | None = None
) -> PassthroughDecorator[Callable[[], AsyncGenerator[Any, None]]]: ...
@overload
def on_context[Ctx_T: Callable[[], AsyncGenerator[Any, None]]](
self, fn: Ctx_T
) -> Ctx_T: ...
def on_context[Ctx_T: Callable[[], AsyncGenerator[Any, None]]](
self, fn: Ctx_T | None = None, *, name: str | None = None
) -> PassthroughDecorator[Ctx_T] | Ctx_T:
if fn is None:
def __wrap_context(fn: Ctx_T) -> Ctx_T:
self.contexts.append((name, fn))
return fn
return __wrap_context
self.contexts.append((name, fn))
return fn return fn
@overload
def add_managed_context(
self,
ctx: AbstractContextManager[Any, Any],
async_: None = None,
name: str | None = None,
) -> None: ...
@overload
def add_managed_context(
self,
ctx: AbstractAsyncContextManager[Any, Any],
async_: None = None,
name: str | None = None,
) -> None: ...
@overload
def add_managed_context(
self, ctx: Any, async_: bool, name: str | None = None
) -> None: ...
def add_managed_context(
self,
ctx: AbstractContextManager[Any, Any] | AbstractAsyncContextManager[Any, Any],
async_: bool | None = None,
name: str | None = None,
) -> None:
if async_ is None:
async_ = isinstance(ctx, AbstractAsyncContextManager)
@self.on_context(name=name)
async def __make_context() -> AsyncGenerator[Any, Any]: # pyright: ignore[reportUnusedFunction]
if async_:
async with cast(AbstractAsyncContextManager[Any, Any], ctx) as c:
yield c
else:
with cast(AbstractContextManager[Any, Any], ctx) as c:
yield c
def get_context[T](self, name: str, type_: type[T] | None = None) -> T: # pyright: ignore[reportUnusedParameter]
return cast(T, self.loaded_context[name])

View File

@@ -2,6 +2,10 @@ class AppError(Exception):
pass pass
class LifespanError(AppError):
pass
class ConnectionClosed(AppError): class ConnectionClosed(AppError):
pass pass

View File

@@ -1,18 +0,0 @@
import asyncio
from collections.abc import AsyncGenerator
from typing import TypeVar
T = TypeVar('T')
async def agzip(*async_generators: AsyncGenerator[T, None]) -> AsyncGenerator[tuple[T, ...], None]:
"""
`zip()`-like function for `AsyncGenerator`s.
"""
iterators = [ag.__aiter__() for ag in async_generators]
while True:
try:
results = await asyncio.gather(*[iterator.__anext__() for iterator in iterators])
yield tuple(results)
except StopAsyncIteration:
break

14
test.py
View File

@@ -1,5 +1,5 @@
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator
from typing import Any from typing import Any, TextIO
from app import App from app import App
from app.components.http import HTTPComponent from app.components.http import HTTPComponent
@@ -21,14 +21,7 @@ async def my_context() -> AsyncGenerator[Any, None]:
print("Stop!") print("Stop!")
# @lifespan.on_startup lifespan.add_managed_context(open("teapot.log", "w"), name="teapot_log")
# async def start() -> None:
# print("Start!")
# @lifespan.on_shutdown
# async def stop() -> None:
# print("Stop!")
@http.route("/teapot", get=True, post=True, put=True, delete=True) @http.route("/teapot", get=True, post=True, put=True, delete=True)
@@ -45,4 +38,7 @@ async def teapot() -> HTMLResponse:
</body> </body>
</html> </html>
""" """
log = lifespan.get_context("teapot_log", TextIO)
_ = log.write("teapot\n")
log.flush()
return HTMLResponse(status=418, content=resp) return HTMLResponse(status=418, content=resp)