Coverage for instanovo/utils/device_handler.py: 98%

57 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-12-08 07:26 +0000

1from typing import Any, Dict, Optional 

2 

3import torch 

4from omegaconf import DictConfig 

5 

6from instanovo.__init__ import console 

7from instanovo.utils.colorlogging import ColorLog 

8 

9logger = ColorLog(console, __name__).logger 

10 

11 

12def get_device_capabilities() -> Dict[str, bool]: 

13 """Check device capabilities. 

14 

15 Returns: 

16 Dict containing availability flags for different device types. 

17 """ 

18 return { 

19 "cuda": torch.cuda.is_available(), 

20 "mps": torch.backends.mps.is_available(), 

21 } 

22 

23 

24def detect_device() -> str: 

25 """Detect the best available device for computation. 

26 

27 Returns: 

28 str: The selected device ('cpu', 'cuda', 'mps'). 

29 """ 

30 capabilities = get_device_capabilities() 

31 

32 if capabilities["cuda"]: 

33 return "cuda" 

34 elif capabilities["mps"]: 

35 return "mps" 

36 else: 

37 return "cpu" 

38 

39 

40def get_device_config_updates(device: str) -> Dict[str, Any]: 

41 """Get configuration updates needed for the specified device. 

42 

43 Args: 

44 device: The device type ('cpu', 'cuda', 'mps'). 

45 

46 Returns: 

47 Dict containing configuration updates for the device. 

48 """ 

49 config_updates: Dict[str, Any] = {} 

50 

51 if device == "cuda": 

52 config_updates.update( 

53 { 

54 "mps": False, 

55 "force_cpu": False, 

56 } 

57 ) 

58 elif device == "mps": 

59 config_updates.update( 

60 { 

61 "mps": True, 

62 "force_fp32": True, 

63 "force_cpu": False, 

64 } 

65 ) 

66 config_updates["model"] = {"peak_embedding_dtype": "float32"} 

67 elif device == "cpu": 

68 config_updates["force_cpu"] = True 

69 else: 

70 raise ValueError(f"Unknown device: {device}, no configuration updates applied.") 

71 

72 return config_updates 

73 

74 

75def apply_device_config(config: DictConfig, device: Optional[str] = None) -> str: 

76 """Apply device-specific configuration to the provided config. 

77 

78 Args: 

79 config: The configuration object to update. 

80 device: Optional device to use. If None, will auto-detect. 

81 

82 Returns: 

83 str: The device that was applied. 

84 """ 

85 if device is None: 

86 device = detect_device() 

87 

88 config_updates = get_device_config_updates(device) 

89 

90 for key, value in config_updates.items(): 

91 if key == "model" and isinstance(value, dict): 

92 if "model" in config: 

93 for model_key, model_value in value.items(): 

94 config["model"][model_key] = model_value 

95 else: 

96 config[key] = value 

97 

98 return device 

99 

100 

101def validate_and_configure_device(config: DictConfig) -> None: 

102 """Validate device configuration and apply necessary updates. 

103 

104 Args: 

105 config: The configuration object to validate and update. 

106 """ 

107 capabilities = get_device_capabilities() 

108 

109 if capabilities["mps"] and not config.get("mps", False): 

110 logger.warning( 

111 "The Metal Performance Shaders (MPS) backend for Apple silicon devices is available, but not requested. " 

112 "See https://developer.apple.com/documentation/metalperformanceshaders for more information. " 

113 "Please set 'mps' to True in the configuration if you would like to use MPS." 

114 ) 

115 

116 if config.get("mps", False): 

117 if not capabilities["mps"]: 

118 logger.warning("MPS is not available, setting mps to False.") 

119 config["mps"] = False 

120 elif config.get("force_cpu", False): 

121 logger.warning("Force CPU is set to True, setting mps to False.") 

122 config["mps"] = False 

123 else: 

124 logger.info("MPS is set to True, forcing fp32. Note that performance on MPS may differ to performance on CUDA.") 

125 config["force_fp32"] = True # Force fp32 if using mps 

126 

127 elif not config.get("force_cpu", False) and not capabilities["cuda"]: 

128 logger.warning("CUDA is not available, setting force_cpu to True.") 

129 config["force_cpu"] = True 

130 

131 

132def check_device(config: Optional[DictConfig] = None) -> str: 

133 """Legacy function for backward compatibility. 

134 

135 Args: 

136 config: Optional configuration object to update. 

137 

138 Returns: 

139 str: The selected device. 

140 """ 

141 if config is not None: 

142 return apply_device_config(config) 

143 else: 

144 return detect_device()