From 7cb0a87d9cb6e79dc13885914088986e805cd9c2 Mon Sep 17 00:00:00 2001 From: worldmozara Date: Tue, 15 Apr 2025 16:36:23 +0800 Subject: [PATCH] upload --- app/components/http.py | 3 + app/components/lifespan.py | 105 ++++++++++++++++++++++++++++++++-- app/exceptions.py | 4 ++ app/subroutines/asyncutils.py | 18 ------ test.py | 14 ++--- 5 files changed, 111 insertions(+), 33 deletions(-) delete mode 100644 app/subroutines/asyncutils.py diff --git a/app/components/http.py b/app/components/http.py index c134490..9b05de9 100644 --- a/app/components/http.py +++ b/app/components/http.py @@ -71,6 +71,7 @@ class HTTPComponent( post: bool = False, put: bool = False, delete: bool = False, + head: bool = False, ) -> PassthroughDecorator[T]: def __wrap_route(fn: T) -> T: if get: @@ -81,6 +82,8 @@ class HTTPComponent( self.route_install(route, fn, type_="PUT") if delete: self.route_install(route, fn, type_="DELETE") + if head: + self.route_install(route, fn, type_="HEAD") return fn return __wrap_route diff --git a/app/components/lifespan.py b/app/components/lifespan.py index f9d3ebc..0b2775d 100644 --- a/app/components/lifespan.py +++ b/app/components/lifespan.py @@ -1,11 +1,14 @@ +import asyncio from collections.abc import AsyncGenerator, Callable -from typing import Any, TypeGuard, override +from contextlib import AbstractAsyncContextManager, AbstractContextManager +from typing import Any, TypeGuard, cast, overload, override -from app.subroutines.asyncutils import agzip +from app.exceptions import LifespanError from app.types_ import ( AnyScope, AsyncCallable, LifespanScope, + PassthroughDecorator, Receive, ReceiveLifespan, Send, @@ -14,16 +17,37 @@ from app.types_ import ( from .base import Component as _Component +async def resolve_context[T]( + *async_generators: AsyncGenerator[T, None], +) -> AsyncGenerator[tuple[T, ...], None]: + """ + Resolve `AsyncGenerator`s context. + """ + iterators = [ag.__aiter__() for ag in async_generators] + + while True: + try: + results = await asyncio.gather( + *[iterator.__anext__() for iterator in iterators] + ) + yield tuple(results) + except StopAsyncIteration: + break + yield () + + class LifespanComponent(_Component[LifespanScope, ReceiveLifespan]): startups: list[AsyncCallable[[], None]] shutdowns: list[AsyncCallable[[], None]] - contexts: list[Callable[[], AsyncGenerator[Any, None]]] + contexts: list[tuple[str | None, Callable[[], AsyncGenerator[Any, None]]]] + loaded_context: dict[str, Any] def __init__(self, *args: Any, **kwds: Any) -> None: super().__init__(*args, **kwds) self.startups = [] self.shutdowns = [] self.contexts = [] + self.loaded_context = {} @override async def condition(self, scope: AnyScope) -> TypeGuard[LifespanScope]: @@ -34,10 +58,18 @@ class LifespanComponent(_Component[LifespanScope, ReceiveLifespan]): self, scope: LifespanScope, receive: Receive[ReceiveLifespan], send: Send ) -> None: message = await receive() - async for _ in agzip(*[ctx() for ctx in self.contexts]): + async for ctxs in resolve_context(*[ctx[1]() for ctx in self.contexts]): if message["type"] == "lifespan.startup": for fn in self.startups: await fn() + for name, val in zip((ctx[0] for ctx in self.contexts), ctxs): + if name is None: + continue + if name in self.loaded_context: + raise LifespanError( + f"Name {name!r} is already used by context {self.loaded_context[name]!r}." + ) + self.loaded_context[name] = val await send({"type": "lifespan.startup.complete"}) elif message["type"] == "lifespan.shutdown": for fn in self.shutdowns: @@ -54,6 +86,67 @@ class LifespanComponent(_Component[LifespanScope, ReceiveLifespan]): self.shutdowns.append(fn) return fn - def on_context[Ctx_T: Callable[[], AsyncGenerator[Any, None]]](self, fn: Ctx_T) -> Ctx_T: - self.contexts.append(fn) + @overload + def on_context( + self, *, name: str | None = None + ) -> PassthroughDecorator[Callable[[], AsyncGenerator[Any, None]]]: ... + + @overload + def on_context[Ctx_T: Callable[[], AsyncGenerator[Any, None]]]( + self, fn: Ctx_T + ) -> Ctx_T: ... + + def on_context[Ctx_T: Callable[[], AsyncGenerator[Any, None]]]( + self, fn: Ctx_T | None = None, *, name: str | None = None + ) -> PassthroughDecorator[Ctx_T] | Ctx_T: + if fn is None: + + def __wrap_context(fn: Ctx_T) -> Ctx_T: + self.contexts.append((name, fn)) + return fn + + return __wrap_context + self.contexts.append((name, fn)) return fn + + @overload + def add_managed_context( + self, + ctx: AbstractContextManager[Any, Any], + async_: None = None, + name: str | None = None, + ) -> None: ... + + @overload + def add_managed_context( + self, + ctx: AbstractAsyncContextManager[Any, Any], + async_: None = None, + name: str | None = None, + ) -> None: ... + + @overload + def add_managed_context( + self, ctx: Any, async_: bool, name: str | None = None + ) -> None: ... + + def add_managed_context( + self, + ctx: AbstractContextManager[Any, Any] | AbstractAsyncContextManager[Any, Any], + async_: bool | None = None, + name: str | None = None, + ) -> None: + if async_ is None: + async_ = isinstance(ctx, AbstractAsyncContextManager) + + @self.on_context(name=name) + async def __make_context() -> AsyncGenerator[Any, Any]: # pyright: ignore[reportUnusedFunction] + if async_: + async with cast(AbstractAsyncContextManager[Any, Any], ctx) as c: + yield c + else: + with cast(AbstractContextManager[Any, Any], ctx) as c: + yield c + + def get_context[T](self, name: str, type_: type[T] | None = None) -> T: # pyright: ignore[reportUnusedParameter] + return cast(T, self.loaded_context[name]) diff --git a/app/exceptions.py b/app/exceptions.py index 39c676b..7c4b20b 100644 --- a/app/exceptions.py +++ b/app/exceptions.py @@ -2,6 +2,10 @@ class AppError(Exception): pass +class LifespanError(AppError): + pass + + class ConnectionClosed(AppError): pass diff --git a/app/subroutines/asyncutils.py b/app/subroutines/asyncutils.py deleted file mode 100644 index d3ff2a3..0000000 --- a/app/subroutines/asyncutils.py +++ /dev/null @@ -1,18 +0,0 @@ -import asyncio -from collections.abc import AsyncGenerator -from typing import TypeVar - -T = TypeVar('T') - -async def agzip(*async_generators: AsyncGenerator[T, None]) -> AsyncGenerator[tuple[T, ...], None]: - """ - `zip()`-like function for `AsyncGenerator`s. - """ - iterators = [ag.__aiter__() for ag in async_generators] - - while True: - try: - results = await asyncio.gather(*[iterator.__anext__() for iterator in iterators]) - yield tuple(results) - except StopAsyncIteration: - break \ No newline at end of file diff --git a/test.py b/test.py index b134202..180e6e5 100644 --- a/test.py +++ b/test.py @@ -1,5 +1,5 @@ from collections.abc import AsyncGenerator -from typing import Any +from typing import Any, TextIO from app import App from app.components.http import HTTPComponent @@ -21,14 +21,7 @@ async def my_context() -> AsyncGenerator[Any, None]: print("Stop!") -# @lifespan.on_startup -# async def start() -> None: -# print("Start!") - - -# @lifespan.on_shutdown -# async def stop() -> None: -# print("Stop!") +lifespan.add_managed_context(open("teapot.log", "w"), name="teapot_log") @http.route("/teapot", get=True, post=True, put=True, delete=True) @@ -45,4 +38,7 @@ async def teapot() -> HTMLResponse: """ + log = lifespan.get_context("teapot_log", TextIO) + _ = log.write("teapot\n") + log.flush() return HTMLResponse(status=418, content=resp)