Coverage for instanovo/common/scheduler.py: 34%

113 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-12-08 07:26 +0000

1import fnmatch 

2from typing import Any 

3 

4import numpy as np 

5import torch 

6from omegaconf import DictConfig 

7 

8from instanovo.__init__ import console 

9from instanovo.utils.colorlogging import ColorLog 

10 

11logger = ColorLog(console, __name__).logger 

12 

13 

14class FinetuneScheduler: 

15 """Scheduler for unfreezing parameters of a model. 

16 

17 Args: 

18 model_state_dict (dict): The state dictionary of the model. 

19 config (DictConfig): The configuration for the scheduler. 

20 steps_per_epoch (int | None): The number of steps per epoch. 

21 """ 

22 

23 def __init__(self, model_state_dict: dict, config: DictConfig, steps_per_epoch: int | None = None): 

24 self.model_state_dict = model_state_dict 

25 self.config = config 

26 self.steps_per_epoch = steps_per_epoch 

27 

28 self.is_verbose = self.config.get("verbose", False) 

29 

30 self.schedule = self._get_schedule() 

31 

32 if self.is_verbose: 

33 logger.info(f"Unfreezing schedule setup with {len(self.schedule)} phases.") 

34 for i, phase in enumerate(self.schedule): 

35 logger.info(f" - Phase {i + 1}, global_step {phase['global_step']:,d}, params {phase['params']}") 

36 

37 self._freeze_parameters() 

38 self.next_phase: dict[str, Any] | None = self.schedule.pop(0) 

39 self.step(0) # Trigger first unfreeze 

40 

41 def _get_schedule(self) -> list[dict]: 

42 unfreeze_format = self.config.get("unfreeze_format", "start_epoch") 

43 

44 phases = self.config.get("unfreeze_schedule", []) 

45 

46 if len(phases) == 0: 

47 raise ValueError("No unfreeze_schedule phases specified") 

48 

49 if any(phase.get(unfreeze_format, None) is None for phase in phases): 

50 raise ValueError(f"{unfreeze_format} must be specified for each phase") 

51 

52 schedule = [] 

53 

54 next_global_step = 0 

55 for phase in phases: 

56 global_step = 0 

57 match unfreeze_format: 

58 case "duration_epochs": 

59 if self.steps_per_epoch is None: 

60 raise ValueError("steps_per_epoch must be specified for epoch-based scheduling") 

61 global_step = next_global_step 

62 next_global_step += phase["duration_epochs"] * self.steps_per_epoch 

63 case "duration_steps": 

64 global_step = next_global_step 

65 next_global_step += phase["duration_steps"] 

66 case "start_epoch": 

67 if self.steps_per_epoch is None: 

68 raise ValueError("steps_per_epoch must be specified for epoch-based scheduling") 

69 global_step = phase["start_epoch"] * self.steps_per_epoch 

70 case "start_step": 

71 global_step = phase["start_step"] 

72 case _: 

73 raise ValueError(f"Invalid unfreeze format: {unfreeze_format}") 

74 

75 schedule.append({"global_step": global_step, "completed": False, "params": phase["params"]}) 

76 

77 # Check schedule is valid 

78 if schedule[0]["global_step"] != 0: 

79 raise ValueError("First phase must start at global step 0") 

80 

81 last_step = schedule[0]["global_step"] 

82 for phase in schedule: 

83 step = phase["global_step"] 

84 if step < last_step: 

85 raise ValueError("Phases must be in increasing order of steps/epochs") 

86 last_step = step 

87 

88 return schedule 

89 

90 def _freeze_parameters(self) -> None: 

91 logger.info("Freezing model parameters") 

92 num_params = 0 

93 num_layers = 0 

94 for _, param in self.model_state_dict.items(): 

95 param.requires_grad = False 

96 num_params += param.numel() 

97 num_layers += 1 

98 logger.info(f"Frozen {num_params:,d} parameters in {num_layers:,d} layers") 

99 

100 def _unfreeze(self, param_patterns: list[str]) -> None: 

101 logger.info(f"Unfreezing parameters: {param_patterns}") 

102 num_params = 0 

103 num_layers = 0 

104 for name, param in self.model_state_dict.items(): 

105 for pattern in param_patterns: 

106 if pattern == "*" or fnmatch.fnmatch(name, pattern): 

107 param.requires_grad = True 

108 num_params += param.numel() 

109 num_layers += 1 

110 break 

111 logger.info(f"Unfrozen {num_params:,d} parameters in {num_layers:,d} layers") 

112 

113 def step(self, global_step: int) -> None: 

114 """Step the unfreezing scheduler. 

115 

116 Args: 

117 global_step (int): The global step of the model. 

118 """ 

119 if self.next_phase is None or global_step < self.next_phase["global_step"]: 

120 return 

121 self._unfreeze(self.next_phase["params"]) 

122 if len(self.schedule) > 0: 

123 self.next_phase = self.schedule.pop(0) 

124 else: 

125 self.next_phase = None 

126 

127 

128class WarmupScheduler(torch.optim.lr_scheduler._LRScheduler): 

129 """Linear warmup scheduler.""" 

130 

131 def __init__(self, optimizer: torch.optim.Optimizer, warmup: int) -> None: 

132 self.warmup = warmup 

133 super().__init__(optimizer) 

134 

135 def get_lr(self) -> list[float]: 

136 """Get the learning rate at the current step.""" 

137 lr_factor = self.get_lr_factor(epoch=self.last_epoch) 

138 return [base_lr * lr_factor for base_lr in self.base_lrs] 

139 

140 def get_lr_factor(self, epoch: int) -> float: 

141 """Get the LR factor at the current step.""" 

142 lr_factor = 1.0 

143 if epoch <= self.warmup: 

144 lr_factor *= epoch / self.warmup 

145 return lr_factor 

146 

147 

148class CosineWarmupScheduler(torch.optim.lr_scheduler._LRScheduler): 

149 """Learning rate scheduler with linear warm up followed by cosine shaped decay. 

150 

151 Parameters 

152 ---------- 

153 optimizer : torch.optim.Optimizer 

154 Optimizer object. 

155 warmup : int 

156 The number of warm up iterations. 

157 max_iters : int 

158 The total number of iterations. 

159 """ 

160 

161 def __init__(self, optimizer: torch.optim.Optimizer, warmup: int, max_iters: int): 

162 self.warmup, self.max_iters = warmup, max_iters 

163 super().__init__(optimizer) 

164 

165 def get_lr(self) -> list[float]: 

166 """Get the learning rate at the current step.""" 

167 lr_factor = self.get_lr_factor(epoch=self.last_epoch) 

168 return [base_lr * lr_factor for base_lr in self.base_lrs] 

169 

170 def get_lr_factor(self, epoch: int) -> float: 

171 """Get the LR factor at the current step.""" 

172 # Cosine annealing after a constant period 

173 decay = self.warmup / self.max_iters 

174 if epoch <= self.warmup and self.warmup > 0: 

175 lr_factor = 1 * (epoch / self.warmup) 

176 else: 

177 lr_factor = 0.5 * (1 + np.cos(np.pi * ((epoch - (decay * self.max_iters)) / ((1 - decay) * self.max_iters)))) 

178 

179 return lr_factor