11import asyncio
2+ import concurrent .futures
23from collections import deque , defaultdict
34from datetime import timedelta
45from itertools import chain
56import functools
67import logging
78import threading
89from time import time
9- from typing import Any , Callable , Hashable , Union
10+ from typing import Any , Callable , Coroutine , Hashable , Tuple , Union , overload
1011import weakref
1112
1213import toolz
@@ -730,6 +731,8 @@ class map_async(Stream):
730731 The arguments to pass to the function.
731732 parallelism:
732733 The maximum number of parallel Tasks for evaluating func, default value is 1
734+ stop_on_exception:
735+ If the mapped func raises an exception, should the stream stop or not. Default value is False.
733736 **kwargs:
734737 Keyword arguments to pass to func
735738
@@ -749,38 +752,72 @@ class map_async(Stream):
749752 6
750753 8
751754 """
752- def __init__ (self , upstream , func , * args , parallelism = 1 , ** kwargs ):
755+ def __init__ (self , upstream , func , * args , parallelism = 1 , stop_on_exception = False , ** kwargs ):
753756 self .func = func
754757 stream_name = kwargs .pop ('stream_name' , None )
755758 self .kwargs = kwargs
756759 self .args = args
760+ self .stop_on_exception = stop_on_exception
757761 self .work_queue = asyncio .Queue (maxsize = parallelism )
758762
759763 Stream .__init__ (self , upstream , stream_name = stream_name , ensure_io_loop = True )
760- self .work_task = self ._create_task (self .work_callback ())
764+ self .work_task = None
765+
766+ def _create_work_task (self ) -> Tuple [asyncio .Event , asyncio .Task [None ]]:
767+ stop_work = asyncio .Event ()
768+ work_task = self ._create_task (self .work_callback (stop_work ))
769+ return stop_work , work_task
770+
771+ def start (self ):
772+ if self .work_task :
773+ stop_work , _ = self .work_task
774+ stop_work .set ()
775+ self .work_task = self ._create_work_task ()
776+ super ().start ()
777+
778+ def stop (self ):
779+ stop_work , _ = self .work_task
780+ stop_work .set ()
781+ self .work_task = None
782+ super ().stop ()
761783
762784 def update (self , x , who = None , metadata = None ):
785+ if not self .work_task :
786+ self .work_task = self ._create_work_task ()
763787 return self ._create_task (self ._insert_job (x , metadata ))
764788
789+ @overload
790+ def _create_task (self , coro : asyncio .Future ) -> asyncio .Future :
791+ ...
792+
793+ @overload
794+ def _create_task (self , coro : concurrent .futures .Future ) -> concurrent .futures .Future :
795+ ...
796+
797+ @overload
798+ def _create_task (self , coro : Coroutine ) -> asyncio .Task :
799+ ...
800+
765801 def _create_task (self , coro ):
766802 if gen .is_future (coro ):
767803 return coro
768804 return self .loop .asyncio_loop .create_task (coro )
769805
770- async def work_callback (self ):
771- while True :
806+ async def work_callback (self , stop_work : asyncio .Event ):
807+ while not stop_work .is_set ():
808+ task , metadata = await self .work_queue .get ()
809+ self .work_queue .task_done ()
772810 try :
773- task , metadata = await self .work_queue .get ()
774- self .work_queue .task_done ()
775811 result = await task
776812 except Exception as e :
777813 logger .exception (e )
778- raise
814+ if self .stop_on_exception :
815+ self .stop ()
779816 else :
780817 results = self ._emit (result , metadata = metadata )
781818 if results :
782819 await asyncio .gather (* results )
783- self ._release_refs (metadata )
820+ self ._release_refs (metadata )
784821
785822 async def _wait_for_work_slot (self ):
786823 while self .work_queue .full ():
0 commit comments