-
Notifications
You must be signed in to change notification settings - Fork 54
Add pre-test hook support for MegatronBridge workload #957
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -73,13 +73,76 @@ def gen_exec_command(self) -> str: | |||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| launcher_py = (mbridge_repo_path / "scripts" / "performance" / "setup_experiment.py").absolute() | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| parts = self._build_launcher_parts(args, tdef, mbridge_repo_path, launcher_py) | ||||||||||||||||||||||||||||||||||||||
| pre_hook_sbatch_path: Optional[Path] = None | ||||||||||||||||||||||||||||||||||||||
| base_slurm_params: str = "" | ||||||||||||||||||||||||||||||||||||||
| if self.test_run.pre_test: | ||||||||||||||||||||||||||||||||||||||
| pre_hook_sbatch_path = self._gen_pre_hook_sbatch() | ||||||||||||||||||||||||||||||||||||||
| parts = self._build_launcher_parts(args, tdef, mbridge_repo_path, launcher_py, include_slurm_params=False) | ||||||||||||||||||||||||||||||||||||||
| base_slurm_params = ";".join(self._collect_additional_slurm_params()) | ||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||
| parts = self._build_launcher_parts(args, tdef, mbridge_repo_path, launcher_py) | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| launcher_python = str((venv_path / "bin" / "python").absolute()) | ||||||||||||||||||||||||||||||||||||||
| full_cmd = self._wrap_launcher_for_job_id_and_quiet_output(" ".join(parts), launcher_python) | ||||||||||||||||||||||||||||||||||||||
| full_cmd = self._wrap_launcher_for_job_id_and_quiet_output( | ||||||||||||||||||||||||||||||||||||||
| " ".join(parts), | ||||||||||||||||||||||||||||||||||||||
| launcher_python, | ||||||||||||||||||||||||||||||||||||||
| pre_hook_sbatch_path=pre_hook_sbatch_path, | ||||||||||||||||||||||||||||||||||||||
| base_slurm_params=base_slurm_params, | ||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| self._write_command_to_file(full_cmd, self.test_run.output_path) | ||||||||||||||||||||||||||||||||||||||
| return full_cmd | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| def _collect_additional_slurm_params(self) -> list[str]: | ||||||||||||||||||||||||||||||||||||||
| """Return the additional_slurm_params list (without dependency).""" | ||||||||||||||||||||||||||||||||||||||
| params: list[str] = [] | ||||||||||||||||||||||||||||||||||||||
| if self.system.gpus_per_node and self.system.supports_gpu_directives: | ||||||||||||||||||||||||||||||||||||||
| params.append(f"gpus-per-node={self.system.gpus_per_node}") | ||||||||||||||||||||||||||||||||||||||
| params.append(f"gres=gpu:{self.system.gpus_per_node}") | ||||||||||||||||||||||||||||||||||||||
| _, node_list = self.get_cached_nodes_spec() | ||||||||||||||||||||||||||||||||||||||
| if node_list: | ||||||||||||||||||||||||||||||||||||||
| params.append(f"nodelist={','.join(node_list)}") | ||||||||||||||||||||||||||||||||||||||
| elif self.test_run.exclude_nodes: | ||||||||||||||||||||||||||||||||||||||
| params.append(f"exclude={','.join(self.test_run.exclude_nodes)}") | ||||||||||||||||||||||||||||||||||||||
| for source in (self.system.extra_srun_args, self.test_run.extra_srun_args): | ||||||||||||||||||||||||||||||||||||||
| if source: | ||||||||||||||||||||||||||||||||||||||
| params.extend(self._parse_srun_args_as_slurm_params(source)) | ||||||||||||||||||||||||||||||||||||||
| return params | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| def _gen_pre_hook_sbatch(self) -> Path: | ||||||||||||||||||||||||||||||||||||||
| """Generate a standalone sbatch script for pre-hook tests; return its path.""" | ||||||||||||||||||||||||||||||||||||||
| pre_hook_output = self.test_run.output_path / "pre_hook" | ||||||||||||||||||||||||||||||||||||||
| pre_hook_output.mkdir(parents=True, exist_ok=True) | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| for tr in self.test_run.pre_test.test_runs: | ||||||||||||||||||||||||||||||||||||||
| tr.num_nodes = self.test_run.nnodes | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| pre_hook_cmds = self.gen_pre_test(self.test_run.pre_test, self.test_run.output_path) | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| num_nodes, node_list = self.get_cached_nodes_spec() | ||||||||||||||||||||||||||||||||||||||
| sbatch_lines = [ | ||||||||||||||||||||||||||||||||||||||
| "#!/bin/bash", | ||||||||||||||||||||||||||||||||||||||
| "# Pre-hook sbatch generated by CloudAI", | ||||||||||||||||||||||||||||||||||||||
| f"#SBATCH --job-name=pre_hook_{self.job_name()}", | ||||||||||||||||||||||||||||||||||||||
| f"#SBATCH --output={pre_hook_output.absolute() / 'stdout.txt'}", | ||||||||||||||||||||||||||||||||||||||
| f"#SBATCH --error={pre_hook_output.absolute() / 'stderr.txt'}", | ||||||||||||||||||||||||||||||||||||||
| f"#SBATCH --partition={self.system.default_partition}", | ||||||||||||||||||||||||||||||||||||||
| ] | ||||||||||||||||||||||||||||||||||||||
| if self.system.account: | ||||||||||||||||||||||||||||||||||||||
| sbatch_lines.append(f"#SBATCH --account={self.system.account}") | ||||||||||||||||||||||||||||||||||||||
| if node_list: | ||||||||||||||||||||||||||||||||||||||
| sbatch_lines.append(f"#SBATCH --nodelist={','.join(node_list)}") | ||||||||||||||||||||||||||||||||||||||
| elif num_nodes: | ||||||||||||||||||||||||||||||||||||||
| sbatch_lines.append(f"#SBATCH --nodes={num_nodes}") | ||||||||||||||||||||||||||||||||||||||
| if self.test_run.time_limit: | ||||||||||||||||||||||||||||||||||||||
| sbatch_lines.append(f"#SBATCH --time={self.test_run.time_limit}") | ||||||||||||||||||||||||||||||||||||||
| sbatch_lines.extend(["", pre_hook_cmds]) | ||||||||||||||||||||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🎯 Functional Correctness | 🔴 Critical | ⚡ Quick win Exit non-zero when the pre-hook success check fails.
🐛 Proposed fix- sbatch_lines.extend(["", pre_hook_cmds])
+ sbatch_lines.extend(
+ [
+ "",
+ pre_hook_cmds,
+ "",
+ 'if [ "${PRE_TEST_SUCCESS:-0}" -ne 1 ]; then',
+ ' echo "Pre-hook failed; not releasing dependent training job." >&2',
+ " exit 1",
+ "fi",
+ ]
+ )📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| sbatch_path = self.test_run.output_path / "pre_hook_sbatch_script.sh" | ||||||||||||||||||||||||||||||||||||||
| sbatch_path.write_text("\n".join(sbatch_lines)) | ||||||||||||||||||||||||||||||||||||||
| sbatch_path.chmod(sbatch_path.stat().st_mode | stat.S_IXUSR) | ||||||||||||||||||||||||||||||||||||||
| return sbatch_path | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| def store_test_run(self) -> None: | ||||||||||||||||||||||||||||||||||||||
| test_cmd = self.gen_exec_command() | ||||||||||||||||||||||||||||||||||||||
| trd = TestRunDetails.from_test_run(self.test_run, test_cmd=test_cmd, full_cmd=test_cmd) | ||||||||||||||||||||||||||||||||||||||
|
|
@@ -166,12 +229,21 @@ def _normalize_cuda_graph_scope_arg(self, val: Any) -> str: | |||||||||||||||||||||||||||||||||||||
| parts = [p.strip().strip("\"'") for p in s.split(",") if p.strip()] | ||||||||||||||||||||||||||||||||||||||
| return ",".join(parts) | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| def _wrap_launcher_for_job_id_and_quiet_output(self, launcher_cmd: str, launcher_python: str) -> str: | ||||||||||||||||||||||||||||||||||||||
| def _wrap_launcher_for_job_id_and_quiet_output( | ||||||||||||||||||||||||||||||||||||||
| self, | ||||||||||||||||||||||||||||||||||||||
| launcher_cmd: str, | ||||||||||||||||||||||||||||||||||||||
| launcher_python: str, | ||||||||||||||||||||||||||||||||||||||
| pre_hook_sbatch_path: Optional[Path] = None, | ||||||||||||||||||||||||||||||||||||||
| base_slurm_params: str = "", | ||||||||||||||||||||||||||||||||||||||
| ) -> str: | ||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||
| Run the Megatron-Bridge launcher quietly and ensure CloudAI can parse a job ID. | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| CloudAI's SlurmRunner expects stdout to include "Submitted batch job <id>". | ||||||||||||||||||||||||||||||||||||||
| This writes a readable wrapper script (with section breaks) into the test output directory, then runs it. | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| If pre_hook_sbatch_path is provided, the pre-hook sbatch is submitted first and its job ID is used as | ||||||||||||||||||||||||||||||||||||||
| a Slurm dependency (afterok) for the main training job, so training only starts if the pre-hook passed. | ||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||
| output_dir = self.test_run.output_path.absolute() | ||||||||||||||||||||||||||||||||||||||
| output_dir.mkdir(parents=True, exist_ok=True) | ||||||||||||||||||||||||||||||||||||||
|
|
@@ -181,6 +253,28 @@ def _wrap_launcher_for_job_id_and_quiet_output(self, launcher_cmd: str, launcher | |||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| container_runtime_exports = self._container_runtime_env_exports() | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| pre_hook_lines: list[str] = [] | ||||||||||||||||||||||||||||||||||||||
| launch_line: str | ||||||||||||||||||||||||||||||||||||||
| if pre_hook_sbatch_path is not None: | ||||||||||||||||||||||||||||||||||||||
| pre_hook_lines = [ | ||||||||||||||||||||||||||||||||||||||
| f'PRE_HOOK_SBATCH="{pre_hook_sbatch_path.absolute()}"', | ||||||||||||||||||||||||||||||||||||||
| 'PRE_HOOK_OUTPUT=$(sbatch "$PRE_HOOK_SBATCH" 2>&1)', | ||||||||||||||||||||||||||||||||||||||
| 'PRE_HOOK_JOB_ID=$(echo "$PRE_HOOK_OUTPUT" | grep -Eo "Submitted batch job [0-9]+" | grep -Eo "[0-9]+" | tail -n1 || true)', # noqa: E501 | ||||||||||||||||||||||||||||||||||||||
| 'if [ -z "$PRE_HOOK_JOB_ID" ]; then', | ||||||||||||||||||||||||||||||||||||||
| ' echo "Failed to submit pre-hook job: $PRE_HOOK_OUTPUT" >&2', | ||||||||||||||||||||||||||||||||||||||
| " exit 1", | ||||||||||||||||||||||||||||||||||||||
| "fi", | ||||||||||||||||||||||||||||||||||||||
| 'echo "Submitted pre-hook batch job $PRE_HOOK_JOB_ID"', | ||||||||||||||||||||||||||||||||||||||
| f'ADDITIONAL_SLURM_PARAMS="{base_slurm_params}"', | ||||||||||||||||||||||||||||||||||||||
| 'if [ -n "$PRE_HOOK_JOB_ID" ]; then', | ||||||||||||||||||||||||||||||||||||||
| ' ADDITIONAL_SLURM_PARAMS="${ADDITIONAL_SLURM_PARAMS};dependency=afterok:${PRE_HOOK_JOB_ID}"', | ||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+268
to
+270
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🎯 Functional Correctness | 🟠 Major | ⚡ Quick win Quote base Slurm params and avoid a leading empty parameter.
🐛 Proposed fix if pre_hook_sbatch_path is not None:
+ quoted_base_slurm_params = shlex.quote(base_slurm_params)
pre_hook_lines = [
f'PRE_HOOK_SBATCH="{pre_hook_sbatch_path.absolute()}"',
'PRE_HOOK_OUTPUT=$(sbatch "$PRE_HOOK_SBATCH" 2>&1)',
@@
"fi",
'echo "Submitted pre-hook batch job $PRE_HOOK_JOB_ID"',
- f'ADDITIONAL_SLURM_PARAMS="{base_slurm_params}"',
- 'if [ -n "$PRE_HOOK_JOB_ID" ]; then',
+ f"ADDITIONAL_SLURM_PARAMS={quoted_base_slurm_params}",
+ 'if [ -n "$ADDITIONAL_SLURM_PARAMS" ]; then',
' ADDITIONAL_SLURM_PARAMS="${ADDITIONAL_SLURM_PARAMS};dependency=afterok:${PRE_HOOK_JOB_ID}"',
+ "else",
+ ' ADDITIONAL_SLURM_PARAMS="dependency=afterok:${PRE_HOOK_JOB_ID}"',
"fi",
"",
]📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||||||||||||||||||||||||||||||||||
| "fi", | ||||||||||||||||||||||||||||||||||||||
| "", | ||||||||||||||||||||||||||||||||||||||
| ] | ||||||||||||||||||||||||||||||||||||||
| launch_line = f'{launcher_cmd} --additional_slurm_params "$ADDITIONAL_SLURM_PARAMS" >>"$LOG" 2>&1 || LAUNCH_RC=$?' # noqa: E501 | ||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||
| launch_line = f'{launcher_cmd} >>"$LOG" 2>&1 || LAUNCH_RC=$?' | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| script_lines = [ | ||||||||||||||||||||||||||||||||||||||
| "#!/usr/bin/env bash", | ||||||||||||||||||||||||||||||||||||||
| "set -o pipefail", | ||||||||||||||||||||||||||||||||||||||
|
|
@@ -195,6 +289,7 @@ def _wrap_launcher_for_job_id_and_quiet_output(self, launcher_cmd: str, launcher | |||||||||||||||||||||||||||||||||||||
| "", | ||||||||||||||||||||||||||||||||||||||
| *container_runtime_exports, | ||||||||||||||||||||||||||||||||||||||
| "", | ||||||||||||||||||||||||||||||||||||||
| *pre_hook_lines, | ||||||||||||||||||||||||||||||||||||||
| ': >"$LOG"', | ||||||||||||||||||||||||||||||||||||||
| "WANDB_INSTALL_RC=0", | ||||||||||||||||||||||||||||||||||||||
| f'{shlex.quote(launcher_python)} -m pip install wandb numpy==1.26.4 >>"$LOG" 2>&1 || WANDB_INSTALL_RC=$?', | ||||||||||||||||||||||||||||||||||||||
|
|
@@ -205,7 +300,7 @@ def _wrap_launcher_for_job_id_and_quiet_output(self, launcher_cmd: str, launcher | |||||||||||||||||||||||||||||||||||||
| "fi", | ||||||||||||||||||||||||||||||||||||||
| "", | ||||||||||||||||||||||||||||||||||||||
| "LAUNCH_RC=0", | ||||||||||||||||||||||||||||||||||||||
| f'{launcher_cmd} >>"$LOG" 2>&1 || LAUNCH_RC=$?', | ||||||||||||||||||||||||||||||||||||||
| launch_line, | ||||||||||||||||||||||||||||||||||||||
| "", | ||||||||||||||||||||||||||||||||||||||
| # Parse job id from Megatron-Bridge output (multiple possible formats) | ||||||||||||||||||||||||||||||||||||||
| # Patterns: "Submitted batch job 694112", "Job id: 694112", "- Job id: 694112", "Job ID: 694112" | ||||||||||||||||||||||||||||||||||||||
|
|
@@ -247,7 +342,12 @@ def _list_or_comma_str(self, val: str | list[str] | None) -> Optional[str]: | |||||||||||||||||||||||||||||||||||||
| raise RuntimeError("Unexpected sweeps list. At this point code expects scalars only") | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| def _build_launcher_parts( # noqa: C901 | ||||||||||||||||||||||||||||||||||||||
| self, args: MegatronBridgeCmdArgs, tdef: MegatronBridgeTestDefinition, repo_path: Path, launcher_py: Path | ||||||||||||||||||||||||||||||||||||||
| self, | ||||||||||||||||||||||||||||||||||||||
| args: MegatronBridgeCmdArgs, | ||||||||||||||||||||||||||||||||||||||
| tdef: MegatronBridgeTestDefinition, | ||||||||||||||||||||||||||||||||||||||
| repo_path: Path, | ||||||||||||||||||||||||||||||||||||||
| launcher_py: Path, | ||||||||||||||||||||||||||||||||||||||
| include_slurm_params: bool = True, | ||||||||||||||||||||||||||||||||||||||
|
Comment on lines
344
to
+350
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 📐 Maintainability & Code Quality | 🔵 Trivial | ⚡ Quick win 🧩 Analysis chain🏁 Script executed: #!/bin/bash
ruff check --select FBT001,FBT002 src/cloudai/workloads/megatron_bridge/slurm_command_gen_strategy.pyRepository: NVIDIA/cloudai Length of output: 1524 Make ♻️ Proposed fix args: MegatronBridgeCmdArgs,
tdef: MegatronBridgeTestDefinition,
repo_path: Path,
launcher_py: Path,
+ *,
include_slurm_params: bool = True,📝 Committable suggestion
Suggested change
🧰 Tools🪛 Ruff (0.15.20)[warning] 344-344: Too many branches (28 > 12) (PLR0912) [warning] 344-344: Too many statements (152 > 50) (PLR0915) [warning] 350-350: Boolean-typed positional argument in function definition (FBT001) [warning] 350-350: Boolean default positional argument in function definition (FBT002) 🤖 Prompt for AI AgentsSource: Linters/SAST tools |
||||||||||||||||||||||||||||||||||||||
| ) -> list[str]: | ||||||||||||||||||||||||||||||||||||||
| fields_set = args.model_fields_set | ||||||||||||||||||||||||||||||||||||||
| force_fields = { | ||||||||||||||||||||||||||||||||||||||
|
|
@@ -451,25 +551,10 @@ def add_field(field: str, flag: str, value: Any) -> None: | |||||||||||||||||||||||||||||||||||||
| add_field("nsys_trace", "--nsys_trace", self._list_or_comma_str(args.nsys_trace)) | ||||||||||||||||||||||||||||||||||||||
| add_field("nsys_extra_args", "--nsys_extra_args", self._list_or_comma_str(args.nsys_extra_args)) | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| additional_slurm_params: list[str] = [] | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| if self.system.gpus_per_node and self.system.supports_gpu_directives: | ||||||||||||||||||||||||||||||||||||||
| additional_slurm_params.append(f"gpus-per-node={self.system.gpus_per_node}") | ||||||||||||||||||||||||||||||||||||||
| additional_slurm_params.append(f"gres=gpu:{self.system.gpus_per_node}") | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| _, node_list = self.get_cached_nodes_spec() | ||||||||||||||||||||||||||||||||||||||
| if node_list: | ||||||||||||||||||||||||||||||||||||||
| nodelist_str = ",".join(node_list) | ||||||||||||||||||||||||||||||||||||||
| additional_slurm_params.append(f"nodelist={nodelist_str}") | ||||||||||||||||||||||||||||||||||||||
| elif self.test_run.exclude_nodes: | ||||||||||||||||||||||||||||||||||||||
| additional_slurm_params.append(f"exclude={','.join(self.test_run.exclude_nodes)}") | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| for source in (self.system.extra_srun_args, self.test_run.extra_srun_args): | ||||||||||||||||||||||||||||||||||||||
| if source: | ||||||||||||||||||||||||||||||||||||||
| additional_slurm_params.extend(self._parse_srun_args_as_slurm_params(source)) | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| if additional_slurm_params: | ||||||||||||||||||||||||||||||||||||||
| parts.extend(["--additional_slurm_params", shlex.quote(";".join(additional_slurm_params))]) | ||||||||||||||||||||||||||||||||||||||
| if include_slurm_params: | ||||||||||||||||||||||||||||||||||||||
| additional_slurm_params = self._collect_additional_slurm_params() | ||||||||||||||||||||||||||||||||||||||
| if additional_slurm_params: | ||||||||||||||||||||||||||||||||||||||
| parts.extend(["--additional_slurm_params", shlex.quote(";".join(additional_slurm_params))]) | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| # Config variant | ||||||||||||||||||||||||||||||||||||||
| add_field("config_variant", "-cv", args.config_variant) | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🎯 Functional Correctness | 🟠 Major | ⚡ Quick win
Preserve excluded nodes in the pre-hook sbatch.
When
get_cached_nodes_spec()returns onlynum_nodes, this script emits--nodesbut dropsself.test_run.exclude_nodes, so the pre-hook can run on nodes the main job excludes.🐛 Proposed fix
if node_list: sbatch_lines.append(f"`#SBATCH` --nodelist={','.join(node_list)}") elif num_nodes: sbatch_lines.append(f"`#SBATCH` --nodes={num_nodes}") + if self.test_run.exclude_nodes: + sbatch_lines.append(f"`#SBATCH` --exclude={','.join(self.test_run.exclude_nodes)}")📝 Committable suggestion
🤖 Prompt for AI Agents
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is a broader issue rather than only num nodes. it also doesn't set distribution, gres/gpu-per-node and some other potential directives. what I suggest as a fix, is to update the base class
SlurmCommandGenStrategywith extract directives generation into a sub-method. So that you could explicitly call it here to write directives into your bufferThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
Length of output: 810
🏁 Script executed:
Length of output: 2675
🏁 Script executed:
#!/bin/bash ast-grep outline src/cloudai/systems/slurm/slurm_command_gen_strategy.pyLength of output: 995
🏁 Script executed:
Length of output: 1164
🏁 Script executed:
Length of output: 5433
@podkidyshev, agreed — that's the better fix. I checkedsrc/cloudai/systems/slurm/slurm_command_gen_strategy.py, and it already has most of this logic centralized in_append_sbatch_directives(Line 375) and_append_nodes_related_directives(Line 411), covering: reservation,--distribution, nodelist/hostfile vs-N,exclude,gpus-per-node/gres,ntasks-per-node,--time, andextra_sbatch_args. The_gen_pre_hook_sbatchmethod inslurm_command_gen_strategy.pycurrently reimplements only a small subset of that (job-name, output/error, partition, account, nodelist/nodes, time_limit), which is why distribution, gres/gpus-per-node, ntasks-per-node, extra_sbatch_args, and exclude are all missing there.The one wrinkle is that
_append_sbatch_directiveshardcodesself.test_run.output_path(for stdout/stderr) andself.job_name(), both of which differ for the pre-hook script (pre_hooksubdir,pre_hook_prefix). So a clean extraction would be something like pulling the node/GPU/misc-directive portion (distribution, nodelist/exclude/hostfile, gpus-per-node/gres, ntasks-per-node, extra_sbatch_args, time_limit) out of_append_sbatch_directivesinto its own method, e.g._append_resource_directives(content: List[str]), that both_write_sbatch_scriptand_gen_pre_hook_sbatchcall directly, while each caller still appends its own job-name/output/error/partition/account lines.Since this touches the shared base class and is broader than the original single-line fix, would you like this done in this PR, or should I open a follow-up issue to track the base-class refactor so both
_gen_pre_hook_sbatchand the main sbatch path stay in sync?✏️ Learnings added
🧠 Learnings used