# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: MIT
"""
Helper functions for image upload operations.
This module contains the implementation of image upload workflows
for the ImageEndpointAPI. These functions are separated from the main endpoint
implementation to improve code organization and testability.
"""
from __future__ import annotations
import os
import time
import warnings
from concurrent.futures import ThreadPoolExecutor, as_completed
from datetime import timedelta
from http import HTTPStatus
from pathlib import Path
from typing import TYPE_CHECKING, Any, Optional
import requests
from air_sdk.const import (
DEFAULT_RETRY_ATTEMPTS,
DEFAULT_RETRY_BACKOFF_FACTOR,
DEFAULT_UPLOAD_TIMEOUT,
MULTIPART_CHUNK_SIZE,
MULTIPART_MIN_PART_SIZE,
)
from air_sdk.endpoints import mixins
from air_sdk.exceptions import AirUnexpectedResponse
from air_sdk.utils import (
FilePartReader,
join_urls,
raise_if_invalid_response,
sha256_file,
)
if TYPE_CHECKING:
from air_sdk import AirAPI
from air_sdk.endpoints.images import Image
[docs]
def abort_multipart_upload(
*,
api_client: AirAPI,
base_url: str,
image: Image,
) -> None:
"""Abort a multipart upload to clean up S3 resources.
This should be called when a multipart upload fails to prevent
orphaned uploads in S3 storage.
Args:
api_client: The AirApi client instance for making HTTP requests
base_url: Base URL for the images endpoint
image: Image instance (must be in UPLOADING status)
Note:
This method does not raise exceptions on failure, as it's meant
for cleanup during error handling. Warnings are issued instead.
"""
try:
abort_url = join_urls(base_url, str(image.id), 'abort-upload')
abort_response = api_client.client.patch(abort_url, data='{}')
# Expect 204 No Content on success
if abort_response.status_code != HTTPStatus.NO_CONTENT:
warnings.warn(
f'Failed to abort multipart upload '
f'for image {image.id}: HTTP {abort_response.status_code}. ',
stacklevel=3,
)
except Exception as e:
# Catch all exceptions to prevent masking the original error
warnings.warn(
f'Exception while aborting multipart upload for image {image.id}: {e}. ',
stacklevel=3,
)
[docs]
def upload_single_part(
*,
api_client: AirAPI,
filepath: str | Path,
part_number: int,
start_offset: int,
part_size: int,
presigned_url: str,
timeout: float,
max_retries: int = DEFAULT_RETRY_ATTEMPTS,
) -> dict[str, Any]:
"""Upload a single part to S3 with retry logic for transient failures.
Automatically retries on transient network errors (connection errors,
timeouts, 503/429 responses) with exponential backoff. Non-transient
errors are raised immediately.
Args:
api_client: The AirApi client instance (used for verify setting)
filepath: Path to the file to upload
part_number: Part number (1-indexed)
start_offset: Starting byte offset in file
part_size: Size of this part in bytes
presigned_url: S3 presigned URL for this part
timeout: Timeout in seconds for the upload
max_retries: Maximum number of retry attempts
(default: DEFAULT_RETRY_ATTEMPTS)
Returns:
Dict with 'part_number' and 'etag' keys
Raises:
AirUnexpectedResponse: If upload fails or S3 doesn't return ETag
requests.RequestException: If all retry attempts fail
"""
last_exception: Exception | None = None
upload_response: requests.Response | None = None
for attempt in range(max_retries):
retry_reason = None # Will be set if we should retry
last_exception = None # Reset each attempt to track the current failure
try:
with FilePartReader(filepath, start_offset, part_size) as part_reader:
upload_response = requests.put(
presigned_url,
data=part_reader,
timeout=timeout,
verify=api_client.client.verify,
)
# Check for transient HTTP errors that should be retried
# 429: Too Many Requests (rate limiting)
# 502: Bad Gateway (upstream server error)
# 503: Service Unavailable (temporary overload/maintenance)
# 504: Gateway Timeout (upstream timeout)
if upload_response.status_code in (429, 502, 503, 504):
retry_reason = f'HTTP {upload_response.status_code}'
else:
# Not a transient error - validate the response
raise_if_invalid_response(
upload_response, status_code=HTTPStatus.OK, data_type=None
)
etag = upload_response.headers.get('ETag', '').strip('"')
if not etag:
raise AirUnexpectedResponse(
f'S3 did not return ETag for part {part_number}. '
f'Upload may have failed silently.'
)
return {'part_number': part_number, 'etag': etag}
except (
requests.exceptions.ConnectionError,
requests.exceptions.Timeout,
requests.exceptions.ChunkedEncodingError,
) as e:
last_exception = e
retry_reason = f'{type(e).__name__}: {e}'
except Exception:
# Don't retry on non-transient errors
# (e.g., file not found, invalid response)
raise
# Common retry logic
if retry_reason:
if attempt < max_retries - 1:
wait_time = DEFAULT_RETRY_BACKOFF_FACTOR * (2**attempt)
warnings.warn(
f'Part {part_number} upload failed ({retry_reason}). '
f'Retrying in {wait_time}s (attempt {attempt + 1}/{max_retries})',
stacklevel=4,
)
time.sleep(wait_time)
continue
else:
# Last attempt - raise the appropriate error
if last_exception:
raise last_exception
# For HTTP errors, validate response to raise proper error
# upload_response is guaranteed to exist here since we only
# reach this path after receiving an HTTP response
assert upload_response is not None
raise_if_invalid_response(
upload_response, status_code=HTTPStatus.OK, data_type=None
)
# Should never reach here, but just in case
if last_exception:
raise last_exception
raise AirUnexpectedResponse(
f'Part {part_number} upload failed after {max_retries} attempts'
)
[docs]
def upload_parts_to_s3(
*,
api_client: AirAPI,
filepath: str | Path,
parts_info: list[dict[str, int]],
part_urls: list[dict[str, Any]],
timeout_per_part: float,
max_workers: int = 1,
) -> list[dict[str, Any]]:
"""Upload file parts directly to S3 using presigned URLs.
Supports both sequential (max_workers=1) and parallel (max_workers>1) uploads.
Args:
api_client: The AirApi client instance
filepath: Path to the file to upload
parts_info: List of part information (part_number, start, size)
part_urls: List of presigned URL data from backend
timeout_per_part: Timeout in seconds for each part upload
max_workers: Number of concurrent upload workers. Default: 1 (sequential)
Returns:
List of uploaded parts with part_number and etag, sorted by part_number
Raises:
AirUnexpectedResponse: If any part upload fails
"""
if max_workers == 1:
# Sequential upload (default)
uploaded_parts = []
for part_info, part_url_data in zip(parts_info, part_urls):
result = upload_single_part(
api_client=api_client,
filepath=filepath,
part_number=part_info['part_number'],
start_offset=part_info['start'],
part_size=part_info['size'],
presigned_url=part_url_data['url'],
timeout=timeout_per_part,
)
uploaded_parts.append(result)
# Sort by part number to ensure correct order
uploaded_parts.sort(key=lambda x: x['part_number'])
return uploaded_parts
# Parallel upload
uploaded_parts = []
failed_parts = []
with ThreadPoolExecutor(max_workers=max_workers) as executor:
# Submit all upload tasks
future_to_part = {}
for part_info, part_url_data in zip(parts_info, part_urls):
future = executor.submit(
upload_single_part,
api_client=api_client,
filepath=filepath,
part_number=part_info['part_number'],
start_offset=part_info['start'],
part_size=part_info['size'],
presigned_url=part_url_data['url'],
timeout=timeout_per_part,
)
future_to_part[future] = part_info['part_number']
# Collect results as they complete
try:
for future in as_completed(future_to_part):
part_number = future_to_part[future]
try:
result = future.result()
uploaded_parts.append(result)
except Exception as e:
# Part upload failed - cancel remaining uploads to save bandwidth
failed_parts.append((part_number, str(e)))
for f in future_to_part:
if not f.done():
f.cancel()
break # Stop collecting results
except Exception:
# On unexpected error during result collection, cancel any pending futures
# to avoid leaving orphaned threads running
for future in future_to_part:
future.cancel()
raise
# Check if all parts uploaded successfully
if failed_parts:
error_msg = 'Failed to upload the following parts:\n'
for part_num, error in failed_parts:
error_msg += f' Part {part_num}: {error}\n'
raise AirUnexpectedResponse(error_msg)
# Sort by part number to ensure correct order
uploaded_parts.sort(key=lambda x: x['part_number'])
return uploaded_parts
[docs]
def complete_multipart_upload(
*,
api_client: AirAPI,
base_url: str,
image: Image,
uploaded_parts: list[dict[str, Any]],
) -> None:
"""Complete a multipart upload by sending ETags to backend.
Backend will use boto3.complete_multipart_upload() with the ETags.
Args:
api_client: The AirApi client instance for making HTTP requests
base_url: Base URL for the images endpoint
image: Image instance
uploaded_parts: List of uploaded parts with part_number and etag
Raises:
requests.RequestException: If completion fails
"""
complete_payload = {
'parts': uploaded_parts,
}
complete_upload_url = join_urls(base_url, str(image.id), 'complete-upload')
complete_upload_response = api_client.client.patch(
complete_upload_url, data=mixins.serialize_payload(complete_payload)
)
raise_if_invalid_response(complete_upload_response, status_code=HTTPStatus.OK)
image.refresh()
[docs]
def calculate_parts_info(
file_size: int,
num_parts: int,
chunk_size: int,
) -> list[dict[str, int]]:
"""Calculate byte ranges for each part based on file size and number of parts.
The API determines the number of parts and chunk size automatically.
This function calculates the actual byte ranges for each part.
Args:
file_size: Total file size in bytes
num_parts: Number of parts (from API response)
chunk_size: Size of each chunk in bytes (from API response)
Returns:
List of dicts with part_number, start, and size for each part
"""
if num_parts == 0:
return []
parts_info = []
for i in range(num_parts):
part_number = i + 1 # Parts are 1-indexed
start = i * chunk_size
is_last_part = i == num_parts - 1
if is_last_part:
# Last part gets whatever remains
size = file_size - start
else:
size = chunk_size
# Validate part size
if size <= 0:
raise AirUnexpectedResponse(
f'Part {part_number} has invalid size ({size} bytes). '
f'The backend returned too many parts ({num_parts}) for '
f'file size ({file_size} bytes).'
)
# S3 requires all parts except the last to be at least 5 MiB
if not is_last_part and size < MULTIPART_MIN_PART_SIZE:
raise AirUnexpectedResponse(
f'Part {part_number} size ({size} bytes) is below S3 minimum '
f'of {MULTIPART_MIN_PART_SIZE} bytes (5 MiB). The backend returned '
f'too many parts ({num_parts}) for file size ({file_size} bytes).'
)
parts_info.append(
{
'part_number': part_number,
'start': start,
'size': size,
}
)
return parts_info
[docs]
def upload_image(
*,
api_client: AirAPI,
base_url: str,
image: Image,
filepath: str | Path,
timeout: Optional[timedelta] = None,
max_workers: int = 1,
**kwargs: Any,
) -> 'Image':
"""Upload an image file using multipart upload.
All uploads use multipart upload to S3. The API calculates parts (~100MB each)
automatically based on file size.
The upload flow is:
1. Start upload with hash and size → get upload_id and presigned URLs
2. Upload each part directly to S3 using presigned URLs
3. Complete upload with part ETags
If any step fails, the multipart upload is automatically aborted to prevent
orphaned data in S3 storage.
Args:
api_client: The AirApi client instance for making HTTP requests
base_url: Base URL for the images endpoint
image: Image instance
filepath: Path to the file to upload
timeout: Timeout per part upload
(default: DEFAULT_UPLOAD_TIMEOUT = 5 minutes).
This timeout applies to EACH part, not the entire multipart
operation.
max_workers: Number of concurrent upload workers. Default: 1 (sequential).
Set > 1 for parallel uploads (e.g., 4 for 4 concurrent uploads).
**kwargs: Additional arguments (currently unused, kept for API compatibility)
Note:
Presigned URL expiration varies by file size. Check the backend
documentation for the exact expiration time.
Returns:
Updated Image instance
Raises:
AirUnexpectedResponse: If upload fails or backend returns invalid data
Exception: For other upload errors
"""
file_size = os.path.getsize(filepath)
# Use the provided timeout (or default) for each part upload
# This ensures sufficient time for large parts without dividing by part count
timeout = timeout or DEFAULT_UPLOAD_TIMEOUT
timeout_per_part = timeout.total_seconds()
# Step 1: Get the file hash and initiate multipart upload
file_hash = sha256_file(filepath)
payload = {
'hash': file_hash,
'size': file_size,
}
start_upload_url = join_urls(base_url, str(image.id), 'start-upload')
start_upload_response = api_client.client.patch(
start_upload_url, data=mixins.serialize_payload(payload)
)
raise_if_invalid_response(start_upload_response, status_code=HTTPStatus.OK)
image.refresh()
# Get presigned URLs and chunk size from backend
response_data = start_upload_response.json()
upload_id = response_data.get('upload_id')
part_urls = response_data.get('part_urls', [])
chunk_size = response_data.get('chunk_size', MULTIPART_CHUNK_SIZE)
# Validate backend response before proceeding
if not upload_id:
raise AirUnexpectedResponse(
'Backend did not return upload_id for multipart upload. '
'This indicates a server error.'
)
# Step 2-3: Upload parts and complete with cleanup on failure
# Wrap everything after upload_id is obtained to ensure cleanup on any failure
try:
if not part_urls:
raise AirUnexpectedResponse(
'Backend did not return presigned URLs for multipart upload. '
'This indicates a server error.'
)
# Calculate part byte ranges based on number of parts and chunk size from API
parts_info = calculate_parts_info(file_size, len(part_urls), chunk_size)
# Step 2: Upload each part directly to S3 using presigned URLs
uploaded_parts = upload_parts_to_s3(
api_client=api_client,
filepath=filepath,
parts_info=parts_info,
part_urls=part_urls,
timeout_per_part=timeout_per_part,
max_workers=max_workers,
)
# Step 3: Complete the multipart upload
complete_multipart_upload(
api_client=api_client,
base_url=base_url,
image=image,
uploaded_parts=uploaded_parts,
)
return image
except Exception:
# Cleanup: abort the multipart upload to prevent orphaned S3 data
abort_multipart_upload(api_client=api_client, base_url=base_url, image=image)
# Re-raise the original exception
raise