Coverage for instanovo/utils/s3.py: 76%

119 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-12-08 07:26 +0000

1from __future__ import annotations 

2 

3import os 

4import tempfile 

5from pathlib import Path 

6from typing import Any, Callable 

7from urllib.parse import urlparse 

8 

9import s3fs 

10from tensorboard.compat.tensorflow_stub.io.gfile import _REGISTERED_FILESYSTEMS, register_filesystem 

11 

12from instanovo.__init__ import console 

13from instanovo.utils.colorlogging import ColorLog 

14 

15logger = ColorLog(console, __name__).logger 

16 

17 

18class S3FileHandler: 

19 """A utility class for handling files stored locally or on S3. 

20 

21 Attributes: 

22 temp_dir (tempfile.TemporaryDirectory): A temporary directory 

23 for storing downloaded S3 files. 

24 """ 

25 

26 def __init__( 

27 self, s3_endpoint: str | None = None, aws_access_key_id: str | None = None, aws_secret_access_key: str | None = None, verbose: bool = True 

28 ) -> None: 

29 """Initializes the S3FileHandler. 

30 

31 Args: 

32 s3_endpoint: Optional S3 endpoint to use. If not provided, 

33 the S3_ENDPOINT environment variable will be used. 

34 aws_access_key_id: Optional AWS access key ID to use. If not provided, 

35 the AWS_ACCESS_KEY_ID environment variable will be used. 

36 aws_secret_access_key: Optional AWS secret access key to use. If not provided, 

37 the AWS_SECRET_ACCESS_KEY environment variable will be used. 

38 verbose: Whether to log verbose messages. 

39 """ 

40 self.s3 = S3FileHandler._create_s3fs( 

41 verbose, 

42 s3_endpoint, 

43 aws_access_key_id, 

44 aws_secret_access_key, 

45 ) 

46 self.temp_dir = tempfile.TemporaryDirectory() 

47 self.verbose = verbose 

48 

49 @staticmethod 

50 def s3_enabled() -> bool: 

51 """Check if s3 is environment variable is present.""" 

52 return "S3_ENDPOINT" in os.environ 

53 

54 @staticmethod 

55 def _create_s3fs( 

56 verbose: bool = True, 

57 s3_endpoint: str | None = None, 

58 aws_access_key_id: str | None = None, 

59 aws_secret_access_key: str | None = None, 

60 ) -> s3fs.core.S3FileSystem | None: 

61 if not S3FileHandler.s3_enabled() and s3_endpoint is None: 

62 return None 

63 

64 if s3_endpoint is None: 

65 assert "S3_ENDPOINT" in os.environ 

66 if aws_access_key_id is None: 

67 assert "AWS_ACCESS_KEY_ID" in os.environ 

68 if aws_secret_access_key is None: 

69 assert "AWS_SECRET_ACCESS_KEY" in os.environ 

70 

71 url = s3_endpoint or os.environ.get("S3_ENDPOINT") 

72 if verbose: 

73 logger.info(f"Creating s3fs.core.S3FileSystem, Endpoint: {url}") 

74 s3 = s3fs.core.S3FileSystem(key=aws_access_key_id, secret=aws_secret_access_key, client_kwargs={"endpoint_url": url}) 

75 if verbose: 

76 logger.info(f"Created s3fs.core.S3FileSystem: {s3}") 

77 

78 return s3 

79 

80 def _log_if_verbose(self, message: str) -> None: 

81 """Log a message if verbose logging is enabled. 

82 

83 Args: 

84 message (str): The message to log. 

85 """ 

86 if self.verbose: 

87 logger.info(message) 

88 

89 def _download_from_s3(self, s3_path: str) -> str: 

90 """Downloads a file from S3 to a temporary directory. 

91 

92 Args: 

93 s3_path (str): The S3 path of the file. 

94 

95 Returns: 

96 str: The local file path where the file is saved. 

97 """ 

98 if self.s3 is None: 

99 return s3_path 

100 parsed = urlparse(s3_path) 

101 bucket, key = parsed.netloc, parsed.path.lstrip("/") 

102 local_path = os.path.join(self.temp_dir.name, os.path.basename(s3_path)) 

103 

104 self._log_if_verbose(f"Downloading {bucket}/{key} to {local_path}") 

105 self.s3.get(f"{bucket}/{key}", local_path) 

106 return local_path 

107 

108 def download(self, s3_path: str, local_path: str) -> None: 

109 """Downloads a local from S3. 

110 

111 Args: 

112 s3_path (str): The source S3 path (e.g., s3://bucket/key). 

113 local_path (str): The path to the local file to be written. 

114 """ 

115 if not s3_path.startswith("s3://") or self.s3 is None: 

116 return 

117 parsed = urlparse(s3_path) 

118 bucket, key = parsed.netloc, parsed.path.lstrip("/") 

119 dir_path = os.path.dirname(local_path) 

120 if dir_path: # Only create directories if there's a directory component 

121 os.makedirs(dir_path, exist_ok=True) 

122 self._log_if_verbose(f"Downloading {bucket}/{key} to {local_path}") 

123 self.s3.get(f"{bucket}/{key}", local_path) 

124 

125 def get_local_path(self, path: str, missing_ok: bool = False) -> str | None: 

126 """Returns a local file path. If the input path is an S3 path, the file is downloaded first. 

127 

128 Args: 

129 path (str): The local or S3 path. 

130 

131 Returns: 

132 str: The local file path. 

133 """ 

134 if path.startswith("s3://") and self.s3 is not None: 

135 if not self.s3.exists(path): 

136 if missing_ok: 

137 return None 

138 else: 

139 raise FileNotFoundError(f"Could not find {path}.") 

140 

141 local_path = os.path.join(self.temp_dir.name, os.path.basename(path)) 

142 self.download(path, local_path) 

143 return local_path 

144 return path # Already a local path 

145 

146 def upload(self, local_path: str, s3_path: str) -> None: 

147 """Uploads a local file to S3. 

148 

149 Args: 

150 local_path (str): The path to the local file. 

151 s3_path (str): The destination S3 path (e.g., s3://bucket/key). 

152 """ 

153 if not s3_path.startswith("s3://") or self.s3 is None: 

154 return 

155 parsed = urlparse(s3_path) 

156 bucket, key = parsed.netloc, parsed.path.lstrip("/") 

157 

158 self._log_if_verbose(f"Uploading {local_path} to {bucket}/{key}") 

159 self.s3.put(local_path, f"{bucket}/{key}") 

160 

161 def upload_to_s3_wrapper(self, save_func: Callable[..., Any], s3_path: str, *args: Any, **kwargs: Any) -> Any: 

162 """Calls a save function and uploads the resulting file to S3. 

163 

164 Args: 

165 save_func (Callable[..., Any]): The function to save a file (e.g., torch.save). 

166 s3_path (str): The destination S3 path. 

167 *args: Additional positional arguments for the save function. 

168 **kwargs: Additional keyword arguments for the save function. 

169 """ 

170 if not s3_path.startswith("s3://") or self.s3 is None: 

171 if os.path.dirname(s3_path): 

172 os.makedirs(os.path.dirname(s3_path), exist_ok=True) 

173 return save_func(s3_path, *args, **kwargs) 

174 

175 local_path = os.path.join(self.temp_dir.name, os.path.basename(urlparse(s3_path).path)) 

176 result = save_func(local_path, *args, **kwargs) 

177 self.upload(local_path, s3_path) 

178 return result 

179 

180 def listdir(self, path: str) -> list[str]: 

181 """List the contents of a directory on S3. 

182 

183 Args: 

184 path (str): The path to the directory. 

185 """ 

186 if not path.startswith("s3://") or self.s3 is None: 

187 return [] 

188 

189 parsed = urlparse(path) 

190 bucket, key = parsed.netloc, parsed.path.lstrip("/") 

191 return self.s3.listdir(f"{bucket}/{key}", detail=False) # type: ignore[no-any-return] 

192 

193 @staticmethod 

194 def _aichor_enabled() -> bool: 

195 """Check if Aichor is enabled.""" 

196 if "AICHOR_LOGS_PATH" in os.environ: 

197 assert "S3_ENDPOINT" in os.environ 

198 return True 

199 return False 

200 

201 @staticmethod 

202 def register_tb() -> bool: 

203 """Register s3 filesystem to tensorboard. 

204 

205 Returns: 

206 bool: Whether the registration was successful. 

207 """ 

208 if not S3FileHandler._aichor_enabled(): 

209 return False 

210 fs = _REGISTERED_FILESYSTEMS["s3"] 

211 if fs._s3_endpoint is None: 

212 # Set custom S3 endpoint explicitly. Not sure why this isn't picked up here: 

213 # https://github.com/tensorflow/tensorboard/blob/153cc747fdbeca3545c81947d4880d139a185c52/tensorboard/compat/tensorflow_stub/io/gfile.py#L227 

214 fs._s3_endpoint = os.environ["S3_ENDPOINT"] 

215 register_filesystem("s3", fs) 

216 return True 

217 

218 @staticmethod 

219 def convert_to_s3_output(path: str) -> str: 

220 """Convert local directory to s3 output path if possible. 

221 

222 Args: 

223 path (str): The local path to convert. 

224 

225 Returns: 

226 str: The s3 output path. 

227 """ 

228 if S3FileHandler._aichor_enabled(): 

229 # Environment variable specific to Aichor compute platform 

230 output_path = os.environ["AICHOR_OUTPUT_PATH"] 

231 if "s3://" not in output_path: 

232 output_path = f"s3://{output_path}/output" 

233 

234 if path.startswith("/"): 

235 path = str(Path(path).relative_to("/")) 

236 

237 return os.path.join(output_path, path) 

238 return path 

239 

240 def cleanup(self) -> None: 

241 """Clean up the temporary directory.""" 

242 self.temp_dir.cleanup() 

243 

244 def __del__(self) -> None: 

245 self.cleanup()