Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 99 additions & 29 deletions src/libkernelbot/background_submission_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,34 @@ def __init__(
self.min_workers = min_workers
self.max_workers = max_workers

def _prune_finished_workers_locked(self):
"""Drop finished worker tasks before using _workers as capacity."""
alive = []
for task in self._workers:
if not task.done():
alive.append(task)
continue

if task.cancelled():
logger.info("[Background Job] pruned cancelled worker %r", task.get_name())
continue

exc = task.exception()
if exc is not None:
logger.error(
"[Background Job] pruned failed worker %r",
task.get_name(),
exc_info=(type(exc), exc, exc.__traceback__),
)
else:
logger.info("[Background Job] pruned finished worker %r", task.get_name())
self._workers = alive

async def start(self):
logger.info("[Background Job] starting background submission manager")
async with self._state_lock:
self._accepting = True
self._prune_finished_workers_locked()
need = max(0, self.min_workers - len(self._workers))
for _ in range(need):
t = asyncio.create_task(self._worker_loop(), name="bg-worker")
Expand Down Expand Up @@ -140,12 +164,51 @@ async def enqueue(
await self._autoscale_up()
return job_id, sub_id

async def _maybe_scale_down_idle_worker(self) -> bool:
async with self._state_lock:
self._prune_finished_workers_locked()
me = asyncio.current_task()
if len(self._workers) <= self.min_workers or me not in self._workers:
return False

try:
self._workers.remove(me)
logger.info(
"[Background Job][worker %r] idle too long,"
"scale down; existing workers=%d",
me.get_name() if hasattr(me, "get_name") else id(me),
len(self._workers),
)
except ValueError:
pass
return True

async def _mark_job_failed_after_worker_crash(self, item: JobItem):
ts = dt.datetime.now(dt.timezone.utc)
try:
with self.backend.db as db:
db.upsert_submission_job_status(
item.sub_id,
status="failed",
last_heartbeat=ts,
error="worker crashed while processing submission",
)
except Exception:
logger.error(
"[Background Job][worker %r] failed to mark crashed "
"submission job `%s`",
id(asyncio.current_task()),
item.sub_id,
exc_info=True,
)

async def _worker_loop(self):
"""
A worker will keep listening to the queue, and process the job in the queue.
If the queue is empty, it will exit after idle_seconds.
Each worker only handles one submission job at a time
"""
crashed = False
try:
while True:
try:
Expand All @@ -158,26 +221,8 @@ async def _worker_loop(self):
item.sub_id,
)
except asyncio.TimeoutError:
async with self._state_lock:
me = asyncio.current_task()
if (
len(self._workers) > self.min_workers
and me in self._workers
):
try:
self._workers.remove(me)
logger.info(
"[Background Job][worker %r] idle too long,"
"scale down; existing workers=%d",
me.get_name()
if hasattr(me, "get_name")
else id(me),
len(self._workers),
)
except ValueError:
pass
return # scale down: exit

if await self._maybe_scale_down_idle_worker():
return # scale down: exit
continue

t = asyncio.create_task(
Expand All @@ -188,6 +233,14 @@ async def _worker_loop(self):
self._live_tasks.add(t)
try:
await t # wait submission job to finish
except Exception:
logger.error(
"[Background Job][worker %r] submission job `%s` crashed",
id(asyncio.current_task()),
item.sub_id,
exc_info=True,
)
await self._mark_job_failed_after_worker_crash(item)
finally:
logger.info(
"[Background Job][worker %r] finishes the submission job `%s`",
Expand All @@ -199,6 +252,20 @@ async def _worker_loop(self):
self.queue.task_done()
except asyncio.CancelledError:
return
except Exception:
crashed = True
logger.error(
"[Background Job][worker %r] worker loop crashed",
id(asyncio.current_task()),
exc_info=True,
)
finally:
me = asyncio.current_task()
async with self._state_lock:
if me in self._workers:
self._workers.remove(me)
if crashed:
await self._autoscale_up()

async def _task_done_async(self, tt: asyncio.Task, item: JobItem):
async with self._state_lock:
Expand All @@ -211,13 +278,9 @@ async def _run_one(self, item: JobItem):
now = dt.datetime.now(dt.timezone.utc)

logger.info("[Background Job] start processing submission %s", sub_id)
with self.backend.db as db:
db.upsert_submission_job_status(
sub_id, status="running", last_heartbeat=now
)

# heartbeat loop continuously update the last heartbeat time for the submission status
stop_heartbeat = asyncio.Event()
hb_task = None

async def heartbeat():
while not stop_heartbeat.is_set():
Expand All @@ -229,8 +292,13 @@ async def heartbeat():
except Exception:
pass

hb_task = asyncio.create_task(heartbeat(), name=f"hb-{sub_id}")
try:
with self.backend.db as db:
db.upsert_submission_job_status(
sub_id, status="running", last_heartbeat=now
)

hb_task = asyncio.create_task(heartbeat(), name=f"hb-{sub_id}")
reporter = BackgroundSubmissionManagerReporter()
await asyncio.wait_for(
self.backend.submit_full(
Expand Down Expand Up @@ -291,12 +359,14 @@ async def heartbeat():
)
finally:
stop_heartbeat.set()
hb_task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await hb_task
if hb_task is not None:
hb_task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await hb_task

async def _autoscale_up(self):
async with self._state_lock:
self._prune_finished_workers_locked()
running = len(self._live_tasks)
workers = len(self._workers)
qsize = self.queue.qsize()
Expand Down
71 changes: 71 additions & 0 deletions tests/test_background_submission_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,77 @@ async def fake_submit_full(req, mode, reporter, sub_id, skip_precheck=False):
await manager.stop()


@pytest.mark.asyncio
async def test_enqueue_prunes_dead_workers_before_autoscaling(mock_backend):
db_context = mock_backend.db
db_context.upsert_submission_job_status = mock.Mock(
side_effect=lambda *a, **k: a[0]
)
db_context.update_heartbeat_if_active = mock.Mock()

async def fake_submit_full(req, mode, reporter, sub_id, skip_precheck=False):
return None, None

mock_backend.submit_full = fake_submit_full

manager = BackgroundSubmissionManager(
mock_backend, min_workers=0, max_workers=1, idle_seconds=0.1
)
await manager.start()

async def dead_worker():
raise RuntimeError("dead worker")

dead_task = asyncio.create_task(dead_worker(), name="dead-bg-worker")
await asyncio.sleep(0)
async with manager._state_lock:
manager._workers.append(dead_task)

await manager.enqueue(get_req(1), SubmissionMode.TEST, sub_id=99)
await manager.queue.join()

assert (
mock.call(99, status="succeeded", last_heartbeat=mock.ANY)
in db_context.upsert_submission_job_status.call_args_list
)

await manager.stop()


@pytest.mark.asyncio
async def test_run_one_initial_status_failure_marks_failed(mock_backend):
db_context = mock_backend.db
statuses = []

def fake_upsert(sub_id, status=None, error=None, last_heartbeat=None):
statuses.append((status, error))
if status == "running":
raise RuntimeError("database unavailable")
return sub_id

db_context.upsert_submission_job_status = mock.Mock(side_effect=fake_upsert)
db_context.update_heartbeat_if_active = mock.Mock()
mock_backend.submit_full = mock.AsyncMock()

manager = BackgroundSubmissionManager(
mock_backend, min_workers=1, max_workers=1, idle_seconds=0.1
)
await manager.start()

await manager.enqueue(get_req(1), SubmissionMode.TEST, sub_id=123)
await manager.queue.join()

assert ("pending", None) in statuses
assert ("running", None) in statuses
assert any(status == "failed" and error == "database unavailable" for status, error in statuses)
mock_backend.submit_full.assert_not_called()

async with manager._state_lock:
assert len(manager._workers) == 1

await manager.stop()


@pytest.mark.asyncio
async def test_hacked_submission_sets_hacked_status(mock_backend):
db_context = mock_backend.db
Expand Down
Loading