Source code for visionsim.utils.progress

from __future__ import annotations

import multiprocessing
from functools import partial

from rich.progress import (
    BarColumn,
    Progress,
    ProgressColumn,
    TaskID,
    TaskProgressColumn,
    TextColumn,
    TimeRemainingColumn,
)

from visionsim.types import UpdateFn


[docs] class ElapsedProgress(Progress):
[docs] @classmethod def get_default_columns(cls) -> tuple[ProgressColumn, ...]: """Overrides `rich.progress.Progress`'s default columns to enable showing elapsed time when finished.""" return ( TextColumn("[progress.description]{task.description}"), BarColumn(), TaskProgressColumn(), TimeRemainingColumn(elapsed_when_finished=True), )
[docs] class PoolProgress(Progress): """Convenience wrapper around rich's ``Progress`` to enable progress bars when using multiple processes. All progressbar updates are carried out by the main process, and worker processes communicate their state via a callback obtained when a task gets added. Example: .. code-block:: python import multiprocessing def long_task(tick, min_len=50, max_len=200): import random, time length = random.randint(min_len, max_len) tick(total=length) for _ in range(length): time.sleep(0.01) tick(advance=1) if __name__ == "__main__": with multiprocessing.Pool(4) as pool, PoolProgress() as progress: for i in range(25): tick = progress.add_task(f"Task: {i}") pool.apply_async(long_task, (tick, )) progress.wait() pool.close() pool.join() """
[docs] def __init__(self, *args, auto_visible=True, description="[green]Total progress:", **kwargs) -> None: """Initialize a ``PoolProgress`` instance. Note: All other \\*args and \\*\\*kwargs are passed as is to `rich.progress.Progress <https://rich.readthedocs.io/en/latest/reference/progress.html#rich.progress.Progress>`_. Args: auto_visible (bool, optional): if true, automatically hides tasks that have not started or finished tasks. Defaults to True. description (str, optional): text description for the overall progress. Defaults to "[green]Total progress:". """ self.manager: multiprocessing.managers.SyncManager | None = None self.progress_queue: multiprocessing.Queue | None = None self.overall_taskid: TaskID | None = None self.inflight_tasks: set[TaskID] = set() self.completed_tasks: set[TaskID] = set() self.auto_visible = auto_visible self.description = description super().__init__(*args, **kwargs)
[docs] @classmethod def get_default_columns(cls) -> tuple[ProgressColumn, ...]: """Overrides ``rich.progress.Progress``\'s default columns to enable showing elapsed time when finished.""" return ( TextColumn("[progress.description]{task.description}"), BarColumn(), TaskProgressColumn(), TimeRemainingColumn(elapsed_when_finished=True), )
[docs] @staticmethod def update_task(progress: multiprocessing.Queue[dict], task_id: TaskID, **kwargs) -> None: progress.put(kwargs | dict(task_id=task_id))
def _update_task(self, task_update: dict): """Actually perform the queued task update""" if self.auto_visible: task_update |= dict(visible=True) self.update(**task_update)
[docs] def task_percentage(self, task_id: TaskID) -> float: with self._lock: return self._tasks[task_id].percentage
[docs] def task_finished(self, task_id: TaskID) -> bool: with self._lock: return self._tasks[task_id].finished
[docs] def add_task(self, *args, **kwargs) -> UpdateFn: # type: ignore[override] """Same as `Progress.add_task` except it returns a callback to update the task instead of the task-id. The returned callback is roughly equivalent to `Progress.update` with it's first argument (the task-id) already filled out, except calling it will not immediately update the task's status. The main process will perform the update asynchronously. """ if self.progress_queue is None: raise RuntimeError("Cannot add task if context manager has not been entered.") if self.auto_visible: kwargs["visible"] = False task_id = super().add_task(*args, **kwargs) update = partial(self.update_task, self.progress_queue, task_id) self.inflight_tasks.add(task_id) return update
def __enter__(self): self.start() self.manager = multiprocessing.Manager().__enter__() self.overall_taskid = super().add_task(self.description) self.progress_queue = self.manager.Queue() return self def __exit__(self, exc_type, exc_val, exc_tb): self.manager.__exit__(exc_type, exc_val, exc_tb) return super().__exit__(exc_type, exc_val, exc_tb)
[docs] def wait(self) -> None: """Block and wait for tasks to finish. Note: This is what actually updates the progress bars, if not called before exiting the with-block no progress will be reported, and processes might be killed. """ if self.progress_queue is None or self.overall_taskid is None: raise RuntimeError("Cannot wait on tasks outside of context manager.") while self.inflight_tasks: while not self.progress_queue.empty(): task_update = self.progress_queue.get() task_id = task_update.get("task_id") self._update_task(task_update) if self.task_finished(task_id): if self.auto_visible: self.update(task_id, visible=False) if task_id in self.inflight_tasks: self.completed_tasks.add(task_id) self.inflight_tasks.remove(task_id) task_progress = sum(self.task_percentage(t) / 100 for t in self.inflight_tasks) self.update( self.overall_taskid, completed=len(self.completed_tasks) + task_progress, total=len(self.completed_tasks) + len(self.inflight_tasks), ) self.update( self.overall_taskid, completed=len(self.completed_tasks), total=len(self.completed_tasks), )