Coverage for instanovo/common/utils.py: 67%

134 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-12-08 07:26 +0000

1import math 

2import os 

3import time 

4import traceback 

5from typing import Any, Dict, List, Optional 

6 

7import neptune 

8from torch.utils.tensorboard import SummaryWriter 

9 

10from instanovo.__init__ import console 

11from instanovo.utils.colorlogging import ColorLog 

12from instanovo.utils.data_handler import SpectrumDataFrame 

13 

14logger = ColorLog(console, __name__).logger 

15 

16 

17# Used to record additional training state parameters 

18# This is only used for accelerate.save_state and 

19# accelerate.load_state. (resuming runs) 

20class TrainingState: 

21 """Training state for tracking training progress. 

22 

23 This class is used by Accelerate to save and load training state during 

24 checkpointing and resuming training runs. It tracks the current epoch 

25 and global step of training. 

26 """ 

27 

28 def __init__(self) -> None: 

29 """Initialize training state with zeroed counters.""" 

30 self._global_step: int = 0 

31 self._epoch: int = 0 

32 

33 @property 

34 def global_step(self) -> int: 

35 """Get the current global step.""" 

36 return self._global_step 

37 

38 @property 

39 def epoch(self) -> int: 

40 """Get the current epoch.""" 

41 return self._epoch 

42 

43 def state_dict(self) -> dict[str, Any]: 

44 """Get the state dictionary for saving. 

45 

46 Returns: 

47 dict[str, Any]: Dictionary containing the current training state. 

48 """ 

49 return { 

50 "global_step": self.global_step, 

51 "epoch": self.epoch, 

52 } 

53 

54 def load_state_dict(self, state_dict: dict[str, Any]) -> None: 

55 """Load state from a dictionary. 

56 

57 Args: 

58 state_dict: Dictionary containing the training state to load. 

59 """ 

60 self._global_step = state_dict["global_step"] 

61 self._epoch = state_dict["epoch"] 

62 

63 def step(self) -> None: 

64 """Step the global step.""" 

65 self._global_step += 1 

66 

67 def step_epoch(self) -> None: 

68 """Step the epoch.""" 

69 self._epoch += 1 

70 

71 def unstep_epoch(self) -> None: 

72 """Unstep the epoch.""" 

73 self._epoch -= 1 

74 

75 

76class Timer: 

77 """Timer for training and validation.""" 

78 

79 def __init__(self, total_steps: int | None = None): 

80 self.start_time = time.time() 

81 self.total_steps = total_steps 

82 self.current_step = 0 

83 

84 def start(self) -> None: 

85 """Restart the timer.""" 

86 self.start_time = time.time() 

87 self.current_step = 0 

88 

89 def step(self) -> None: 

90 """Step the timer.""" 

91 self.current_step += 1 

92 self.last_time = time.time() 

93 

94 def get_delta(self) -> float: 

95 """Get the time delta since the timer was started.""" 

96 return self.last_time - self.start_time 

97 

98 def get_eta(self, current_step: int | None = None) -> float: 

99 """Get the estimated time to completion.""" 

100 if self.total_steps is None: 

101 raise ValueError("Total steps is not set.") 

102 current_step = current_step or self.current_step 

103 if current_step == 0: 

104 return 0 

105 return self.get_delta() / current_step * max(self.total_steps - current_step, 0) 

106 

107 def get_total_time(self) -> float: 

108 """Get the total time expected to complete all steps.""" 

109 if self.total_steps is None: 

110 raise ValueError("Total steps is not set.") 

111 return self.get_delta() / self.current_step * self.total_steps 

112 

113 def get_rate(self, current_step: int | None = None) -> float: 

114 """Get the rate of steps per second.""" 

115 current_step = current_step or self.current_step 

116 return current_step / self.get_delta() 

117 

118 def get_step_time(self, current_step: int | None = None) -> float: 

119 """Get the time per step.""" 

120 current_step = current_step or self.current_step 

121 return self.get_delta() / current_step 

122 

123 def get_time_str(self) -> str: 

124 """Get the time delta since the timer was started.""" 

125 return Timer._format_time(self.get_delta()) 

126 

127 def get_eta_str(self, current_step: int | None = None) -> str: 

128 """Get the estimated time to completion.""" 

129 current_step = current_step or self.current_step 

130 return Timer._format_time(self.get_eta(current_step)) 

131 

132 def get_total_time_str(self) -> str: 

133 """Get the total time expected to complete all steps.""" 

134 return Timer._format_time(self.get_total_time()) 

135 

136 def get_rate_str(self, current_step: int | None = None) -> str: 

137 """Get the rate of steps per second.""" 

138 current_step = current_step or self.current_step 

139 return f"{self.get_rate(current_step):.2f} steps/s" 

140 

141 def get_step_time_rate_str(self, current_step: int | None = None) -> str: 

142 """Get the time per step.""" 

143 current_step = current_step or self.current_step 

144 return f"{self.get_step_time(current_step):.2f} s/step" 

145 

146 def get_step_time_str(self, current_step: int | None = None) -> str: 

147 """Get the time per step.""" 

148 current_step = current_step or self.current_step 

149 return Timer._format_time(self.get_step_time(current_step)) 

150 

151 @staticmethod 

152 def _format_time(seconds: float) -> str: 

153 """Format time in seconds to HH:MM:SS.""" 

154 seconds = int(seconds) 

155 return f"{seconds // 3600:02d}:{(seconds % 3600) // 60:02d}:{seconds % 60:02d}" 

156 

157 

158class NeptuneSummaryWriter(SummaryWriter): 

159 """Combine SummaryWriter with NeptuneWriter.""" 

160 

161 def __init__(self, log_dir: str, run: neptune.Run) -> None: 

162 super().__init__(log_dir=log_dir) 

163 self.run = run 

164 

165 def add_scalar(self, tag: str, scalar_value: float, global_step: int | float | None = None) -> None: 

166 """Record scalar to tensorboard and Neptune.""" 

167 # Check for NaN values - these indicate a serious training problem 

168 if math.isnan(scalar_value): 

169 error_msg = ( 

170 f"NaN value detected when logging metric '{tag}' at step {global_step}. " 

171 f"This indicates a serious training problem (e.g., exploding gradients, division by zero). " 

172 f"Stopping training to prevent further issues.\n\n" 

173 f"Traceback showing where this NaN value originated:\n" 

174 ) 

175 # Get the current stack trace 

176 stack_trace = traceback.format_stack() 

177 # Remove the last frame (this method) and show the relevant callers 

178 # Keep the last few frames that show where add_scalar was called from 

179 relevant_frames = stack_trace[:-1][-6:] # Show last 6 frames before this method 

180 error_msg += "".join(relevant_frames) 

181 raise ValueError(error_msg) 

182 

183 super().add_scalar( 

184 tag=tag, 

185 scalar_value=scalar_value, 

186 global_step=global_step, 

187 ) 

188 self.run[tag].append(scalar_value, step=global_step) 

189 

190 def add_text( 

191 self, 

192 tag: str, 

193 text_string: str, 

194 global_step: Optional[int] = None, 

195 walltime: Optional[float] = None, 

196 ) -> None: 

197 """Record text to tensorboard and Neptune.""" 

198 super().add_text(tag=tag, text_string=text_string, global_step=global_step, walltime=walltime) 

199 

200 self.run[tag] = text_string 

201 

202 def add_hparams( 

203 self, 

204 hparam_dict: dict, 

205 metric_dict: dict, 

206 hparam_domain_discrete: Optional[Dict[str, List[Any]]] = None, 

207 run_name: Optional[str] = None, 

208 global_step: Optional[int] = None, 

209 ) -> None: 

210 """Add a set of hyperparameters to be compared in Neptune as for Tensorboard.""" 

211 super().add_hparams( 

212 hparam_dict, 

213 metric_dict, 

214 hparam_domain_discrete=hparam_domain_discrete, 

215 run_name=run_name, 

216 global_step=global_step, 

217 ) 

218 flatten_hparam = _flatten_dict_using_keypath(hparam_dict, base_keypath="params") 

219 for hparam, value in flatten_hparam.items(): 

220 self.run[hparam] = value 

221 

222 

223def _set_author_neptune_api_token() -> None: 

224 """Set the variable NEPTUNE_API_TOKEN based on the email of commit author. 

225 

226 It is useful on AIchor to have proper owner of each run. 

227 """ 

228 try: 

229 author_email = os.environ["VCS_AUTHOR_EMAIL"] 

230 # we are not on AIchor 

231 except KeyError: 

232 logger.debug("We are not running on AIchor (https://aichor.ai/), not looking for Neptune API token.") 

233 return 

234 

235 author_email, _ = author_email.split("@") 

236 author_email = author_email.replace("-", "_").replace(".", "_").upper() 

237 

238 logger.info(f"Checking for Neptune API token under {author_email}__NEPTUNE_API_TOKEN.") 

239 try: 

240 author_api_token = os.environ[f"{author_email}__NEPTUNE_API_TOKEN"] 

241 os.environ["NEPTUNE_API_TOKEN"] = author_api_token 

242 logger.info(f"Set token for {author_email}.") 

243 except KeyError: 

244 logger.info(f"Neptune credentials for user {author_email} not found.") 

245 

246 

247def _get_filepath_mapping(file_groups: Dict[str, str]) -> Dict[str, str]: 

248 """Get filepath mapping for validation groups.""" 

249 group_mapping = {} 

250 for group, path in file_groups.items(): 

251 for fp in SpectrumDataFrame._convert_file_paths(path): 

252 group_mapping[fp] = group 

253 return group_mapping 

254 

255 

256def _flatten_dict_using_keypath(obj: Dict[str, Any], base_keypath: str = "", sep: str = "/") -> Dict[str, Any]: 

257 """Recursively flatten a nested mapping into a single-level dict with joined keys (usually called keypaths). 

258 

259 Example: 

260 _flatten_dict_using_keypath({"a": {"b": 1}, "c": 2}) -> {"a/b": 1, "c": 2} 

261 """ 

262 flatten: Dict[str, Any] = {} 

263 

264 for key, value in obj.items(): 

265 key_str = str(key) 

266 new_key = f"{base_keypath}{sep}{key_str}" if base_keypath else key_str 

267 

268 if isinstance(value, dict): 

269 deeper = _flatten_dict_using_keypath(value, base_keypath=new_key, sep=sep) 

270 flatten.update(deeper) 

271 else: 

272 flatten[new_key] = value 

273 

274 return flatten