Coverage for instanovo/common/utils.py: 67%
134 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-12-08 07:26 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2025-12-08 07:26 +0000
1import math
2import os
3import time
4import traceback
5from typing import Any, Dict, List, Optional
7import neptune
8from torch.utils.tensorboard import SummaryWriter
10from instanovo.__init__ import console
11from instanovo.utils.colorlogging import ColorLog
12from instanovo.utils.data_handler import SpectrumDataFrame
14logger = ColorLog(console, __name__).logger
17# Used to record additional training state parameters
18# This is only used for accelerate.save_state and
19# accelerate.load_state. (resuming runs)
20class TrainingState:
21 """Training state for tracking training progress.
23 This class is used by Accelerate to save and load training state during
24 checkpointing and resuming training runs. It tracks the current epoch
25 and global step of training.
26 """
28 def __init__(self) -> None:
29 """Initialize training state with zeroed counters."""
30 self._global_step: int = 0
31 self._epoch: int = 0
33 @property
34 def global_step(self) -> int:
35 """Get the current global step."""
36 return self._global_step
38 @property
39 def epoch(self) -> int:
40 """Get the current epoch."""
41 return self._epoch
43 def state_dict(self) -> dict[str, Any]:
44 """Get the state dictionary for saving.
46 Returns:
47 dict[str, Any]: Dictionary containing the current training state.
48 """
49 return {
50 "global_step": self.global_step,
51 "epoch": self.epoch,
52 }
54 def load_state_dict(self, state_dict: dict[str, Any]) -> None:
55 """Load state from a dictionary.
57 Args:
58 state_dict: Dictionary containing the training state to load.
59 """
60 self._global_step = state_dict["global_step"]
61 self._epoch = state_dict["epoch"]
63 def step(self) -> None:
64 """Step the global step."""
65 self._global_step += 1
67 def step_epoch(self) -> None:
68 """Step the epoch."""
69 self._epoch += 1
71 def unstep_epoch(self) -> None:
72 """Unstep the epoch."""
73 self._epoch -= 1
76class Timer:
77 """Timer for training and validation."""
79 def __init__(self, total_steps: int | None = None):
80 self.start_time = time.time()
81 self.total_steps = total_steps
82 self.current_step = 0
84 def start(self) -> None:
85 """Restart the timer."""
86 self.start_time = time.time()
87 self.current_step = 0
89 def step(self) -> None:
90 """Step the timer."""
91 self.current_step += 1
92 self.last_time = time.time()
94 def get_delta(self) -> float:
95 """Get the time delta since the timer was started."""
96 return self.last_time - self.start_time
98 def get_eta(self, current_step: int | None = None) -> float:
99 """Get the estimated time to completion."""
100 if self.total_steps is None:
101 raise ValueError("Total steps is not set.")
102 current_step = current_step or self.current_step
103 if current_step == 0:
104 return 0
105 return self.get_delta() / current_step * max(self.total_steps - current_step, 0)
107 def get_total_time(self) -> float:
108 """Get the total time expected to complete all steps."""
109 if self.total_steps is None:
110 raise ValueError("Total steps is not set.")
111 return self.get_delta() / self.current_step * self.total_steps
113 def get_rate(self, current_step: int | None = None) -> float:
114 """Get the rate of steps per second."""
115 current_step = current_step or self.current_step
116 return current_step / self.get_delta()
118 def get_step_time(self, current_step: int | None = None) -> float:
119 """Get the time per step."""
120 current_step = current_step or self.current_step
121 return self.get_delta() / current_step
123 def get_time_str(self) -> str:
124 """Get the time delta since the timer was started."""
125 return Timer._format_time(self.get_delta())
127 def get_eta_str(self, current_step: int | None = None) -> str:
128 """Get the estimated time to completion."""
129 current_step = current_step or self.current_step
130 return Timer._format_time(self.get_eta(current_step))
132 def get_total_time_str(self) -> str:
133 """Get the total time expected to complete all steps."""
134 return Timer._format_time(self.get_total_time())
136 def get_rate_str(self, current_step: int | None = None) -> str:
137 """Get the rate of steps per second."""
138 current_step = current_step or self.current_step
139 return f"{self.get_rate(current_step):.2f} steps/s"
141 def get_step_time_rate_str(self, current_step: int | None = None) -> str:
142 """Get the time per step."""
143 current_step = current_step or self.current_step
144 return f"{self.get_step_time(current_step):.2f} s/step"
146 def get_step_time_str(self, current_step: int | None = None) -> str:
147 """Get the time per step."""
148 current_step = current_step or self.current_step
149 return Timer._format_time(self.get_step_time(current_step))
151 @staticmethod
152 def _format_time(seconds: float) -> str:
153 """Format time in seconds to HH:MM:SS."""
154 seconds = int(seconds)
155 return f"{seconds // 3600:02d}:{(seconds % 3600) // 60:02d}:{seconds % 60:02d}"
158class NeptuneSummaryWriter(SummaryWriter):
159 """Combine SummaryWriter with NeptuneWriter."""
161 def __init__(self, log_dir: str, run: neptune.Run) -> None:
162 super().__init__(log_dir=log_dir)
163 self.run = run
165 def add_scalar(self, tag: str, scalar_value: float, global_step: int | float | None = None) -> None:
166 """Record scalar to tensorboard and Neptune."""
167 # Check for NaN values - these indicate a serious training problem
168 if math.isnan(scalar_value):
169 error_msg = (
170 f"NaN value detected when logging metric '{tag}' at step {global_step}. "
171 f"This indicates a serious training problem (e.g., exploding gradients, division by zero). "
172 f"Stopping training to prevent further issues.\n\n"
173 f"Traceback showing where this NaN value originated:\n"
174 )
175 # Get the current stack trace
176 stack_trace = traceback.format_stack()
177 # Remove the last frame (this method) and show the relevant callers
178 # Keep the last few frames that show where add_scalar was called from
179 relevant_frames = stack_trace[:-1][-6:] # Show last 6 frames before this method
180 error_msg += "".join(relevant_frames)
181 raise ValueError(error_msg)
183 super().add_scalar(
184 tag=tag,
185 scalar_value=scalar_value,
186 global_step=global_step,
187 )
188 self.run[tag].append(scalar_value, step=global_step)
190 def add_text(
191 self,
192 tag: str,
193 text_string: str,
194 global_step: Optional[int] = None,
195 walltime: Optional[float] = None,
196 ) -> None:
197 """Record text to tensorboard and Neptune."""
198 super().add_text(tag=tag, text_string=text_string, global_step=global_step, walltime=walltime)
200 self.run[tag] = text_string
202 def add_hparams(
203 self,
204 hparam_dict: dict,
205 metric_dict: dict,
206 hparam_domain_discrete: Optional[Dict[str, List[Any]]] = None,
207 run_name: Optional[str] = None,
208 global_step: Optional[int] = None,
209 ) -> None:
210 """Add a set of hyperparameters to be compared in Neptune as for Tensorboard."""
211 super().add_hparams(
212 hparam_dict,
213 metric_dict,
214 hparam_domain_discrete=hparam_domain_discrete,
215 run_name=run_name,
216 global_step=global_step,
217 )
218 flatten_hparam = _flatten_dict_using_keypath(hparam_dict, base_keypath="params")
219 for hparam, value in flatten_hparam.items():
220 self.run[hparam] = value
223def _set_author_neptune_api_token() -> None:
224 """Set the variable NEPTUNE_API_TOKEN based on the email of commit author.
226 It is useful on AIchor to have proper owner of each run.
227 """
228 try:
229 author_email = os.environ["VCS_AUTHOR_EMAIL"]
230 # we are not on AIchor
231 except KeyError:
232 logger.debug("We are not running on AIchor (https://aichor.ai/), not looking for Neptune API token.")
233 return
235 author_email, _ = author_email.split("@")
236 author_email = author_email.replace("-", "_").replace(".", "_").upper()
238 logger.info(f"Checking for Neptune API token under {author_email}__NEPTUNE_API_TOKEN.")
239 try:
240 author_api_token = os.environ[f"{author_email}__NEPTUNE_API_TOKEN"]
241 os.environ["NEPTUNE_API_TOKEN"] = author_api_token
242 logger.info(f"Set token for {author_email}.")
243 except KeyError:
244 logger.info(f"Neptune credentials for user {author_email} not found.")
247def _get_filepath_mapping(file_groups: Dict[str, str]) -> Dict[str, str]:
248 """Get filepath mapping for validation groups."""
249 group_mapping = {}
250 for group, path in file_groups.items():
251 for fp in SpectrumDataFrame._convert_file_paths(path):
252 group_mapping[fp] = group
253 return group_mapping
256def _flatten_dict_using_keypath(obj: Dict[str, Any], base_keypath: str = "", sep: str = "/") -> Dict[str, Any]:
257 """Recursively flatten a nested mapping into a single-level dict with joined keys (usually called keypaths).
259 Example:
260 _flatten_dict_using_keypath({"a": {"b": 1}, "c": 2}) -> {"a/b": 1, "c": 2}
261 """
262 flatten: Dict[str, Any] = {}
264 for key, value in obj.items():
265 key_str = str(key)
266 new_key = f"{base_keypath}{sep}{key_str}" if base_keypath else key_str
268 if isinstance(value, dict):
269 deeper = _flatten_dict_using_keypath(value, base_keypath=new_key, sep=sep)
270 flatten.update(deeper)
271 else:
272 flatten[new_key] = value
274 return flatten