diff --git a/src/libkernelbot/background_submission_manager.py b/src/libkernelbot/background_submission_manager.py index fb98d779..c3d3d8b5 100644 --- a/src/libkernelbot/background_submission_manager.py +++ b/src/libkernelbot/background_submission_manager.py @@ -86,10 +86,34 @@ def __init__( self.min_workers = min_workers self.max_workers = max_workers + def _prune_finished_workers_locked(self): + """Drop finished worker tasks before using _workers as capacity.""" + alive = [] + for task in self._workers: + if not task.done(): + alive.append(task) + continue + + if task.cancelled(): + logger.info("[Background Job] pruned cancelled worker %r", task.get_name()) + continue + + exc = task.exception() + if exc is not None: + logger.error( + "[Background Job] pruned failed worker %r", + task.get_name(), + exc_info=(type(exc), exc, exc.__traceback__), + ) + else: + logger.info("[Background Job] pruned finished worker %r", task.get_name()) + self._workers = alive + async def start(self): logger.info("[Background Job] starting background submission manager") async with self._state_lock: self._accepting = True + self._prune_finished_workers_locked() need = max(0, self.min_workers - len(self._workers)) for _ in range(need): t = asyncio.create_task(self._worker_loop(), name="bg-worker") @@ -140,12 +164,51 @@ async def enqueue( await self._autoscale_up() return job_id, sub_id + async def _maybe_scale_down_idle_worker(self) -> bool: + async with self._state_lock: + self._prune_finished_workers_locked() + me = asyncio.current_task() + if len(self._workers) <= self.min_workers or me not in self._workers: + return False + + try: + self._workers.remove(me) + logger.info( + "[Background Job][worker %r] idle too long," + "scale down; existing workers=%d", + me.get_name() if hasattr(me, "get_name") else id(me), + len(self._workers), + ) + except ValueError: + pass + return True + + async def _mark_job_failed_after_worker_crash(self, item: JobItem): + ts = dt.datetime.now(dt.timezone.utc) + try: + with self.backend.db as db: + db.upsert_submission_job_status( + item.sub_id, + status="failed", + last_heartbeat=ts, + error="worker crashed while processing submission", + ) + except Exception: + logger.error( + "[Background Job][worker %r] failed to mark crashed " + "submission job `%s`", + id(asyncio.current_task()), + item.sub_id, + exc_info=True, + ) + async def _worker_loop(self): """ A worker will keep listening to the queue, and process the job in the queue. If the queue is empty, it will exit after idle_seconds. Each worker only handles one submission job at a time """ + crashed = False try: while True: try: @@ -158,26 +221,8 @@ async def _worker_loop(self): item.sub_id, ) except asyncio.TimeoutError: - async with self._state_lock: - me = asyncio.current_task() - if ( - len(self._workers) > self.min_workers - and me in self._workers - ): - try: - self._workers.remove(me) - logger.info( - "[Background Job][worker %r] idle too long," - "scale down; existing workers=%d", - me.get_name() - if hasattr(me, "get_name") - else id(me), - len(self._workers), - ) - except ValueError: - pass - return # scale down: exit - + if await self._maybe_scale_down_idle_worker(): + return # scale down: exit continue t = asyncio.create_task( @@ -188,6 +233,14 @@ async def _worker_loop(self): self._live_tasks.add(t) try: await t # wait submission job to finish + except Exception: + logger.error( + "[Background Job][worker %r] submission job `%s` crashed", + id(asyncio.current_task()), + item.sub_id, + exc_info=True, + ) + await self._mark_job_failed_after_worker_crash(item) finally: logger.info( "[Background Job][worker %r] finishes the submission job `%s`", @@ -199,6 +252,20 @@ async def _worker_loop(self): self.queue.task_done() except asyncio.CancelledError: return + except Exception: + crashed = True + logger.error( + "[Background Job][worker %r] worker loop crashed", + id(asyncio.current_task()), + exc_info=True, + ) + finally: + me = asyncio.current_task() + async with self._state_lock: + if me in self._workers: + self._workers.remove(me) + if crashed: + await self._autoscale_up() async def _task_done_async(self, tt: asyncio.Task, item: JobItem): async with self._state_lock: @@ -211,13 +278,9 @@ async def _run_one(self, item: JobItem): now = dt.datetime.now(dt.timezone.utc) logger.info("[Background Job] start processing submission %s", sub_id) - with self.backend.db as db: - db.upsert_submission_job_status( - sub_id, status="running", last_heartbeat=now - ) - # heartbeat loop continuously update the last heartbeat time for the submission status stop_heartbeat = asyncio.Event() + hb_task = None async def heartbeat(): while not stop_heartbeat.is_set(): @@ -229,8 +292,13 @@ async def heartbeat(): except Exception: pass - hb_task = asyncio.create_task(heartbeat(), name=f"hb-{sub_id}") try: + with self.backend.db as db: + db.upsert_submission_job_status( + sub_id, status="running", last_heartbeat=now + ) + + hb_task = asyncio.create_task(heartbeat(), name=f"hb-{sub_id}") reporter = BackgroundSubmissionManagerReporter() await asyncio.wait_for( self.backend.submit_full( @@ -291,12 +359,14 @@ async def heartbeat(): ) finally: stop_heartbeat.set() - hb_task.cancel() - with contextlib.suppress(asyncio.CancelledError): - await hb_task + if hb_task is not None: + hb_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await hb_task async def _autoscale_up(self): async with self._state_lock: + self._prune_finished_workers_locked() running = len(self._live_tasks) workers = len(self._workers) qsize = self.queue.qsize() diff --git a/tests/test_background_submission_manager.py b/tests/test_background_submission_manager.py index cfc28b8f..b072a17d 100644 --- a/tests/test_background_submission_manager.py +++ b/tests/test_background_submission_manager.py @@ -153,6 +153,77 @@ async def fake_submit_full(req, mode, reporter, sub_id, skip_precheck=False): await manager.stop() +@pytest.mark.asyncio +async def test_enqueue_prunes_dead_workers_before_autoscaling(mock_backend): + db_context = mock_backend.db + db_context.upsert_submission_job_status = mock.Mock( + side_effect=lambda *a, **k: a[0] + ) + db_context.update_heartbeat_if_active = mock.Mock() + + async def fake_submit_full(req, mode, reporter, sub_id, skip_precheck=False): + return None, None + + mock_backend.submit_full = fake_submit_full + + manager = BackgroundSubmissionManager( + mock_backend, min_workers=0, max_workers=1, idle_seconds=0.1 + ) + await manager.start() + + async def dead_worker(): + raise RuntimeError("dead worker") + + dead_task = asyncio.create_task(dead_worker(), name="dead-bg-worker") + await asyncio.sleep(0) + async with manager._state_lock: + manager._workers.append(dead_task) + + await manager.enqueue(get_req(1), SubmissionMode.TEST, sub_id=99) + await manager.queue.join() + + assert ( + mock.call(99, status="succeeded", last_heartbeat=mock.ANY) + in db_context.upsert_submission_job_status.call_args_list + ) + + await manager.stop() + + +@pytest.mark.asyncio +async def test_run_one_initial_status_failure_marks_failed(mock_backend): + db_context = mock_backend.db + statuses = [] + + def fake_upsert(sub_id, status=None, error=None, last_heartbeat=None): + statuses.append((status, error)) + if status == "running": + raise RuntimeError("database unavailable") + return sub_id + + db_context.upsert_submission_job_status = mock.Mock(side_effect=fake_upsert) + db_context.update_heartbeat_if_active = mock.Mock() + mock_backend.submit_full = mock.AsyncMock() + + manager = BackgroundSubmissionManager( + mock_backend, min_workers=1, max_workers=1, idle_seconds=0.1 + ) + await manager.start() + + await manager.enqueue(get_req(1), SubmissionMode.TEST, sub_id=123) + await manager.queue.join() + + assert ("pending", None) in statuses + assert ("running", None) in statuses + assert any(status == "failed" and error == "database unavailable" for status, error in statuses) + mock_backend.submit_full.assert_not_called() + + async with manager._state_lock: + assert len(manager._workers) == 1 + + await manager.stop() + + @pytest.mark.asyncio async def test_hacked_submission_sets_hacked_status(mock_backend): db_context = mock_backend.db