Coverage for instanovo/common/scheduler.py: 34%
113 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-12-08 07:26 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2025-12-08 07:26 +0000
1import fnmatch
2from typing import Any
4import numpy as np
5import torch
6from omegaconf import DictConfig
8from instanovo.__init__ import console
9from instanovo.utils.colorlogging import ColorLog
11logger = ColorLog(console, __name__).logger
14class FinetuneScheduler:
15 """Scheduler for unfreezing parameters of a model.
17 Args:
18 model_state_dict (dict): The state dictionary of the model.
19 config (DictConfig): The configuration for the scheduler.
20 steps_per_epoch (int | None): The number of steps per epoch.
21 """
23 def __init__(self, model_state_dict: dict, config: DictConfig, steps_per_epoch: int | None = None):
24 self.model_state_dict = model_state_dict
25 self.config = config
26 self.steps_per_epoch = steps_per_epoch
28 self.is_verbose = self.config.get("verbose", False)
30 self.schedule = self._get_schedule()
32 if self.is_verbose:
33 logger.info(f"Unfreezing schedule setup with {len(self.schedule)} phases.")
34 for i, phase in enumerate(self.schedule):
35 logger.info(f" - Phase {i + 1}, global_step {phase['global_step']:,d}, params {phase['params']}")
37 self._freeze_parameters()
38 self.next_phase: dict[str, Any] | None = self.schedule.pop(0)
39 self.step(0) # Trigger first unfreeze
41 def _get_schedule(self) -> list[dict]:
42 unfreeze_format = self.config.get("unfreeze_format", "start_epoch")
44 phases = self.config.get("unfreeze_schedule", [])
46 if len(phases) == 0:
47 raise ValueError("No unfreeze_schedule phases specified")
49 if any(phase.get(unfreeze_format, None) is None for phase in phases):
50 raise ValueError(f"{unfreeze_format} must be specified for each phase")
52 schedule = []
54 next_global_step = 0
55 for phase in phases:
56 global_step = 0
57 match unfreeze_format:
58 case "duration_epochs":
59 if self.steps_per_epoch is None:
60 raise ValueError("steps_per_epoch must be specified for epoch-based scheduling")
61 global_step = next_global_step
62 next_global_step += phase["duration_epochs"] * self.steps_per_epoch
63 case "duration_steps":
64 global_step = next_global_step
65 next_global_step += phase["duration_steps"]
66 case "start_epoch":
67 if self.steps_per_epoch is None:
68 raise ValueError("steps_per_epoch must be specified for epoch-based scheduling")
69 global_step = phase["start_epoch"] * self.steps_per_epoch
70 case "start_step":
71 global_step = phase["start_step"]
72 case _:
73 raise ValueError(f"Invalid unfreeze format: {unfreeze_format}")
75 schedule.append({"global_step": global_step, "completed": False, "params": phase["params"]})
77 # Check schedule is valid
78 if schedule[0]["global_step"] != 0:
79 raise ValueError("First phase must start at global step 0")
81 last_step = schedule[0]["global_step"]
82 for phase in schedule:
83 step = phase["global_step"]
84 if step < last_step:
85 raise ValueError("Phases must be in increasing order of steps/epochs")
86 last_step = step
88 return schedule
90 def _freeze_parameters(self) -> None:
91 logger.info("Freezing model parameters")
92 num_params = 0
93 num_layers = 0
94 for _, param in self.model_state_dict.items():
95 param.requires_grad = False
96 num_params += param.numel()
97 num_layers += 1
98 logger.info(f"Frozen {num_params:,d} parameters in {num_layers:,d} layers")
100 def _unfreeze(self, param_patterns: list[str]) -> None:
101 logger.info(f"Unfreezing parameters: {param_patterns}")
102 num_params = 0
103 num_layers = 0
104 for name, param in self.model_state_dict.items():
105 for pattern in param_patterns:
106 if pattern == "*" or fnmatch.fnmatch(name, pattern):
107 param.requires_grad = True
108 num_params += param.numel()
109 num_layers += 1
110 break
111 logger.info(f"Unfrozen {num_params:,d} parameters in {num_layers:,d} layers")
113 def step(self, global_step: int) -> None:
114 """Step the unfreezing scheduler.
116 Args:
117 global_step (int): The global step of the model.
118 """
119 if self.next_phase is None or global_step < self.next_phase["global_step"]:
120 return
121 self._unfreeze(self.next_phase["params"])
122 if len(self.schedule) > 0:
123 self.next_phase = self.schedule.pop(0)
124 else:
125 self.next_phase = None
128class WarmupScheduler(torch.optim.lr_scheduler._LRScheduler):
129 """Linear warmup scheduler."""
131 def __init__(self, optimizer: torch.optim.Optimizer, warmup: int) -> None:
132 self.warmup = warmup
133 super().__init__(optimizer)
135 def get_lr(self) -> list[float]:
136 """Get the learning rate at the current step."""
137 lr_factor = self.get_lr_factor(epoch=self.last_epoch)
138 return [base_lr * lr_factor for base_lr in self.base_lrs]
140 def get_lr_factor(self, epoch: int) -> float:
141 """Get the LR factor at the current step."""
142 lr_factor = 1.0
143 if epoch <= self.warmup:
144 lr_factor *= epoch / self.warmup
145 return lr_factor
148class CosineWarmupScheduler(torch.optim.lr_scheduler._LRScheduler):
149 """Learning rate scheduler with linear warm up followed by cosine shaped decay.
151 Parameters
152 ----------
153 optimizer : torch.optim.Optimizer
154 Optimizer object.
155 warmup : int
156 The number of warm up iterations.
157 max_iters : int
158 The total number of iterations.
159 """
161 def __init__(self, optimizer: torch.optim.Optimizer, warmup: int, max_iters: int):
162 self.warmup, self.max_iters = warmup, max_iters
163 super().__init__(optimizer)
165 def get_lr(self) -> list[float]:
166 """Get the learning rate at the current step."""
167 lr_factor = self.get_lr_factor(epoch=self.last_epoch)
168 return [base_lr * lr_factor for base_lr in self.base_lrs]
170 def get_lr_factor(self, epoch: int) -> float:
171 """Get the LR factor at the current step."""
172 # Cosine annealing after a constant period
173 decay = self.warmup / self.max_iters
174 if epoch <= self.warmup and self.warmup > 0:
175 lr_factor = 1 * (epoch / self.warmup)
176 else:
177 lr_factor = 0.5 * (1 + np.cos(np.pi * ((epoch - (decay * self.max_iters)) / ((1 - decay) * self.max_iters))))
179 return lr_factor