Coverage for instanovo/utils/device_handler.py: 98%
57 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-12-08 07:26 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2025-12-08 07:26 +0000
1from typing import Any, Dict, Optional
3import torch
4from omegaconf import DictConfig
6from instanovo.__init__ import console
7from instanovo.utils.colorlogging import ColorLog
9logger = ColorLog(console, __name__).logger
12def get_device_capabilities() -> Dict[str, bool]:
13 """Check device capabilities.
15 Returns:
16 Dict containing availability flags for different device types.
17 """
18 return {
19 "cuda": torch.cuda.is_available(),
20 "mps": torch.backends.mps.is_available(),
21 }
24def detect_device() -> str:
25 """Detect the best available device for computation.
27 Returns:
28 str: The selected device ('cpu', 'cuda', 'mps').
29 """
30 capabilities = get_device_capabilities()
32 if capabilities["cuda"]:
33 return "cuda"
34 elif capabilities["mps"]:
35 return "mps"
36 else:
37 return "cpu"
40def get_device_config_updates(device: str) -> Dict[str, Any]:
41 """Get configuration updates needed for the specified device.
43 Args:
44 device: The device type ('cpu', 'cuda', 'mps').
46 Returns:
47 Dict containing configuration updates for the device.
48 """
49 config_updates: Dict[str, Any] = {}
51 if device == "cuda":
52 config_updates.update(
53 {
54 "mps": False,
55 "force_cpu": False,
56 }
57 )
58 elif device == "mps":
59 config_updates.update(
60 {
61 "mps": True,
62 "force_fp32": True,
63 "force_cpu": False,
64 }
65 )
66 config_updates["model"] = {"peak_embedding_dtype": "float32"}
67 elif device == "cpu":
68 config_updates["force_cpu"] = True
69 else:
70 raise ValueError(f"Unknown device: {device}, no configuration updates applied.")
72 return config_updates
75def apply_device_config(config: DictConfig, device: Optional[str] = None) -> str:
76 """Apply device-specific configuration to the provided config.
78 Args:
79 config: The configuration object to update.
80 device: Optional device to use. If None, will auto-detect.
82 Returns:
83 str: The device that was applied.
84 """
85 if device is None:
86 device = detect_device()
88 config_updates = get_device_config_updates(device)
90 for key, value in config_updates.items():
91 if key == "model" and isinstance(value, dict):
92 if "model" in config:
93 for model_key, model_value in value.items():
94 config["model"][model_key] = model_value
95 else:
96 config[key] = value
98 return device
101def validate_and_configure_device(config: DictConfig) -> None:
102 """Validate device configuration and apply necessary updates.
104 Args:
105 config: The configuration object to validate and update.
106 """
107 capabilities = get_device_capabilities()
109 if capabilities["mps"] and not config.get("mps", False):
110 logger.warning(
111 "The Metal Performance Shaders (MPS) backend for Apple silicon devices is available, but not requested. "
112 "See https://developer.apple.com/documentation/metalperformanceshaders for more information. "
113 "Please set 'mps' to True in the configuration if you would like to use MPS."
114 )
116 if config.get("mps", False):
117 if not capabilities["mps"]:
118 logger.warning("MPS is not available, setting mps to False.")
119 config["mps"] = False
120 elif config.get("force_cpu", False):
121 logger.warning("Force CPU is set to True, setting mps to False.")
122 config["mps"] = False
123 else:
124 logger.info("MPS is set to True, forcing fp32. Note that performance on MPS may differ to performance on CUDA.")
125 config["force_fp32"] = True # Force fp32 if using mps
127 elif not config.get("force_cpu", False) and not capabilities["cuda"]:
128 logger.warning("CUDA is not available, setting force_cpu to True.")
129 config["force_cpu"] = True
132def check_device(config: Optional[DictConfig] = None) -> str:
133 """Legacy function for backward compatibility.
135 Args:
136 config: Optional configuration object to update.
138 Returns:
139 str: The selected device.
140 """
141 if config is not None:
142 return apply_device_config(config)
143 else:
144 return detect_device()