diff --git a/app/components/base.py b/app/components/base.py index cf5a892..8e12e26 100644 --- a/app/components/base.py +++ b/app/components/base.py @@ -6,6 +6,7 @@ from app.types_ import AnyScope, AsyncCallable, Receive, Send class Component[S: AnyScope, R: Any](metaclass=ABCMeta): + @abstractmethod def __init__(self, *args: Any, **kwds: Any) -> None: pass @@ -22,6 +23,7 @@ class Component[S: AnyScope, R: Any](metaclass=ABCMeta): class RouteComponent[S: AnyScope, Recv_T: Any, Route_T: MutableMapping[str, Any], Route_R: Any](Component[S, Recv_T], metaclass=ABCMeta): routes: Route_T + @abstractmethod def __init__(self, *args: Any, **kwds: Any) -> None: super().__init__(*args, **kwds) @@ -30,6 +32,7 @@ class RouteComponent[S: AnyScope, Recv_T: Any, Route_T: MutableMapping[str, Any] """Route dispatcher""" raise NotImplementedError - def route_install(self, type_: str, route: str, target: AsyncCallable[..., Route_R]) -> None: + @abstractmethod + def route_install(self, route: str, target: AsyncCallable[..., Route_R], *, type_: str | None = None) -> None: """Install route target for specific type and route.""" - self.routes.setdefault(type_, {})[route] = target \ No newline at end of file + raise NotImplementedError \ No newline at end of file diff --git a/app/components/http.py b/app/components/http.py index d307fed..c134490 100644 --- a/app/components/http.py +++ b/app/components/http.py @@ -56,6 +56,13 @@ class HTTPComponent( if scope["path"] == k: # temporary impl. return await callee() + @override + def route_install(self, route: str, target: AsyncCallable[..., Response], *, type_: str | None = None) -> None: + """Install route target for specific type and route.""" + if type_ is None: + raise ValueError("Route type `type_` is unset.") + self.routes.setdefault(type_, {})[route] = target + def route[T: AsyncCallable[..., Response]]( self, route: str, @@ -67,13 +74,13 @@ class HTTPComponent( ) -> PassthroughDecorator[T]: def __wrap_route(fn: T) -> T: if get: - self.route_install("GET", route, fn) + self.route_install(route, fn, type_="GET") if post: - self.route_install("POST", route, fn) + self.route_install(route, fn, type_="POST") if put: - self.route_install("PUT", route, fn) + self.route_install(route, fn, type_="PUT") if delete: - self.route_install("DELETE", route, fn) + self.route_install(route, fn, type_="DELETE") return fn return __wrap_route diff --git a/app/components/lifespan.py b/app/components/lifespan.py index efb9171..f9d3ebc 100644 --- a/app/components/lifespan.py +++ b/app/components/lifespan.py @@ -1,25 +1,59 @@ -from typing import TypeGuard, override +from collections.abc import AsyncGenerator, Callable +from typing import Any, TypeGuard, override -from app.types_ import AnyScope, LifespanScope, Receive, ReceiveLifespan, Send +from app.subroutines.asyncutils import agzip +from app.types_ import ( + AnyScope, + AsyncCallable, + LifespanScope, + Receive, + ReceiveLifespan, + Send, +) from .base import Component as _Component class LifespanComponent(_Component[LifespanScope, ReceiveLifespan]): + startups: list[AsyncCallable[[], None]] + shutdowns: list[AsyncCallable[[], None]] + contexts: list[Callable[[], AsyncGenerator[Any, None]]] + + def __init__(self, *args: Any, **kwds: Any) -> None: + super().__init__(*args, **kwds) + self.startups = [] + self.shutdowns = [] + self.contexts = [] + @override async def condition(self, scope: AnyScope) -> TypeGuard[LifespanScope]: return scope["type"] == "lifespan" @override - async def handle(self, scope: LifespanScope, receive: Receive[ReceiveLifespan], send: Send) -> None: - while True: + async def handle( + self, scope: LifespanScope, receive: Receive[ReceiveLifespan], send: Send + ) -> None: + message = await receive() + async for _ in agzip(*[ctx() for ctx in self.contexts]): + if message["type"] == "lifespan.startup": + for fn in self.startups: + await fn() + await send({"type": "lifespan.startup.complete"}) + elif message["type"] == "lifespan.shutdown": + for fn in self.shutdowns: + await fn() + await send({"type": "lifespan.shutdown.complete"}) + return message = await receive() - if message['type'] == 'lifespan.startup': - ... # Do some startup here! - print("Startup...") - await send({'type': 'lifespan.startup.complete'}) - elif message['type'] == 'lifespan.shutdown': - ... # Do some shutdown here! - print("Shutdown...") - await send({'type': 'lifespan.shutdown.complete'}) - return \ No newline at end of file + + def on_startup[Call_T: AsyncCallable[[], None]](self, fn: Call_T) -> Call_T: + self.startups.append(fn) + return fn + + def on_shutdown[Call_T: AsyncCallable[[], None]](self, fn: Call_T) -> Call_T: + self.shutdowns.append(fn) + return fn + + def on_context[Ctx_T: Callable[[], AsyncGenerator[Any, None]]](self, fn: Ctx_T) -> Ctx_T: + self.contexts.append(fn) + return fn diff --git a/app/subroutines/asyncutils.py b/app/subroutines/asyncutils.py new file mode 100644 index 0000000..d3ff2a3 --- /dev/null +++ b/app/subroutines/asyncutils.py @@ -0,0 +1,18 @@ +import asyncio +from collections.abc import AsyncGenerator +from typing import TypeVar + +T = TypeVar('T') + +async def agzip(*async_generators: AsyncGenerator[T, None]) -> AsyncGenerator[tuple[T, ...], None]: + """ + `zip()`-like function for `AsyncGenerator`s. + """ + iterators = [ag.__aiter__() for ag in async_generators] + + while True: + try: + results = await asyncio.gather(*[iterator.__anext__() for iterator in iterators]) + yield tuple(results) + except StopAsyncIteration: + break \ No newline at end of file diff --git a/test.py b/test.py index 7e0e126..b134202 100644 --- a/test.py +++ b/test.py @@ -1,15 +1,36 @@ +from collections.abc import AsyncGenerator +from typing import Any + from app import App from app.components.http import HTTPComponent - -# from app.components.lifespan import LifespanComponent +from app.components.lifespan import LifespanComponent from app.subroutines.http import HTMLResponse app = App() -# lifespan = app.use_component(LifespanComponent()) +lifespan = app.use_component(LifespanComponent()) http = app.use_component(HTTPComponent()) +@lifespan.on_context +async def my_context() -> AsyncGenerator[Any, None]: + try: + print("Start!") + yield + finally: + print("Stop!") + + +# @lifespan.on_startup +# async def start() -> None: +# print("Start!") + + +# @lifespan.on_shutdown +# async def stop() -> None: +# print("Stop!") + + @http.route("/teapot", get=True, post=True, put=True, delete=True) async def teapot() -> HTMLResponse: resp = """