-
Notifications
You must be signed in to change notification settings - Fork 9
/
client.py
311 lines (251 loc) · 11.2 KB
/
client.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
"""Low-level functions for making requests to MLHub API endpoints."""
import itertools as it
from pathlib import Path
from typing import Iterator, List
from concurrent.futures import ThreadPoolExecutor
from functools import partial
import urllib.parse
from requests.exceptions import HTTPError
try:
from tqdm.auto import tqdm
except ImportError: # pragma: no cover
# Handles this issue: https://github.com/tqdm/tqdm/issues/1082
from tqdm import tqdm # type: ignore [no-redef]
from .session import get_session
from .exceptions import EntityDoesNotExist, MLHubException
def _download(
url: str,
output_dir: Path,
overwrite: bool = False,
chunk_size=5000000,
**session_kwargs
):
"""Internal function used to parallelize downloads from a given URL.
Parameters
----------
url : str
This can either be a full URL or a path relative to the Radiant MLHub root URL.
output_dir : Path
Path to a local directory to which the file will be downloaded. File name will be generated
automatically based on the download URL.
overwrite : bool, optional
Whether to overwrite an existing file at ``output_path``. Defaults to ``False``.
chunk_size : int, optional
The size of byte range for each concurrent request.
session_kwargs
Keyword arguments passed directly to ``get_session``
Raises
------
FileExistsError
If file of the same name already exists in ``output_dir`` and ``overwrite==False``.
"""
def _get_ranges(total_size, interval):
"""Internal function for getting byte ranges from a total size and interval/chunk size."""
start = 0
while True:
end = min(start + interval - 1, total_size)
yield f'{start}-{end}'
start += interval
if start >= total_size:
break
def _fetch_range(url_, range_):
"""Internal function for fetching a byte range from the url."""
return session.get(url_, headers={'Range': f'bytes={range_}'}).content
# Create a session
session = get_session(**session_kwargs)
# HEAD the endpoint and follow redirects to get the actual download URL and Content-Length
r = session.head(url, allow_redirects=True)
r.raise_for_status()
content_length = int(r.headers['Content-Length'])
download_url = r.url
# Resolve user directory shortcuts and relative paths
output_dir = Path(output_dir).expanduser().resolve()
# Get the full file path
output_file_name = urllib.parse.urlsplit(download_url).path.rsplit('/', 1)[1]
output_path = output_dir / output_file_name
# Check for existing output file
if output_path.exists() and not overwrite:
raise FileExistsError(f'File {output_path} already exists. Use overwrite=True to overwrite this file.')
# Create the parent directory, if it does not exist
output_path.parent.mkdir(parents=True, exist_ok=True)
# Check that the endpoint accepts byte range requests
use_range = r.headers.get('Accept-Ranges') == 'bytes'
if use_range:
# If we can use range requests, make concurrent requests to the byte ranges we need...
with ThreadPoolExecutor(max_workers=20) as executor:
with output_path.open('wb') as dst:
with tqdm(total=round(content_length / 1000000., 1), unit='M') as pbar:
for chunk in executor.map(partial(_fetch_range, download_url), _get_ranges(content_length, chunk_size)):
dst.write(chunk)
pbar.update(round(chunk_size / 1000000., 1))
else:
# ...if not, stream the response
with session.get(url, stream=True, allow_redirects=True) as r:
with output_path.open('wb') as dst:
with tqdm(total=round(content_length / 1000000., 1), unit='M') as pbar:
for chunk in r.iter_content(chunk_size=chunk_size):
dst.write(chunk)
pbar.update(round(chunk_size / 1000000., 1))
def list_datasets(**session_kwargs) -> List[dict]:
"""Gets a list of JSON-like dictionaries representing dataset objects returned by the Radiant MLHub ``GET /datasets`` endpoint.
See the `MLHub API docs <https://docs.mlhub.earth/#radiant-mlhub-api>`_ for details.
Parameters
----------
**session_kwargs
Keyword arguments passed directly to :func:`~radiant_mlhub.session.get_session`
Returns
-------
datasets : List[dict]
"""
session = get_session(**session_kwargs)
return session.get('datasets').json()
def get_dataset(dataset_id: str, **session_kwargs) -> dict:
"""Returns a JSON-like dictionary representing the response from the Radiant MLHub ``GET /datasets/{dataset_id}`` endpoint.
See the `MLHub API docs <https://docs.mlhub.earth/#radiant-mlhub-api>`_ for details.
Parameters
----------
dataset_id : str
The ID of the dataset to fetch
**session_kwargs
Keyword arguments passed directly to :func:`~radiant_mlhub.session.get_session`
Returns
-------
dataset : dict
"""
session = get_session(**session_kwargs)
try:
return session.get(f'datasets/{dataset_id}').json()
except HTTPError as e:
if e.response.status_code == 404:
raise EntityDoesNotExist(f'Dataset "{dataset_id}" does not exist.') from None
raise MLHubException(f'An unknown error occurred: {e.response.status_code} ({e.response.reason})') from None
def list_collections(**session_kwargs) -> List[dict]:
"""Gets a list of JSON-like dictionaries representing STAC Collection objects returned by the Radiant MLHub ``GET /collections``
endpoint.
See the `MLHub API docs <https://docs.mlhub.earth/#radiant-mlhub-api>`_ for details.
Parameters
----------
**session_kwargs
Keyword arguments passed directly to :func:`~radiant_mlhub.session.get_session`
Returns
-------
collections: List[dict]
List of JSON-like dictionaries representing STAC Collection objects.
"""
session = get_session(**session_kwargs)
r = session.get('collections')
return r.json().get('collections', [])
def get_collection(collection_id: str, **session_kwargs) -> dict:
"""Returns a JSON-like dictionary representing the response from the Radiant MLHub ``GET /collections/{p1}`` endpoint.
See the `MLHub API docs <https://docs.mlhub.earth/#radiant-mlhub-api>`_ for details.
Parameters
----------
collection_id : str
The ID of the collection to fetch
**session_kwargs
Keyword arguments passed directly to :func:`~radiant_mlhub.session.get_session`
Returns
-------
collection : dict
Raises
------
EntityDoesNotExist
If a 404 response code is returned by the API
MLHubException
If any other response code is returned
"""
session = get_session(**session_kwargs)
try:
return session.get(f'collections/{collection_id}').json()
except HTTPError as e:
if e.response.status_code == 404:
raise EntityDoesNotExist(f'Collection "{collection_id}" does not exist.') from None
raise MLHubException(f'An unknown error occurred: {e.response.status_code} ({e.response.reason})') from None
def list_collection_items(
collection_id: str,
*,
page_size: int = None,
extensions: List[str] = None,
limit: int = 10,
**session_kwargs
) -> Iterator[dict]:
"""Yields JSON-like dictionaries representing STAC Item objects returned by the Radiant MLHub ``GET /collections/{collection_id}/items``
endpoint.
.. note::
Because some collections may contain hundreds of thousands of items, this function limits the total number of responses
to ``10`` by default. You can change this value by increasing the value of the ``limit`` keyword argument,
or setting it to ``None`` to list all items. **Be aware that trying to list all items in a large collection may take a very
long time.**
Parameters
----------
collection_id : str
The ID of the collection from which to fetch items
page_size : int
The number of items to return in each page. If set to ``None``, then this parameter will not be passed to the API and
the default API value will be used (currently ``30``).
extensions : list
If provided, then only items that support all of the extensions listed will be returned.
limit : int
The maximum *total* number of items to yield. Defaults to ``10``.
**session_kwargs
Keyword arguments passed directly to :func:`~radiant_mlhub.session.get_session`
Yields
------
item : dict
JSON-like dictionary representing a STAC Item associated with the given collection.
"""
session = get_session(**session_kwargs)
def _list_items():
params = {}
if page_size is not None:
params['limit'] = page_size
if extensions is not None:
params['extensions'] = extensions
for page in session.paginate(f'collections/{collection_id}/items', params=params):
yield from page['features']
yield from it.islice(_list_items(), limit)
def get_collection_item(collection_id: str, item_id: str, **session_kwargs) -> dict:
"""Returns a JSON-like dictionary representing the response from the Radiant MLHub ``GET /collections/{p1}/items/{p2}`` endpoint.
Parameters
----------
collection_id : str
The ID of the Collection to which the Item belongs.
item_id : str
The ID of the Item.
**session_kwargs
Keyword arguments passed directly to :func:`~radiant_mlhub.session.get_session`
Returns
-------
item : dict
"""
session = get_session(**session_kwargs)
response = session.get(f'collections/{collection_id}/items/{item_id}')
if response.ok:
return response.json()
if response.status_code == 404:
raise EntityDoesNotExist(f'Collection "{collection_id}" does not exist.')
raise MLHubException(f'An unknown error occurred: {response.status_code} ({response.reason})')
def download_archive(archive_id: str, output_dir: Path = None, *, overwrite: bool = False, **session_kwargs):
"""Downloads the archive with the given ID to an output location (current working directory by default).
Parameters
----------
archive_id : str
The ID of the archive to download.
output_dir : Path
Path to which the archive will be downloaded. Defaults to the current working directory.
overwrite : bool, optional
Whether to overwrite an existing file of the same name. Defaults to ``False``.
**session_kwargs
Keyword arguments passed directly to :func:`~radiant_mlhub.session.get_session`
Raises
------
FileExistsError
If file at ``output_path`` already exists and ``overwrite==False``.
"""
output_dir = output_dir if output_dir is not None else Path.cwd()
try:
_download(f'archive/{archive_id}', output_dir=output_dir, overwrite=overwrite, **session_kwargs)
except HTTPError as e:
if e.response.status_code != 404:
raise
raise EntityDoesNotExist(f'Archive "{archive_id}" does not exist and may still be generating. Please try again later.') from None