diff --git a/.github/workflows/python-tests.yml b/.github/workflows/python-tests.yml index 085a43c..f6c517b 100644 --- a/.github/workflows/python-tests.yml +++ b/.github/workflows/python-tests.yml @@ -13,7 +13,7 @@ jobs: build: strategy: matrix: - python-version: ["3.10.15", "3.11", "3.12", "3.13"] + python-version: ["3.11", "3.12", "3.13"] os: - "ubuntu-latest" - "windows-latest" diff --git a/README.md b/README.md index 5e636d9..4ba2827 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,13 @@ A namespace package for [ezmsg](https://github.com/iscoe/ezmsg) to visualize running graphs and data. -The data visualization is highly fragile. Expect bugs. +Key features: + +* **Graph visualization** - Visualize ezmsg graph topologies +* **Data visualization** - Real-time data plotting and monitoring +* **Debug tools** - Tools for debugging ezmsg pipelines + +> The data visualization is highly fragile. Expect bugs. ## Installation @@ -16,6 +22,18 @@ On Mac, you should use brew: * `export CFLAGS="-I $(brew --prefix graphviz)/include"` * `export LDFLAGS="-L $(brew --prefix graphviz)/lib"` +On Windows, follow the instructions [here](https://pygraphviz.github.io/documentation/stable/install.html#windows). Short version: + +* Download and install [Visual C/C++](https://visualstudio.microsoft.com/visual-cpp-build-tools/) +* Download and install [graphviz](https://gitlab.com/graphviz/graphviz/-/releases). Get the most recent Windows x64 CMake releases. +* Install pygraphviz in your environment: +``` +python -m pip install --config-settings="--global-option=build_ext" + --config-settings="--global-option=-IC:\Program Files\Graphviz\include" + --config-settings="--global-option=-LC:\Program Files\Graphviz\lib" + pygraphviz +``` + ### Release Install the latest release from pypi with: `pip install ezmsg-tools` (or `uv add ...` or `poetry add ...`). @@ -24,26 +42,36 @@ More than likely, you will want to include at least one of the extras when insta `pip install "ezmsg-tools[all]"` -### Development Version +Or install the latest development version: -If you intend to edit `ezmsg-tools` then please refer to the [Developers](#developers) section below. +`pip install "ezmsg-tools[all] @ git+https://github.com/ezmsg-org/ezmsg-tools.git@dev"` -You can add the development version of ezmsg-tools directly from GitHub: +## Getting Started -* Using `pip`: `pip install git+https://github.com/ezmsg-org/ezmsg-tools.git@dev` -* Using `poetry`: `poetry add "git+https://github.com/ezmsg-org/ezmsg-tools.git@dev"` -* Using `uv`: `uv add git+https://github.com/ezmsg-org/ezmsg-tools --branch dev` +This package includes some entrypoints with useful tools. -You probably want to include the extras when installing the development version: +### ezmsg-signal-monitor -* `pip install "ezmsg-tools[all] @ git+https://github.com/ezmsg-org/ezmsg-tools.git@dev"` +The pipeline must be running on a graph service exposed on the network. For example, first, run the GraphService on an open port: -## Getting Started +`ezmsg --address 127.0.0.1:25978 start` -This package includes some entrypoints with useful tools. +Then run your usual pipeline but make sure it attaches to the graph address by passing `graph_address=("127.0.0.1", 25978)` as a kwarg to `ez.run`. + +While the pipeline is running, you can run the signal-monitor tool with (`uv run`) `ezmsg-signal-monitor --graph-addr 127.0.0.1:25978`. + +This launches a window with graph visualized on the left. Click on a node's output box to get a live visualization on the right side of the screen plotting the data as it leaves that node. Use `a` to toggle auto-scaling. With auto-scaling off, use `-`, and `=` to zoom out and in, respectively. See the [phosphor docs](https://www.ezmsg.org/phosphor/) for the full list of keyboard shortcuts. + +> Currently only 2-D outputs are supported! + +Don't forget to shutdown your graph service when you are done, e.g.: `ezmsg --address 127.0.0.1:25978 shutdown` ### ezmsg-performance-monitor +**DEPRECATED** + +> ezmsg will soon includes a built-in performance monitor that can be used instead of this tool. + This tool operates on logfiles created by ezmsg. Logfiles will automatically be created when running a pipeline containing nodes decorated with `ezmsg.sigproc.util.profile.profile_subpub`, and if the `EZMSG_LOGLEVEL` environment variable is set to DEBUG. The logfiles will be created in `~/.ezmsg/profile/ezprofiler.log` by default but this can be changed with the `EZMSG_PROFILE` environment variable. @@ -52,25 +80,11 @@ You can decorate other nodes with `ezmsg.sigproc.util.profile.profile_subpub` to During a run with profiling enabled, the logfiles will be created in the specified location. You may wish to additionally create a graph file: (`uv run`) `EZMSG_LOGLEVEL=WARN ezmsg mermaid > ~/.ezmsg/profile/ezprofiler.mermaid` -During or after a pipeline run with profiling enabled, you can run (`uv run `) `performance-monitor` to visualize the performance of the nodes in the pipeline. +During or after a pipeline run with profiling enabled, you can run (`uv run `) `ezmsg-performance-monitor` to visualize the performance of the nodes in the pipeline. > Unlike `signal-monitor`, this tool does not require the pipeline to attach to an existing graph service because it relies exclusively on the logfile. -### ezmsg-signal-monitor - -The pipeline must be running on a graph service exposed on the network. For example, first, run the GraphService on an open port: - -`ezmsg --address 127.0.0.1:25978 start` - -Then run your usual pipeline but make sure it attaches to the graph address by passing `graph_address=("127.0.0.1", 25978)` as a kwarg to `ez.run`. - -While the pipeline is running, you can run the signal-monitor tool with (`uv run`) `signal-monitor --graph-addr 127.0.0.1:25978`. - -This launches a window with graph visualized on the left. Click on a node's output box to get a live visualization on the right side of the screen plotting the data as it leaves that node. Use `a` to toggle auto-scaling. With auto-scaling off, use `-`, and `=` to zoom out and in, respectively. - -> Currently only 2-D outputs are supported! - -Don't forget to shutdown your graph service when you are done, e.g.: `ezmsg --address 127.0.0.1:25978 shutdown` +> This performance monitor is soon to be deprecated in favor of monitoring tools built-in to ezmsg. ## Developers @@ -80,9 +94,9 @@ We use [`uv`](https://docs.astral.sh/uv/getting-started/installation/) for devel 2. Fork ezmsg-tools and clone your fork to your local computer. 3. Open a terminal and `cd` to the cloned folder. 4. Make sure `pygraphviz` [pre-requisites](#pre-requisites) are installed. - * On mac: `export CFLAGS="-I $(brew --prefix graphviz)/include"` and `export LDFLAGS="-L $(brew --prefix graphviz)/lib"` -5. `uv sync --all-extras --python 3.10` to create a .venv and install ezmsg-tools including dev and test dependencies. -6. After editing code and making commits, Run the test suite before making a PR: `uv run pytest` +5. `uv sync --all-extras` to create a .venv and install ezmsg-tools including dev and test dependencies. +6. (Optional) Install pre-commit hooks: `uv run pre-commit install` +7. After editing code and making commits, Run the test suite before making a PR: `uv run pytest` ## Troubleshooting diff --git a/docs/source/index.md b/docs/source/index.md new file mode 100644 index 0000000..8fff9d6 --- /dev/null +++ b/docs/source/index.md @@ -0,0 +1,16 @@ +```{include} ../../README.md +``` + +## Documentation + +```{toctree} +:maxdepth: 2 +:caption: Contents: + +api/index +``` + +## Indices and tables + +- {ref}`genindex` +- {ref}`modindex` diff --git a/docs/source/index.rst b/docs/source/index.rst deleted file mode 100644 index a34d4ab..0000000 --- a/docs/source/index.rst +++ /dev/null @@ -1,69 +0,0 @@ -ezmsg.tools -=========== - -Tools to visualize running graphs and data in ezmsg. - -Overview --------- - -``ezmsg-tools`` tools to visualize running graphs and data in ezmsg. - -Key features: - -* **Graph visualization** - Visualize ezmsg graph topologies -* **Data visualization** - Real-time data plotting and monitoring -* **Debug tools** - Tools for debugging ezmsg pipelines - - -.. note:: - The data visualization is highly fragile. Expect bugs. - - -Installation ------------- - -Install from PyPI: - -.. code-block:: bash - - pip install ezmsg-tools - -Or install the latest development version: - -.. code-block:: bash - - pip install git+https://github.com/ezmsg-org/ezmsg-tools@main - -Dependencies -^^^^^^^^^^^^ - -Core dependencies: - -* ``ezmsg`` -* ``various visualization libraries`` - -Quick Start ------------ - -For general ezmsg tutorials and guides, visit `ezmsg.org `_. - -For package-specific documentation: - -* **API Reference** - See :doc:`api/index` for complete API documentation -* **README** - See the `GitHub repository `_ for usage examples - -Documentation -------------- - -.. toctree:: - :maxdepth: 2 - :caption: Contents: - - api/index - - -Indices and tables ------------------- - -* :ref:`genindex` -* :ref:`modindex` diff --git a/pyproject.toml b/pyproject.toml index 4f9d347..e8991ca 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,18 +6,20 @@ authors = [ ] license = "MIT" readme = "README.md" -requires-python = ">=3.10.15" +requires-python = ">=3.11" dynamic = ["version"] dependencies = [ - "ezmsg>=3.6.1", + "ezmsg>=3.6.2", "numpy>=1.26.0", + "typer>=0.24.1", ] [dependency-groups] dev = [ "pre-commit>=4.0.0", "scipy>=1.14.1", - "ezmsg-sigproc>=2.0.0", + "ezmsg-sigproc>=2.18.0", + "ezmsg-simbiophys>=1.4.1", {include-group = "lint"}, {include-group = "test"}, ] @@ -25,8 +27,8 @@ lint = [ "ruff>=0.12.9", ] test = [ - "ezmsg-sigproc>=1.6.0", "pytest>=8.3.3", + "ezmsg-simbiophys>=1.4.1", ] docs = [ "sphinx>=7.0", @@ -45,17 +47,20 @@ perfmon = [ "typer>=0.15.1", "pygtail>=0.14.0", "dash-bootstrap-components>=1.6.0", - "ezmsg-baseproc", + "ezmsg-baseproc>=1.1.0", ] sigmon = [ - "pygame>=2.6.1", + "PySide6>=6.7", "pygraphviz>=1.14", "typer>=0.15.1", + "phosphor>=0.2", + "pandas", + "ezmsg-qt", ] [project.scripts] -ezmsg-performance-monitor = "ezmsg.tools.perfmon.main:main" -ezmsg-signal-monitor = "ezmsg.tools.sigmon.main:main" +ezmsg-performance-monitor = "ezmsg.tools.perfmon.cli:main" +ezmsg-signal-monitor = "ezmsg.tools.sigmon.cli:main" [build-system] requires = ["hatchling", "hatch-vcs"] @@ -89,4 +94,5 @@ known-third-party = ["ezmsg"] [tool.uv.sources] # Uncomment to use development version of ezmsg from git -ezmsg = { git = "https://github.com/ezmsg-org/ezmsg.git", branch = "feature/profiling" } +#ezmsg = { git = "https://github.com/ezmsg-org/ezmsg.git", branch = "dev" } +ezmsg-qt = { git = "https://github.com/ezmsg-org/ezmsg-qt.git", branch = "dynamic_subscriber" } diff --git a/scripts_nbs/generator/eeg_generator_graph.py b/scripts_nbs/generator/eeg_generator_graph.py new file mode 100644 index 0000000..e341359 --- /dev/null +++ b/scripts_nbs/generator/eeg_generator_graph.py @@ -0,0 +1,74 @@ +"""Test graph for sigmon development. + +Generates synthetic EEG (time-domain) and routes a copy through +Window + Spectrum to produce frequency-domain data. Both branches +terminate at a no-op Sink so the topics exist in the graph for +sigmon to subscribe to. + +Usage: + uv run python scripts/eeg_generator_graph.py +""" + +import ezmsg.core as ez +from ezmsg.sigproc.spectrum import Spectrum, SpectrumSettings +from ezmsg.sigproc.window import Window, WindowSettings +from ezmsg.simbiophys import EEGSynth, EEGSynthSettings +from ezmsg.util.messages.axisarray import AxisArray + + +class Sink(ez.Unit): + """Consumes messages and does nothing.""" + + INPUT_SIGNAL = ez.InputStream(AxisArray) + + @ez.subscriber(INPUT_SIGNAL) + async def on_message(self, msg: AxisArray) -> None: + pass + + +def main() -> None: + eeg = EEGSynth( + EEGSynthSettings( + fs=2000.0, + n_time=100, + n_ch=16, + alpha_freq=10.5, + ) + ) + + win = Window( + WindowSettings( + axis="time", + window_dur=0.5, + window_shift=0.2, + ) + ) + + spec = Spectrum( + SpectrumSettings( + axis="time", + ) + ) + + time_sink = Sink() + freq_sink = Sink() + + ez.run( + components={ + "EEG": eeg, + "WIN": win, + "SPEC": spec, + "TIME_SINK": time_sink, + "FREQ_SINK": freq_sink, + }, + connections=( + (eeg.OUTPUT_SIGNAL, time_sink.INPUT_SIGNAL), + (eeg.OUTPUT_SIGNAL, win.INPUT_SIGNAL), + (win.OUTPUT_SIGNAL, spec.INPUT_SIGNAL), + (spec.OUTPUT_SIGNAL, freq_sink.INPUT_SIGNAL), + ), + ) + + +if __name__ == "__main__": + main() diff --git a/scripts_nbs/profiler/ecog_preproc.py b/scripts_nbs/profiler/ecog_preproc.py index a9bddbe..0613c58 100644 --- a/scripts_nbs/profiler/ecog_preproc.py +++ b/scripts_nbs/profiler/ecog_preproc.py @@ -7,9 +7,9 @@ from ezmsg.sigproc.downsample import Downsample from ezmsg.sigproc.scaler import AdaptiveStandardScaler from ezmsg.sigproc.slicer import Slicer -from ezmsg.sigproc.synth import EEGSynth from ezmsg.sigproc.wavelets import CWT, MinPhaseMode from ezmsg.sigproc.window import Anchor +from ezmsg.simbiophys.eeg import EEGSynth from ezmsg.util.terminate import TerminateOnTotal diff --git a/scripts_nbs/profiler/mp_demo.py b/scripts_nbs/profiler/mp_demo.py deleted file mode 100644 index 6d2c06a..0000000 --- a/scripts_nbs/profiler/mp_demo.py +++ /dev/null @@ -1,44 +0,0 @@ -import sys - -import ezmsg.core as ez -from ezmsg.sigproc.synth import Counter, CounterSettings -from ezmsg.util.terminate import TerminateOnTotal, TerminateOnTotalSettings - -sys.path.append("..") -from nodes.dummy import Dummy, DummySettings - - -def main( - do_multi: bool = False, - sleep_time: float = 0.005, - fs: float = 10.0, - run_duration: float = 35.0, -): - n_msgs = int(run_duration * fs) - comps = { - "SOURCE": Counter(CounterSettings(n_time=1, fs=fs, dispatch_rate="realtime")), - "DUMMY1": Dummy(DummySettings(mean=sleep_time, stddev=0.0)), - "DUMMY2": Dummy(DummySettings(mean=sleep_time, stddev=0.0)), - "SINK": TerminateOnTotal(TerminateOnTotalSettings(total=n_msgs * 2)), - } - conns = ( - (comps["SOURCE"].OUTPUT_SIGNAL, comps["DUMMY1"].INPUT_SIGNAL), - (comps["SOURCE"].OUTPUT_SIGNAL, comps["DUMMY2"].INPUT_SIGNAL), - (comps["DUMMY1"].OUTPUT_SIGNAL, comps["SINK"].INPUT_MESSAGE), - (comps["DUMMY2"].OUTPUT_SIGNAL, comps["SINK"].INPUT_MESSAGE), - ) - ez.run( - components=comps, - connections=conns, - # graph_address=("127.0.0.1", 25978), - process_components=(comps["DUMMY2"],) if do_multi else (), - ) - - -if __name__ == "__main__": - try: - import typer - - typer.run(main) - except ModuleNotFoundError: - main() diff --git a/src/ezmsg/tools/dag.py b/src/ezmsg/tools/dag.py index 2a7f636..c196b51 100644 --- a/src/ezmsg/tools/dag.py +++ b/src/ezmsg/tools/dag.py @@ -1,4 +1,5 @@ import asyncio +import logging import typing from collections import defaultdict from typing import TYPE_CHECKING @@ -10,31 +11,49 @@ if TYPE_CHECKING: import pygraphviz +logger = logging.getLogger(__name__) -def get_graph(graph_address: typing.Tuple[str, int]) -> "pygraphviz.AGraph": + +def get_graph(graph_address: typing.Tuple[str, int], timeout: float = 5.0) -> "pygraphviz.AGraph": import pygraphviz as pgv # Create a graphviz object with our graph components as nodes and our connections as edges. G = pgv.AGraph(name="ezmsg-graphviz", strict=False, directed=True) G.graph_attr["label"] = "ezmsg-graphviz" G.graph_attr["rankdir"] = "TB" - # G.graph_attr["outputorder"] = "edgesfirst" - # G.graph_attr["ratio"] = "1.0" - # G.node_attr["shape"] = "circle" - # G.node_attr["fixedsize"] = "true" G.node_attr["fontsize"] = "8" G.node_attr["fontcolor"] = "#000000" G.node_attr["style"] = "filled" G.edge_attr["color"] = "#0000FF" G.edge_attr["style"] = "setlinewidth(2)" - # Get the dag from the GraphService + # Get the dag from the GraphService with timeout loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) - dag = loop.run_until_complete(ez.graphserver.GraphService(address=graph_address).dag()) + + async def dag_with_timeout(): + return await asyncio.wait_for(ez.graphserver.GraphService(address=graph_address).dag(), timeout=timeout) + + try: + dag = loop.run_until_complete(dag_with_timeout()) + except asyncio.TimeoutError: + logger.warning(f"GraphService.dag() timed out after {timeout}s - returning empty graph") + return G + except (ConnectionRefusedError, OSError) as e: + logger.warning(f"GraphService.dag() connection failed: {e} - returning empty graph") + return G + except Exception as e: + logger.warning(f"GraphService.dag() failed: {type(e).__name__}: {e} - returning empty graph") + return G + finally: + loop.close() # Retrieve a description of the graph graph_connections = dag.graph.copy() + + # Handle empty graph - return early with minimal valid AGraph + if not graph_connections: + return G # graph_connections is a dict with format # { # 'apath/unit/port': {'some/other_unit/port', 'yet/another/unit/port'}, @@ -42,14 +61,14 @@ def get_graph(graph_address: typing.Tuple[str, int]) -> "pygraphviz.AGraph": # where 'port' might be a pub (out) stream or a sub (input) stream. b_refresh_dag = False + _monitor_topics = {"VISBUFF/INPUT_SIGNAL", "SIGMON/INPUT"} for k, v in graph_connections.items(): - if "VISBUFF/INPUT_SIGNAL" in v: - b_refresh_dag = True - loop.run_until_complete( - ez.graphserver.GraphService(address=graph_address).disconnect(k, "VISBUFF/INPUT_SIGNAL") - ) + for sub in v: + if any(mt in sub for mt in _monitor_topics): + b_refresh_dag = True + asyncio.run(ez.graphserver.GraphService(address=graph_address).disconnect(k, sub)) if b_refresh_dag: - dag = loop.run_until_complete(ez.graphserver.GraphService(address=graph_address).dag()) + dag = asyncio.run(ez.graphserver.GraphService(address=graph_address).dag()) graph_connections = dag.graph.copy() # Generate UUID node names diff --git a/src/ezmsg/tools/perfmon/main.py b/src/ezmsg/tools/perfmon/cli.py similarity index 99% rename from src/ezmsg/tools/perfmon/main.py rename to src/ezmsg/tools/perfmon/cli.py index 89776a8..e79e368 100644 --- a/src/ezmsg/tools/perfmon/main.py +++ b/src/ezmsg/tools/perfmon/cli.py @@ -279,5 +279,9 @@ def update_hist(data): return fig, f"Sum: {proc_sum:.2f} ms" +def main() -> None: + app.run(debug=True) + + if __name__ == "__main__": app.run(debug=True) diff --git a/src/ezmsg/tools/proc.py b/src/ezmsg/tools/proc.py deleted file mode 100644 index e982109..0000000 --- a/src/ezmsg/tools/proc.py +++ /dev/null @@ -1,87 +0,0 @@ -import asyncio -import multiprocessing -import multiprocessing.connection -import typing - -import ezmsg.core as ez - -from .shmem.shmem import ShMemCircBuff, ShMemCircBuffSettings - -BUF_DUR = 3.0 - - -class EzMonitorProcess(multiprocessing.Process): - def __init__( - self, - settings: ShMemCircBuffSettings, - topic: str, - address: typing.Optional[typing.Tuple[str, int]] = None, - ) -> None: - super().__init__() - self._settings = settings - self._topic = topic - self._graph_address = address - - def run(self) -> None: - comps = {"SHMEM": ShMemCircBuff(self._settings)} - conns = ((self._topic, comps["SHMEM"].INPUT_SIGNAL),) - ez.run(components=comps, connections=conns, graph_address=self._graph_address) - - -class EZProcManager: - """ - Manages the subprocess that runs an ezmsg pipeline comprising a single ShMemCircBuff unit connected to a pipeline. - The unit must be parameterized with the correct shared memory name. - We do not actually interact with the shared memory in this class. See .mirror.EzmsgShmMirror. - """ - - def __init__(self, graph_ip: str, graph_port: int, buf_dur: float = BUF_DUR) -> None: - self._graph_addr: typing.Tuple[str, int] = (graph_ip, graph_port) - self._buf_dur = buf_dur - self._proc = None - self._node_path: typing.Optional[str] = None - self._remote_conn, self._conn = multiprocessing.Pipe() - - @property - def node_path(self) -> str: - return self._node_path - - @property - def conn(self) -> typing.Optional[multiprocessing.connection.Connection]: - return self._conn - - def reset(self, node_path: typing.Optional[str]) -> None: - self._cleanup_subprocess() - self._node_path = node_path - self._init_subprocess() - - def cleanup(self): - self._cleanup_subprocess() - - def _cleanup_subprocess(self) -> None: - if self._proc is not None: - self._conn.send("quit") - # Close process - self._proc.join() - self._proc = None - - # TODO: Somehow closing the proc doesn't always clear the VISBUFF connections. - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - loop.run_until_complete( - ez.graphserver.GraphService(address=self._graph_addr).disconnect( - self._node_path, "VISBUFF/INPUT_SIGNAL" - ) - ) - - def _init_subprocess(self, axis: str = "time"): - unit_settings = ShMemCircBuffSettings( - shmem_name="buff_" + self._node_path, - buf_dur=self._buf_dur, - conn=self._remote_conn, - axis=axis, - ) - self._proc = EzMonitorProcess(unit_settings, self._node_path, address=self._graph_addr) - self._proc.start() - - # if self._rend_conn.poll(): msg = self._rend_conn.recv() diff --git a/src/ezmsg/tools/sigmon/cli.py b/src/ezmsg/tools/sigmon/cli.py new file mode 100644 index 0000000..67aa456 --- /dev/null +++ b/src/ezmsg/tools/sigmon/cli.py @@ -0,0 +1,250 @@ +"""Sigmon — real-time ezmsg graph inspector using Qt + phosphor.""" + +import logging +import sys + +import numpy as np +import typer +from ezmsg.qt import EzDynamicSubscriber, EzGuiBridge +from phosphor import ( + ScatterConfig, + ScatterWidget, + SpectrumConfig, + SpectrumWidget, + SweepConfig, + SweepWidget, +) +from PySide6.QtCore import Qt +from PySide6.QtGui import QKeySequence, QShortcut +from PySide6.QtWidgets import QApplication, QMainWindow, QSplitter, QWidget + +from ezmsg.tools.sigmon.dag_widget import DAGWidget + +logger = logging.getLogger(__name__) + +GRAPH_IP = "127.0.0.1" +GRAPH_PORT = 25978 + + +def _extract_channel_meta(msg) -> tuple[list[str] | None, np.ndarray | None]: + """Extract channel labels and 2D positions from AxisArray channel metadata. + + Returns ``(labels, positions)`` where *positions* is ``(n_channels, 2)`` + float32 or ``None`` if no location fields are present. + """ + if "ch" not in msg.dims: + return None, None + + ch_axis = msg.get_axis("ch") + ch_data = getattr(ch_axis, "data", None) + if ch_data is None or ch_data.dtype.names is None: + return None, None + + # Labels + labels = None + if "label" in ch_data.dtype.names: + labels = [str(v) for v in ch_data["label"]] + + # Positions — need at least x and y + positions = None + if "x" in ch_data.dtype.names and "y" in ch_data.dtype.names: + x = ch_data["x"].astype(np.float32) + y = ch_data["y"].astype(np.float32) + if np.any(x != 0) or np.any(y != 0): + positions = np.column_stack([x, y]) + + return labels, positions + + +class SigmonWindow(QMainWindow): + def __init__( + self, + graph_address: tuple[str, int], + parent: QWidget | None = None, + ) -> None: + super().__init__(parent) + self.setWindowTitle("ezmsg Signal Monitor") + self._graph_address = graph_address + + # Dynamic subscriber — switches topics when the user clicks a graph node. + self._data_sub = EzDynamicSubscriber(parent=self) + self._data_sub.connect(self._on_data) + + # Layout: splitter with DAG on left, plot on right. + self._splitter = QSplitter(Qt.Orientation.Horizontal) + self.setCentralWidget(self._splitter) + + self._dag_widget = DAGWidget(graph_address) + self._dag_widget.node_selected.connect(self._on_node_selected) + self._splitter.addWidget(self._dag_widget) + + self._plot_widget: QWidget = QWidget() # placeholder + self._splitter.addWidget(self._plot_widget) + self._splitter.setStretchFactor(0, 1) + self._splitter.setStretchFactor(1, 3) + + self._first_message = True + + # Channel metadata cached from the first message of each topic. + self._channel_labels: list[str] | None = None + self._channel_positions: np.ndarray | None = None + # Cached parameters for rebuilding the primary (sweep/spectrum) widget. + self._primary_config: SweepConfig | SpectrumConfig | None = None + self._showing_scatter = False + + # Hotkey: "M" toggles scatter/map view (only effective when positions exist). + shortcut = QShortcut(QKeySequence("M"), self) + shortcut.activated.connect(self._toggle_scatter) + + def _on_node_selected(self, topic: str) -> None: + self._data_sub.subscribe(topic) + self._first_message = True + self._channel_labels = None + self._channel_positions = None + self._primary_config = None + self._showing_scatter = False + + def _on_data(self, msg) -> None: + """Handle a message delivered by the dynamic subscriber.""" + if self._first_message: + self._channel_labels, self._channel_positions = _extract_channel_meta(msg) + self._create_plot_widget(msg) + self._first_message = False + + self._push_message(msg) + + def _create_plot_widget(self, msg) -> None: + """Detect data type from AxisArray dims and create the appropriate widget.""" + labels = self._channel_labels + + if "time" in msg.dims: + time_axis = msg.get_axis("time") + srate = 1.0 / time_axis.gain + time_idx = msg.get_axis_idx("time") + n_samples = msg.shape[time_idx] + n_channels = msg.data.size // n_samples + + config = SweepConfig( + n_channels=n_channels, + srate=srate, + channel_labels=labels, + ) + widget = SweepWidget(config) + + elif "freq" in msg.dims: + freq_axis = msg.get_axis("freq") + freq_idx = msg.get_axis_idx("freq") + n_bins = msg.shape[freq_idx] + srate = 2.0 * freq_axis.gain * n_bins + n_channels = msg.data.size // n_bins + + config = SpectrumConfig( + n_channels=n_channels, + srate=srate, + n_bins=n_bins, + channel_labels=labels, + ) + widget = SpectrumWidget(config) + + else: + logger.warning("Unknown AxisArray dims: %s — defaulting to sweep", msg.dims) + n_samples = msg.shape[0] + n_channels = msg.data.size // n_samples if n_samples > 0 else 1 + config = SweepConfig( + n_channels=n_channels, + srate=1000.0, + channel_labels=labels, + ) + widget = SweepWidget(config) + + self._primary_config = config + self._showing_scatter = False + self._replace_plot_widget(widget) + + def _toggle_scatter(self) -> None: + """Toggle between primary (sweep/spectrum) view and scatter/map view.""" + if self._channel_positions is None: + return # no locations — nothing to toggle + + if self._showing_scatter: + # Switch back to primary widget. + if isinstance(self._primary_config, SweepConfig): + widget = SweepWidget(self._primary_config) + elif isinstance(self._primary_config, SpectrumConfig): + widget = SpectrumWidget(self._primary_config) + else: + return + self._showing_scatter = False + else: + config = ScatterConfig( + positions=self._channel_positions, + channel_labels=self._channel_labels, + ) + widget = ScatterWidget(config) + self._showing_scatter = True + + self._replace_plot_widget(widget) + + def _replace_plot_widget(self, widget: QWidget) -> None: + """Swap the right pane of the splitter.""" + sizes = self._splitter.sizes() + old = self._splitter.widget(1) + if old is not None: + old.setParent(None) + old.deleteLater() + self._splitter.insertWidget(1, widget) + self._splitter.setStretchFactor(1, 3) + self._splitter.setSizes(sizes) + self._plot_widget = widget + + def _push_message(self, msg) -> None: + """Extract 2D data from AxisArray and push to the plot widget.""" + widget = self._plot_widget + + if isinstance(widget, SweepWidget): + time_idx = msg.get_axis_idx("time") if "time" in msg.dims else 0 + n_samples = msg.shape[time_idx] + n_channels = msg.data.size // n_samples if n_samples > 0 else 1 + data_2d = np.moveaxis(msg.data, time_idx, 0).reshape(n_samples, n_channels) + widget.push_data(data_2d.astype(np.float32)) + + elif isinstance(widget, SpectrumWidget): + freq_idx = msg.get_axis_idx("freq") if "freq" in msg.dims else 0 + n_bins = msg.shape[freq_idx] + n_channels = msg.data.size // n_bins if n_bins > 0 else 1 + data_2d = np.moveaxis(msg.data, freq_idx, 0).reshape(n_bins, n_channels) + widget.push_data(data_2d.astype(np.float32)) + + elif isinstance(widget, ScatterWidget): + # Scatter expects (n_channels,) or (n_samples, n_channels). + if len(msg.shape) > 1: + targ_idx = 0 + if "time" in msg.dims or "freq" in msg.dims: + targ_idx = msg.get_axis_idx("time") if "time" in msg.dims else msg.get_axis_idx("freq") + n_items = msg.shape[targ_idx] + n_channels = msg.data.size // n_items if n_items > 0 else 1 + data_2d = np.moveaxis(msg.data, targ_idx, 0).reshape(n_items, n_channels) + else: + data_2d = msg.data.reshape(1, msg.data.size) + widget.push_data(data_2d.astype(np.float32)) + + +def _run( + graph_addr: str = ":".join((GRAPH_IP, str(GRAPH_PORT))), +) -> None: + graph_ip, graph_port_str = graph_addr.split(":") + graph_address = (graph_ip, int(graph_port_str)) + + app = QApplication.instance() or QApplication(sys.argv) + window = SigmonWindow(graph_address) + window.showMaximized() + with EzGuiBridge(app, graph_address=graph_address): + app.exec() + + +def main() -> None: + typer.run(_run) + + +if __name__ == "__main__": + main() diff --git a/src/ezmsg/tools/sigmon/dag_widget.py b/src/ezmsg/tools/sigmon/dag_widget.py new file mode 100644 index 0000000..5c9a741 --- /dev/null +++ b/src/ezmsg/tools/sigmon/dag_widget.py @@ -0,0 +1,160 @@ +"""Qt widget for interactive DAG visualization.""" + +import logging +import tempfile +import xml.etree.ElementTree as ET +from pathlib import Path + +import pandas as pd +from PySide6.QtCore import QEvent, QObject, QSize, Signal +from PySide6.QtGui import QMouseEvent, QPixmap, QResizeEvent +from PySide6.QtWidgets import ( + QApplication, + QLabel, + QPushButton, + QScrollArea, + QVBoxLayout, + QWidget, +) + +from ..dag import get_graph, pgv2pd + +logger = logging.getLogger(__name__) + + +class DAGWidget(QWidget): + """Interactive DAG graph view with click-to-select nodes.""" + + node_selected = Signal(str) + + def __init__( + self, + graph_address: tuple[str, int], + parent: QWidget | None = None, + ) -> None: + super().__init__(parent) + self._graph_address = graph_address + self._node_df = pd.DataFrame(columns=["name", "x", "y", "upstream"]) + self._full_pixmap: QPixmap | None = None + self._display_scale: float = 1.0 + + layout = QVBoxLayout(self) + layout.setContentsMargins(0, 0, 0, 0) + + self._refresh_btn = QPushButton("Refresh") + self._refresh_btn.clicked.connect(self._refresh_graph) + layout.addWidget(self._refresh_btn) + + self._scroll_area = QScrollArea() + self._scroll_area.setWidgetResizable(False) + layout.addWidget(self._scroll_area) + + self._label = QLabel() + self._label.setScaledContents(False) + self._scroll_area.setWidget(self._label) + self._scroll_area.viewport().installEventFilter(self) + + self._refresh_graph() + + def _refresh_graph(self) -> None: + G = get_graph(self._graph_address) + + if len(G.nodes()) == 0: + self._node_df = pd.DataFrame(columns=["name", "x", "y", "upstream"]) + self._full_pixmap = None + self._label.setText("No graph connected") + return + + G.layout(prog="dot") + + # Render SVG for coordinate extraction. + svg_path = Path(tempfile.gettempdir()) / "ezmsg-graphviz.svg" + G.draw(svg_path, format="svg:cairo") + + # Extract node positions from the layout. + self._node_df = pgv2pd(G) + + # Parse SVG viewBox for native graphviz coordinate dimensions — + # avoids Qt's DPI-dependent SVG rasterization (wrong on Retina). + tree = ET.parse(svg_path) + root = tree.getroot() + viewbox = root.get("viewBox") + if viewbox: + parts = viewbox.split() + svg_width = float(parts[2]) + svg_height = float(parts[3]) + else: + # Fallback: parse width/height attributes (e.g. "400pt"). + svg_width = float("".join(c for c in root.get("width", "1") if c.isdigit() or c == ".")) + svg_height = float("".join(c for c in root.get("height", "1") if c.isdigit() or c == ".")) + + # Render PNG at a resolution that fills the screen width, so the + # image stays sharp even when the panel is expanded to full size. + screen = QApplication.primaryScreen() + if screen is not None and svg_width > 0: + target_px = screen.size().width() * screen.devicePixelRatio() + dpi = max(96, int(target_px * 72 / svg_width)) + else: + dpi = 96 + G.graph_attr["dpi"] = str(dpi) + + img_path = Path(tempfile.gettempdir()) / "ezmsg-graphviz.png" + G.draw(img_path) + self._full_pixmap = QPixmap(str(img_path)) + + # Compute coordinate scale from SVG (points) to PNG (pixels). + x_scale = self._full_pixmap.width() / svg_width if svg_width else 1.0 + y_scale = self._full_pixmap.height() / svg_height if svg_height else 1.0 + + self._node_df["x"] *= x_scale + self._node_df["y"] *= y_scale + # Invert Y so origin is top-left (matching PNG pixel coords). + self._node_df["y"] = self._full_pixmap.height() - self._node_df["y"] + + self._scale_and_display() + + def _scale_and_display(self) -> None: + if self._full_pixmap is None: + return + available_width = self._scroll_area.viewport().width() + if available_width <= 0: + available_width = self.width() + scaled = self._full_pixmap.scaledToWidth(max(available_width, 1)) + self._display_scale = scaled.width() / self._full_pixmap.width() if self._full_pixmap.width() else 1.0 + self._label.setPixmap(scaled) + self._label.adjustSize() + + def resizeEvent(self, event: QResizeEvent) -> None: + super().resizeEvent(event) + self._scale_and_display() + + def eventFilter(self, obj: QObject, event: QEvent) -> bool: + if obj is self._scroll_area.viewport() and event.type() == QEvent.Type.MouseButtonPress: + self._handle_click(event) + return True + return super().eventFilter(obj, event) + + def _handle_click(self, event: QMouseEvent) -> None: + if self._full_pixmap is None or len(self._node_df) == 0: + return + + # Viewport click + scroll offset = position in scaled image. + vx = event.position().x() + vy = event.position().y() + sx = self._scroll_area.horizontalScrollBar().value() + sy = self._scroll_area.verticalScrollBar().value() + img_x = vx + sx + img_y = vy + sy + + # Convert to full-pixmap coords. + px_x = img_x / self._display_scale if self._display_scale > 0 else img_x + px_y = img_y / self._display_scale if self._display_scale > 0 else img_y + + # Find nearest node via Euclidean distance. + dist_sq = (self._node_df["x"] - px_x) ** 2 + (self._node_df["y"] - px_y) ** 2 + min_idx = dist_sq.argmin() + topic = str(self._node_df.iloc[min_idx]["upstream"]) + self.node_selected.emit(topic) + + def sizeHint(self) -> QSize: + return QSize(300, 600) diff --git a/src/ezmsg/tools/sigmon/main.py b/src/ezmsg/tools/sigmon/main.py deleted file mode 100644 index 4113038..0000000 --- a/src/ezmsg/tools/sigmon/main.py +++ /dev/null @@ -1,95 +0,0 @@ -import pygame -import pygame.locals -import typer - -from ezmsg.tools.proc import EZProcManager -from ezmsg.tools.shmem.shmem_mirror import EZShmMirror -from ezmsg.tools.sigmon.ui.dag import VisDAG -from ezmsg.tools.sigmon.ui.timeseries import Sweep - -GRAPH_IP = "127.0.0.1" -GRAPH_PORT = 25978 -PLOT_DUR = 2.0 - - -def main( - graph_addr: str = ":".join((GRAPH_IP, str(GRAPH_PORT))), -): - pygame.init() - - # Screen - screen = pygame.display.set_mode((0, 0), pygame.FULLSCREEN) - screen_width, screen_height = screen.get_size() - screen = pygame.display.set_mode((screen_width, screen_height), pygame.locals.RESIZABLE) - screen.fill((0, 0, 0)) # Fill the screen with black - - # Interactive ezmsg graph. Its purpose is to show the graph (w/ scrolling) - # and get the name of the node that was clicked on and that we want to visualize. - graph_ip, graph_port = graph_addr.split(":") - graph_port = int(graph_port) - dag = VisDAG(screen_height=screen_height, graph_ip=graph_ip, graph_port=graph_port) - - # ezmsg process manager -- the process runs a mini ezmsg pipeline - # that attaches a single node to an existing pipeline. We don't - # know the attachment point yet, so we do not start the pipeline. - ez_proc_man = EZProcManager( - graph_ip=graph_ip, - graph_port=graph_port, - buf_dur=PLOT_DUR, - ) - - # We need an in-process mirror to the out-of-process ShMemCircBuff - # in `ez_proc_man`. It initializes in a waiting state because the - # remote unit does not exist until EZProcManager starts up. - mirror = EZShmMirror() - - # Data Plotter. Puts a surface on the screen, plots 2D lines - # with some basic auto-scaling. ezmsg-graphviz renderers are - # highly customized to use the mirror object as it uses - # the mirror's shmem buffer as its own rendering buffer. - sweep = Sweep( - mirror, - (screen_width - dag.size[0], screen_height), - tl_offset=(dag.size[0], 0), - dur=PLOT_DUR, - ) - - running = True - while running: - new_node_path = None - for event in pygame.event.get(): - if event.type == pygame.QUIT: - running = False - break - elif event.type == pygame.KEYDOWN: - # Keyboard presses - if event.key == pygame.K_ESCAPE: - # Close the application when Esc key is pressed - running = False - break - new_node_path = dag.handle_event(event) - _ = sweep.handle_event(event) # Currently does nothing - - if new_node_path is not None and new_node_path != ez_proc_man.node_path: - # Clicked on a new node to monitor - ez_proc_man.reset(new_node_path) # Close subprocess and start a new one. - sweep.reset(new_node_path) - - # Remaining initialization must wait until subprocess has seen data. - - # Refresh / scroll dag image if required - rects = dag.update(screen) - - # Update the sweep plot (internally it uses shmem) - rects += sweep.update(screen) - - pygame.display.update(rects) - - sweep.reset(None) - ez_proc_man.cleanup() - - pygame.quit() - - -if __name__ == "__main__": - typer.run(main) diff --git a/src/ezmsg/tools/sigmon/ui/__init__.py b/src/ezmsg/tools/sigmon/ui/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/src/ezmsg/tools/sigmon/ui/base.py b/src/ezmsg/tools/sigmon/ui/base.py deleted file mode 100644 index 891e9ca..0000000 --- a/src/ezmsg/tools/sigmon/ui/base.py +++ /dev/null @@ -1,96 +0,0 @@ -import typing - -import pygame - -from ...shmem.shmem_mirror import EZShmMirror - -PLOT_BG_COLOR = (255, 255, 255) -PLOT_FONT_COLOR = (0, 0, 0) -PLOT_DUR = 2.0 - - -class BaseRenderer(pygame.Surface): - """ - This is an abstract class representing a pygame.Surface that also manages - a subprocess running ezmsg as well as shared memory to communicate with that - subprocess. - """ - - def __init__( - self, - mirror: EZShmMirror, - *args, - tl_offset: typing.Tuple[int, int] = (0, 0), - **kwargs, - ): - super().__init__(*args, **kwargs) - self._mirror = mirror - self._tl_offset: typing.Tuple[int, int] = tl_offset - self._plot_rect = self.get_rect(topleft=self._tl_offset) - self._node_path: typing.Optional[str] = None - self._font = pygame.font.Font(None, 36) # Default font and size 36 - self._refresh_text = True - self._plot_needs_reset = True - self.fill(PLOT_BG_COLOR) - - def handle_event(self, event: pygame.event.Event): - if event.type in [pygame.MOUSEWHEEL, pygame.MOUSEBUTTONDOWN]: - pass - # TODO: Check if mouse_pos is over self - # mouse_pos = pygame.mouse.get_pos() - # TODO: Respond to mouse. - - def _reset_plot(self): - raise NotImplementedError - - def reset(self, node_path: typing.Optional[str]) -> None: - self._mirror.disconnect() - self.fill(PLOT_BG_COLOR) - if node_path is not None and node_path != self._node_path: - self._node_path = node_path - self._refresh_text = True - self._plot_needs_reset = True - # This is all we can do until the metadata becomes available. - - def _print_node_path(self, surface: pygame.Surface) -> pygame.Rect: - # TEMP: Render the node_path - meta = self._mirror.meta - if meta is not None: - self._mirror.connect("buff_" + self._node_path) - - import numpy as np - - buf_shape = meta.shape[: meta.ndim] - buf_dtype = np.dtype(meta.dtype).name - src_str = f"{self._node_path} {buf_shape}, {buf_dtype}" - else: - src_str = self._node_path - text_surface = self._font.render( - src_str, - True, - PLOT_FONT_COLOR, - ) - text_rect = text_surface.get_rect(midtop=self._plot_rect.midtop) - # Draw a background rectangle for the text - pygame.draw.rect(surface, (200, 200, 200), self._plot_rect) - # Draw the actual text - surface.blit(text_surface, text_rect) - pygame.display.update(text_rect) - return text_rect - - def update(self, surface: pygame.Surface) -> typing.List[pygame.Rect]: - rects = [] - - if not self._mirror.connected and self._node_path is not None: - self._mirror.connect("buff_" + self._node_path) - - if self._mirror.connected and self._plot_needs_reset: - self._reset_plot() - self._plot_needs_reset = False - self._refresh_text = True - - if self._refresh_text: - rects.append(self._print_node_path(surface)) - self._refresh_text = False - - return rects diff --git a/src/ezmsg/tools/sigmon/ui/dag.py b/src/ezmsg/tools/sigmon/ui/dag.py deleted file mode 100644 index 213db93..0000000 --- a/src/ezmsg/tools/sigmon/ui/dag.py +++ /dev/null @@ -1,101 +0,0 @@ -import sys -import tempfile -import typing -from pathlib import Path - -import pygame -import pygame.event - -from ...dag import get_graph, pgv2pd - -SCROLL_STEP = 50 - - -class VisDAG: - def __init__( - self, - tl_offset: typing.Tuple[int, int] = (0, 0), - screen_height: int = 1440, - graph_ip: str = "127.0.0.1", - graph_port: int = 25978, - ): - self._screen_height = screen_height - G = get_graph((graph_ip, graph_port)) - G.layout(prog="dot") - # Create SVG to get the correct coordinates - svg_path = Path(tempfile.gettempdir()) / "ezmsg-graphviz.svg" - G.draw(svg_path, format="svg:cairo") - # Get the graph details as dataframe - self._node_df = pgv2pd(G) - # Unfortunately, pygame cannot render svg very well, so we render as png for display - img_path = Path(tempfile.gettempdir()) / "ezmsg-graphviz.png" - G.draw(img_path) - self._image = pygame.image.load(img_path) - self._image_rect = self._image.get_rect(topleft=tl_offset) - self._min_y = screen_height - self._image_rect.height - - if sys.platform == "win32": - # On Windows, it looks like we need to scale the svg coordinates by the window dims. - x_scale = self._image_rect.width / (self._node_df["x"].max() + self._node_df["x"].min()) - y_scale = self._image_rect.height / (self._node_df["y"].max() + self._node_df["y"].min()) - else: - # Scale the coordinates in the dataframe by png size / svg size - _svg = pygame.image.load(svg_path) - x_scale = self._image_rect.width / _svg.get_rect().width - y_scale = self._image_rect.height / _svg.get_rect().height - - self._node_df["y"] *= y_scale - self._node_df["x"] *= x_scale - # Invert the y coordinates of the image so origin is top-left, like in pygame - self._node_df["y"] = self._image_rect.height - self._node_df["y"] - - self._image_y = 0 # Initial offset of the image - self._b_update = True - - @property - def size(self) -> typing.Tuple[int, int]: - return self._image_rect.size - - def handle_event(self, event: pygame.event.Event) -> typing.Optional[str]: - clicked_node_path = None - if event.type in [pygame.MOUSEWHEEL, pygame.MOUSEBUTTONDOWN]: - mouse_pos = pygame.mouse.get_pos() - if self._image_rect.left <= mouse_pos[0] <= self._image_rect.right: - if event.type == pygame.MOUSEWHEEL: - # The image of the dag is scrolled. `_image_y` is the offset for the top of the image. - # We scroll down (shift image up) by making the top of the image more negative. - if event.y > 0: - # scroll graph up - self._image_y = min(0, self._image_y + SCROLL_STEP) - elif event.y < 0: - # scroll graph down - self._image_y = max(self._min_y, self._image_y - SCROLL_STEP) - self._b_update = True - - elif event.type == pygame.MOUSEBUTTONDOWN: - # Mouse events - if event.button == 1: - # Clicked on the screen over the DAG. - # Calculate the position of the click from screen coordinates to DAG coordinates. - # (On a Mac at least) - # The mouse coordinates are top-left is origin, right is positive x, down is positive y. - # The dag _image_rect is left: 0, right: width, top: 0, bottom: height. - # We must add -1 * _image_y to compensate for the pixels of the image shifted up off the screen. - graph_pos = ( - mouse_pos[0] - self._image_rect.left, - mouse_pos[1] - self._image_rect.top - self._image_y, - ) - min_row = ( - (self._node_df.x - graph_pos[0]) ** 2 + (self._node_df.y - graph_pos[1]) ** 2 - ).argmin() - clicked_node_path = f"{self._node_df.iloc[min_row]['upstream']}" - return clicked_node_path - - def update(self, surface: pygame.Surface) -> typing.List[pygame.Rect]: - res = [] - if self._b_update: - surface.blit(self._image, (0, self._image_y)) - pygame.display.update(self._image_rect) - res.append(self._image_rect) - self._b_update = False - return res diff --git a/src/ezmsg/tools/sigmon/ui/timeseries.py b/src/ezmsg/tools/sigmon/ui/timeseries.py deleted file mode 100644 index 4ee9817..0000000 --- a/src/ezmsg/tools/sigmon/ui/timeseries.py +++ /dev/null @@ -1,263 +0,0 @@ -import typing - -import numpy as np -import numpy.typing as npt -import pygame - -from .base import PLOT_DUR, BaseRenderer - -PLOT_BG_COLOR = (255, 255, 255) -PLOT_LINE_COLOR = (0, 0, 0) -INIT_Y_RANGE = 1e4 # Raw units per channel - - -def running_stats( - fs: float, - time_constant: float = PLOT_DUR, -) -> typing.Generator[typing.Tuple[npt.NDArray, npt.NDArray], npt.NDArray, None]: - arr_in = np.array([]) - tuple_out = (np.array([]), np.array([])) - means = vars_means = vars_sq_means = None - alpha = 1 - np.exp(-1 / (fs * time_constant)) - - def _ew_update(arr, prev, _alpha): - if np.all(prev == 0): - return arr - # return _alpha * arr + (1 - _alpha) * prev - # Micro-optimization: sub, mult, add (below) is faster than sub, mult, mult, add (above) - return prev + _alpha * (arr - prev) - - while True: - arr_in = yield tuple_out - - if means is None: - vars_sq_means = np.zeros_like(arr_in[0], dtype=float) - vars_means = np.zeros_like(arr_in[0], dtype=float) - means = np.zeros_like(arr_in[0], dtype=float) - - for sample in arr_in: - # Update step - vars_means = _ew_update(sample, vars_means, alpha) - vars_sq_means = _ew_update(sample**2, vars_sq_means, alpha) - means = _ew_update(sample, means, alpha) - tuple_out = means, np.sqrt(vars_sq_means - vars_means**2) - - -class Sweep(BaseRenderer): - def __init__( - self, - *args, - yrange: float = INIT_Y_RANGE, - autoscale: bool = True, - dur: float = PLOT_DUR, - **kwargs, - ): - super().__init__(*args, **kwargs) - self._y_range = yrange - self._autoscale = autoscale - self._dur = dur - self._xvec = np.array([]) # Vector of indices - self._plot_x_idx = 0 # index into xvec where the next plot starts. - self._read_index = 0 # Index into shmem buffer - self._stats_gen: typing.Optional[typing.Generator] = None - self._last_y_vec: typing.Optional[npt.NDArray] = None - self._x2px: float = 1.0 - - def _reset_plot(self): - # Reset plot parameters - meta = self._mirror.meta - plot_samples = int(self._dur * meta.srate) - self._xvec = np.arange(plot_samples) - self._x2px = self._plot_rect.width / plot_samples - self._stats_gen = running_stats(meta.srate, time_constant=self._dur) - self._stats_gen.send(None) # Prime the generator - self._plot_x_idx = 0 - self._read_index = 0 - self._last_y_vec = None - # Blank the surface - self.fill(PLOT_BG_COLOR) - pygame.display.update(self._plot_rect) - if meta.ndim > 2: - # Monkey-patch udpate func to do nothing - print("timeseries does not support > 2 dimensions") - - def update_with_copy(self, surface: pygame.Surface) -> typing.List[pygame.Rect]: - rects = super().update(surface) - data = self._mirror.auto_view(n=None) - if data is not None: - if self._autoscale: - # Check if the scale has changed. - means, stds = self._stats_gen.send(data) - new_y_range = 3 * np.mean(stds) - b_reset_scale = new_y_range < 0.8 * self._y_range or new_y_range > 1.2 * self._y_range - if b_reset_scale: - self._y_range = new_y_range - # TODO: We should also redraw the entire plot at the new scale. - # However, we do not have a copy of all visible data. - - n_chs = data.shape[1] - yoffsets = (np.arange(n_chs) + 0.5) * self._y_range - y_span = (n_chs + 1) * self._y_range - y2px = self._plot_rect.height / y_span - - # Establish the minimum rectangle for the update - n_samps = data.shape[0] - dat_offset = 0 - while n_samps > 0: - x0 = self._plot_x_idx - b_prepend = x0 != 0 and self._last_y_vec is not None - if b_prepend: - xvec = self._xvec[x0 - 1 : x0 + n_samps] - if dat_offset == 0: - _data = np.concatenate([self._last_y_vec, data[: xvec.shape[0] - 1]], axis=0) - else: - _data = data[dat_offset - 1 : dat_offset + xvec.shape[0] - 1] - else: - xvec = self._xvec[x0 : x0 + n_samps] - _data = data[dat_offset : dat_offset + xvec.shape[0]] - - # Identify the rectangle that we will be plotting over. - _rect_x = ( - int(xvec[0] * self._x2px), - int(np.ceil(xvec[-1] * self._x2px)), - ) - update_rect = pygame.Rect( - (_rect_x[0], 0), - (_rect_x[1] - _rect_x[0] + 5, self._plot_rect.height), - ) - - # Blank the rectangle with bgcolor - pygame.draw.rect(self, PLOT_BG_COLOR, update_rect) - - # Plot the lines - if _data.shape[0] > 1: - for ch_ix, ch_offset in enumerate(yoffsets): - plot_dat = _data[:, ch_ix] + ch_offset - try: - xy = np.column_stack((xvec * self._x2px, plot_dat * y2px)) - except ValueError: - print("DEBUG") - pygame.draw.lines(self, PLOT_LINE_COLOR, 0, xy) - - # Blit the surface - _rect = surface.blit( - self, - ( - self._tl_offset[0] + update_rect.x, - self._tl_offset[1], - ), - update_rect, - ) - rects.append(_rect) - - n_new = (xvec.shape[0] - 1) if b_prepend else xvec.shape[0] - self._plot_x_idx += n_new - self._plot_x_idx %= self._xvec.shape[0] - n_samps -= n_new - dat_offset += n_new - self._last_y_vec = _data[-1:].copy() - - # Draw cursor - curs_x = int(((self._plot_x_idx + 1) % self._xvec.shape[0]) * self._x2px) - curs_rect = pygame.draw.line( - self, - PLOT_LINE_COLOR, - (curs_x, 0), - (curs_x, self._plot_rect.height), - ) - _rect = surface.blit( - self, - ( - self._tl_offset[0] + curs_rect.x, - self._tl_offset[1], - ), - curs_rect, - ) - rects.append(_rect) - - return rects - - def update(self, surface: pygame.Surface) -> typing.List[pygame.Rect]: - rects = super().update(surface) - - res, b_overflow = self._mirror.auto_view() - if res.size == 0: - return rects - - if self._plot_needs_reset: - return rects - - meta = self._mirror.meta - if meta.ndim > 2: - return rects - n_samples = res.shape[0] - - t_slice = np.s_[max(0, self._read_index - 1) : self._read_index + n_samples] - if self._autoscale: - means, stds = self._stats_gen.send(self._mirror.buffer[t_slice]) - new_y_range = max(3 * np.mean(stds), 1e-12) - b_reset_scale = new_y_range < 0.8 * self._y_range or new_y_range > 1.2 * self._y_range - if b_reset_scale: - self._y_range = new_y_range - t_slice = np.s_[:] - - n_chs = res.shape[1] - yoffsets = (np.arange(n_chs) + 0.5) * self._y_range - y_span = (n_chs + 1) * self._y_range - y2px = self._plot_rect.height / y_span - - _x = self._xvec[t_slice] - _rect_x = (int(_x[0] * self._x2px), int(np.ceil(_x[-1] * self._x2px))) - update_rect = pygame.Rect( - (_rect_x[0], 0), - (_rect_x[1] - _rect_x[0] + 5, self._plot_rect.height), - ) - # Blank the rectangle with bgcolor - pygame.draw.rect(self, PLOT_BG_COLOR, update_rect) - - # Plot the lines - for ch_ix, ch_offset in enumerate(yoffsets): - plot_dat = self._mirror.buffer[t_slice, ch_ix] + ch_offset - try: - xy = np.column_stack((_x * self._x2px, plot_dat * y2px)) - except ValueError: - print(_x.shape, plot_dat.shape) - raise - pygame.draw.lines(self, PLOT_LINE_COLOR, 0, xy) - - self._read_index = (self._read_index + n_samples) % self._xvec.shape[0] - - # Draw cursor - curs_x = int(((self._read_index + 1) % self._xvec.shape[0]) * self._x2px) - pygame.draw.line( - self, - PLOT_LINE_COLOR, - (curs_x, 0), - (curs_x, self._plot_rect.height), - ) - - # Update - _rect = surface.blit( - self, - ( - self._tl_offset[0] + update_rect.x, - self._tl_offset[1], - ), - update_rect, - ) - rects.append(_rect) - return rects - - def handle_event(self, event: pygame.event.Event): - if event.type in [pygame.KEYDOWN]: - if event.key == pygame.K_a: - # Toggle autoscale with 'a' key - self._autoscale = not self._autoscale - elif not self._autoscale: - # When autoscale is disabled, allow manual y-range adjustment - if event.key == pygame.K_MINUS: - # Zoom-Out: Increase y-range by 20% with '-' key - self._y_range *= 1.2 - elif event.key == pygame.K_EQUALS: - # Zoom-in: Decrease y-range by 20% with '=' key - self._y_range *= 0.8 diff --git a/tests/test_shmem_mirror.py b/tests/test_shmem_mirror.py index 98b70de..30305ff 100644 --- a/tests/test_shmem_mirror.py +++ b/tests/test_shmem_mirror.py @@ -9,7 +9,7 @@ import ezmsg.core as ez import numpy as np import pytest -from ezmsg.sigproc.synth import Clock, Oscillator +from ezmsg.simbiophys.eeg import EEGSynth from ezmsg.util.messagecodec import message_log from ezmsg.util.messagelogger import MessageLogger from ezmsg.util.messages.axisarray import AxisArray @@ -85,15 +85,13 @@ def app(file_path) -> None: n_messages = int(TOTAL_DURATION * chunk_rate) comps = { - "CLOCK": Clock(dispatch_rate=chunk_rate), - "SYNTH": Oscillator(n_time=chunk_size, fs=SR, n_ch=CHANNEL_COUNT, dispatch_rate="ext_clock"), + "SYNTH": EEGSynth(fs=SR, n_time=chunk_size, n_ch=CHANNEL_COUNT), "CRAZY": CrazyUnit(change_after=n_messages // 2, change_type=change_type), "SINK": ShMemCircBuff(SHMEM_NAME, 2.0, conn=None, axis="time"), "LOGGER": MessageLogger(output=file_path), "TERM": TerminateOnTotal(total=n_messages), } conns = ( - (comps["CLOCK"].OUTPUT_SIGNAL, comps["SYNTH"].INPUT_SIGNAL), (comps["SYNTH"].OUTPUT_SIGNAL, comps["CRAZY"].INPUT_SIGNAL), (comps["CRAZY"].OUTPUT_SIGNAL, comps["SINK"].INPUT_SIGNAL), (comps["CRAZY"].OUTPUT_SIGNAL, comps["LOGGER"].INPUT_MESSAGE), diff --git a/tests/test_shmem_sink.py b/tests/test_shmem_sink.py index f14de46..1b5192b 100644 --- a/tests/test_shmem_sink.py +++ b/tests/test_shmem_sink.py @@ -6,7 +6,7 @@ import ezmsg.core as ez import numpy as np import pytest -from ezmsg.sigproc.synth import Clock, Oscillator +from ezmsg.simbiophys.eeg import EEGSynth from ezmsg.util.messagecodec import message_log from ezmsg.util.messagelogger import MessageLogger from ezmsg.util.messages.axisarray import AxisArray @@ -82,15 +82,13 @@ def test_shmem_change(change_type: str): file_path.unlink(missing_ok=True) comps = { - "CLOCK": Clock(dispatch_rate=100.0), - "SYNTH": Oscillator(n_time=10, fs=1000, n_ch=n_ch, dispatch_rate="ext_clock"), + "SYNTH": EEGSynth(fs=1000, n_time=10, n_ch=n_ch), "CRAZY": CrazyUnit(change_after=n_messages // 2, change_type=change_type), "SINK": ShMemCircBuff(SHMEM_NAME, 2.0, conn=None, axis="time"), "LOGGER": MessageLogger(output=file_path), "TERM": TerminateOnTotal(total=n_messages), } conns = ( - (comps["CLOCK"].OUTPUT_SIGNAL, comps["SYNTH"].INPUT_SIGNAL), (comps["SYNTH"].OUTPUT_SIGNAL, comps["CRAZY"].INPUT_SIGNAL), (comps["CRAZY"].OUTPUT_SIGNAL, comps["LOGGER"].INPUT_MESSAGE), (comps["LOGGER"].OUTPUT_MESSAGE, comps["TERM"].INPUT_MESSAGE),