import hashlib
from pathlib import Path
from typing import List
import boto3
from botocore import UNSIGNED
from botocore.client import BaseClient, Config
from botocore.exceptions import ClientError
from botocore.exceptions import ConnectTimeoutError
from botocore.exceptions import ProfileNotFound
from flexsea.utilities.decorators import check_status_code
# ============================================
# s3_download
# ============================================
[docs]@check_status_code
def s3_download(
obj: str, bucket: str, dest: str, profile: str | None = None, timeout=60
) -> None:
"""
Downloads a file from S3.
Parameters
----------
obj : str
The name of the S3 object to download.
bucket : str
The name of the S3 bucket ``obj`` resides in.
dest : str
The path to where ``obj`` will be downloaded.
profile : str, optional
The name of the profile in the ``~/.aws/credentials`` file.
This profile should hold both the access key and secret access
key needed for downloading private or restricted files.
timeout : int, optional
Time, in seconds, spent trying to connect to S3 before an
exception is raised.
Raises
------
ValueError
If the given profile cannot be found.
ConnectTimeoutError
If a connection to S3 cannot be established within the
allotted time.
"""
# https://stackoverflow.com/a/34866092
if profile is None:
# pylint: disable=duplicate-code
client = boto3.client(
"s3",
config=Config(signature_version=UNSIGNED, connect_timeout=timeout),
region_name="us-east-1",
)
else:
try:
session = boto3.Session(profile_name=profile)
except ProfileNotFound as err:
msg = f"Error: invalid AWS profile `{profile}`"
raise ValueError(msg) from err
client = session.client("s3", config=Config(connect_timeout=timeout))
try:
client.download_file(bucket, obj, dest)
except ClientError:
# If the download fails, one possible reason is because we weren't given a
# valid object path, but, instead, just a base name, e.g., myfirmware.dfu
# instead of firmwareBucket/major.minor.patch/device/hw/myfirmware.dfu
# In this case we want to search the given bucket for the file
obj = s3_find_object(obj, bucket, client)
try:
client.download_file(bucket, obj, dest)
except ConnectTimeoutError as err:
raise RuntimeError("Could not connect to S3. Timeout.") from err
except ConnectTimeoutError as err:
raise RuntimeError("Could not connect to S3. Timeout.") from err
_validate_download(client, bucket, obj, dest)
print(f"Downloaded {obj} from {bucket} to {dest}")
# ============================================
# s3_find_object
# ============================================
[docs]@check_status_code
def s3_find_object(fileName: str, bucket: str, client: str) -> str:
"""
Searches the given bucket for the given file.
Returns the full object path if there's only one match. If there
aren't any matches or there's more than one, we fail.
Parameters
----------
fileName : str
The name of the S3 object to search for.
bucket : str
The name of the S3 bucket to search.
client : :py:class:`BaseClient`
The object responsible for connecting to S3.
Raises
------
FileNotFoundError
If the given ``fileName`` cannot be found.
Returns
-------
str
The full S3 path to the object.
Notes
-----
Paginator use: https://tinyurl.com/4scnuk6c
"""
paginator = client.get_paginator("list_objects_v2")
pageIterator = paginator.paginate(Bucket=bucket)
objects = pageIterator.search(f"Contents[?contains(Key, `{fileName}`)][]")
# There should only be one match
# Objects is a generator, so we have to convert it to a list to check length
items = []
for item in objects:
items.append(item["Key"])
if len(items) == 0:
raise FileNotFoundError(f"Could not find: {fileName} in {bucket}")
if len(items) > 1:
raise FileNotFoundError(f"Found multiple options for: {fileName} in {bucket}")
return items[0]
# ============================================
# _validate_download
# ============================================
def _validate_download(
client: BaseClient, bucket: str, fileObj: str, dest: str
) -> None:
"""
Compares the AWS md5 hash to the local md5 hash to make sure the
files are the same.
S3 objects have an attribute called 'ETag', which is a string.
There are two possibilities: the string is a hex digest or the
string is a hex digest + -NUMBER.
In the case that ETag has no -NUMBER suffix, it means that the ETag
is the md5 hash of the file contents.
In the case that there is a -NUMBER suffix, it means that the file
was uploaded to S3 in NUMBER chunks and that the hex digest is actually
the hex digest of the digests of all of the chunks concatenated together.
Parameters
----------
client : :py:class:`BaseClient`
The object that allows use to communicate with S3.
bucket : str
The name of the bucket the file came from.
fileObj : str
The name of the fileObj we are validating.
dest : str
The name of the dowloaded file on disk.
Raises
-------
FileNotFoundError
If the downloaded file is not present in the desired destinationself.
AssertionError
If the hash of the remote object does not match the hash of the local
file.
"""
try:
assert Path(dest).exists()
except AssertionError as err:
raise FileNotFoundError from err
# Check the local file's integrity by comparing its md5 hash to
# AWS's md5 hash, called ETag
objData = client.head_object(Bucket=bucket, Key=fileObj)
etag = objData["ETag"].strip('"')
try:
nChunks = int(etag.split("-")[1])
except IndexError:
nChunks = 1
if nChunks == 1:
with open(dest, "rb") as fd:
data = fd.read()
localHash = hashlib.md5(data).hexdigest()
elif nChunks > 1:
chunkHashes = []
with open(dest, "rb") as fd:
for chunk in range(1, nChunks + 1):
objData = client.head_object(
Bucket=bucket, Key=fileObj, PartNumber=chunk
)
chunkSize = objData["ContentLength"]
data = fd.read(chunkSize)
if data:
chunkHashes.append(hashlib.md5(data))
else:
break
if len(chunkHashes) == 1:
localHash = chunkHashes[0].hexdigest()
else:
digests = b"".join([m.digest() for m in chunkHashes])
digestsMd5 = hashlib.md5(digests)
localHash = f"{digestsMd5.hexdigest()}-{len(chunkHashes)}"
assert localHash == etag
# ============================================
# get_s3_objects
# ============================================
[docs]@check_status_code
def get_s3_objects(bucket: str, client: BaseClient, prefix: str = "") -> List:
"""
Recursively loops over all directories in a bucket and returns a
list of files.
Parameters
----------
bucket : str
The name of the bucket we're getting files from.
client : :py:class:`BaseClient`
The object providing an interface to S3.
prefix : str
The directory we're looping over. If `""`, then we get the
top-level directories.
Returns
-------
List[str]
A list of all the objects in the bucket.
"""
objectList = []
objects = client.list_objects_v2(Bucket=bucket, Delimiter="/", Prefix=prefix)
if "CommonPrefixes" in objects:
for pre in objects["CommonPrefixes"]:
objectList += get_s3_objects(bucket, client, pre["Prefix"])
if "Contents" in objects:
return objectList + [obj["Key"] for obj in objects["Contents"][1:]]
return objectList