Coverage for instanovo/utils/s3.py: 76%
119 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-12-08 07:26 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2025-12-08 07:26 +0000
1from __future__ import annotations
3import os
4import tempfile
5from pathlib import Path
6from typing import Any, Callable
7from urllib.parse import urlparse
9import s3fs
10from tensorboard.compat.tensorflow_stub.io.gfile import _REGISTERED_FILESYSTEMS, register_filesystem
12from instanovo.__init__ import console
13from instanovo.utils.colorlogging import ColorLog
15logger = ColorLog(console, __name__).logger
18class S3FileHandler:
19 """A utility class for handling files stored locally or on S3.
21 Attributes:
22 temp_dir (tempfile.TemporaryDirectory): A temporary directory
23 for storing downloaded S3 files.
24 """
26 def __init__(
27 self, s3_endpoint: str | None = None, aws_access_key_id: str | None = None, aws_secret_access_key: str | None = None, verbose: bool = True
28 ) -> None:
29 """Initializes the S3FileHandler.
31 Args:
32 s3_endpoint: Optional S3 endpoint to use. If not provided,
33 the S3_ENDPOINT environment variable will be used.
34 aws_access_key_id: Optional AWS access key ID to use. If not provided,
35 the AWS_ACCESS_KEY_ID environment variable will be used.
36 aws_secret_access_key: Optional AWS secret access key to use. If not provided,
37 the AWS_SECRET_ACCESS_KEY environment variable will be used.
38 verbose: Whether to log verbose messages.
39 """
40 self.s3 = S3FileHandler._create_s3fs(
41 verbose,
42 s3_endpoint,
43 aws_access_key_id,
44 aws_secret_access_key,
45 )
46 self.temp_dir = tempfile.TemporaryDirectory()
47 self.verbose = verbose
49 @staticmethod
50 def s3_enabled() -> bool:
51 """Check if s3 is environment variable is present."""
52 return "S3_ENDPOINT" in os.environ
54 @staticmethod
55 def _create_s3fs(
56 verbose: bool = True,
57 s3_endpoint: str | None = None,
58 aws_access_key_id: str | None = None,
59 aws_secret_access_key: str | None = None,
60 ) -> s3fs.core.S3FileSystem | None:
61 if not S3FileHandler.s3_enabled() and s3_endpoint is None:
62 return None
64 if s3_endpoint is None:
65 assert "S3_ENDPOINT" in os.environ
66 if aws_access_key_id is None:
67 assert "AWS_ACCESS_KEY_ID" in os.environ
68 if aws_secret_access_key is None:
69 assert "AWS_SECRET_ACCESS_KEY" in os.environ
71 url = s3_endpoint or os.environ.get("S3_ENDPOINT")
72 if verbose:
73 logger.info(f"Creating s3fs.core.S3FileSystem, Endpoint: {url}")
74 s3 = s3fs.core.S3FileSystem(key=aws_access_key_id, secret=aws_secret_access_key, client_kwargs={"endpoint_url": url})
75 if verbose:
76 logger.info(f"Created s3fs.core.S3FileSystem: {s3}")
78 return s3
80 def _log_if_verbose(self, message: str) -> None:
81 """Log a message if verbose logging is enabled.
83 Args:
84 message (str): The message to log.
85 """
86 if self.verbose:
87 logger.info(message)
89 def _download_from_s3(self, s3_path: str) -> str:
90 """Downloads a file from S3 to a temporary directory.
92 Args:
93 s3_path (str): The S3 path of the file.
95 Returns:
96 str: The local file path where the file is saved.
97 """
98 if self.s3 is None:
99 return s3_path
100 parsed = urlparse(s3_path)
101 bucket, key = parsed.netloc, parsed.path.lstrip("/")
102 local_path = os.path.join(self.temp_dir.name, os.path.basename(s3_path))
104 self._log_if_verbose(f"Downloading {bucket}/{key} to {local_path}")
105 self.s3.get(f"{bucket}/{key}", local_path)
106 return local_path
108 def download(self, s3_path: str, local_path: str) -> None:
109 """Downloads a local from S3.
111 Args:
112 s3_path (str): The source S3 path (e.g., s3://bucket/key).
113 local_path (str): The path to the local file to be written.
114 """
115 if not s3_path.startswith("s3://") or self.s3 is None:
116 return
117 parsed = urlparse(s3_path)
118 bucket, key = parsed.netloc, parsed.path.lstrip("/")
119 dir_path = os.path.dirname(local_path)
120 if dir_path: # Only create directories if there's a directory component
121 os.makedirs(dir_path, exist_ok=True)
122 self._log_if_verbose(f"Downloading {bucket}/{key} to {local_path}")
123 self.s3.get(f"{bucket}/{key}", local_path)
125 def get_local_path(self, path: str, missing_ok: bool = False) -> str | None:
126 """Returns a local file path. If the input path is an S3 path, the file is downloaded first.
128 Args:
129 path (str): The local or S3 path.
131 Returns:
132 str: The local file path.
133 """
134 if path.startswith("s3://") and self.s3 is not None:
135 if not self.s3.exists(path):
136 if missing_ok:
137 return None
138 else:
139 raise FileNotFoundError(f"Could not find {path}.")
141 local_path = os.path.join(self.temp_dir.name, os.path.basename(path))
142 self.download(path, local_path)
143 return local_path
144 return path # Already a local path
146 def upload(self, local_path: str, s3_path: str) -> None:
147 """Uploads a local file to S3.
149 Args:
150 local_path (str): The path to the local file.
151 s3_path (str): The destination S3 path (e.g., s3://bucket/key).
152 """
153 if not s3_path.startswith("s3://") or self.s3 is None:
154 return
155 parsed = urlparse(s3_path)
156 bucket, key = parsed.netloc, parsed.path.lstrip("/")
158 self._log_if_verbose(f"Uploading {local_path} to {bucket}/{key}")
159 self.s3.put(local_path, f"{bucket}/{key}")
161 def upload_to_s3_wrapper(self, save_func: Callable[..., Any], s3_path: str, *args: Any, **kwargs: Any) -> Any:
162 """Calls a save function and uploads the resulting file to S3.
164 Args:
165 save_func (Callable[..., Any]): The function to save a file (e.g., torch.save).
166 s3_path (str): The destination S3 path.
167 *args: Additional positional arguments for the save function.
168 **kwargs: Additional keyword arguments for the save function.
169 """
170 if not s3_path.startswith("s3://") or self.s3 is None:
171 if os.path.dirname(s3_path):
172 os.makedirs(os.path.dirname(s3_path), exist_ok=True)
173 return save_func(s3_path, *args, **kwargs)
175 local_path = os.path.join(self.temp_dir.name, os.path.basename(urlparse(s3_path).path))
176 result = save_func(local_path, *args, **kwargs)
177 self.upload(local_path, s3_path)
178 return result
180 def listdir(self, path: str) -> list[str]:
181 """List the contents of a directory on S3.
183 Args:
184 path (str): The path to the directory.
185 """
186 if not path.startswith("s3://") or self.s3 is None:
187 return []
189 parsed = urlparse(path)
190 bucket, key = parsed.netloc, parsed.path.lstrip("/")
191 return self.s3.listdir(f"{bucket}/{key}", detail=False) # type: ignore[no-any-return]
193 @staticmethod
194 def _aichor_enabled() -> bool:
195 """Check if Aichor is enabled."""
196 if "AICHOR_LOGS_PATH" in os.environ:
197 assert "S3_ENDPOINT" in os.environ
198 return True
199 return False
201 @staticmethod
202 def register_tb() -> bool:
203 """Register s3 filesystem to tensorboard.
205 Returns:
206 bool: Whether the registration was successful.
207 """
208 if not S3FileHandler._aichor_enabled():
209 return False
210 fs = _REGISTERED_FILESYSTEMS["s3"]
211 if fs._s3_endpoint is None:
212 # Set custom S3 endpoint explicitly. Not sure why this isn't picked up here:
213 # https://github.com/tensorflow/tensorboard/blob/153cc747fdbeca3545c81947d4880d139a185c52/tensorboard/compat/tensorflow_stub/io/gfile.py#L227
214 fs._s3_endpoint = os.environ["S3_ENDPOINT"]
215 register_filesystem("s3", fs)
216 return True
218 @staticmethod
219 def convert_to_s3_output(path: str) -> str:
220 """Convert local directory to s3 output path if possible.
222 Args:
223 path (str): The local path to convert.
225 Returns:
226 str: The s3 output path.
227 """
228 if S3FileHandler._aichor_enabled():
229 # Environment variable specific to Aichor compute platform
230 output_path = os.environ["AICHOR_OUTPUT_PATH"]
231 if "s3://" not in output_path:
232 output_path = f"s3://{output_path}/output"
234 if path.startswith("/"):
235 path = str(Path(path).relative_to("/"))
237 return os.path.join(output_path, path)
238 return path
240 def cleanup(self) -> None:
241 """Clean up the temporary directory."""
242 self.temp_dir.cleanup()
244 def __del__(self) -> None:
245 self.cleanup()