# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: MIT
from __future__ import annotations
import hashlib
import inspect
import re
import time
from dataclasses import _MISSING_TYPE, Field, fields, is_dataclass
from datetime import datetime, timedelta, timezone
from functools import wraps
from http import HTTPStatus
from json import JSONDecodeError
from pathlib import Path
from typing import (
TYPE_CHECKING,
Any,
BinaryIO,
Callable,
Optional,
TypeVar,
cast,
get_type_hints,
)
from urllib.parse import ParseResult, urlparse
from uuid import UUID, uuid4
from requests import Response
from air_sdk.exceptions import AirUnexpectedResponse
from air_sdk.types import type_check
if TYPE_CHECKING:
from air_sdk.air_model import AirModel
F = TypeVar('F', bound=Callable[..., Any])
[docs]
def filter_missing(**kwargs: Any) -> dict[str, Any]:
"""Filter out MISSING values from kwargs.
This is a helper function to remove dataclasses.MISSING sentinel values
before passing kwargs to API methods.
Args:
**kwargs: Keyword arguments that may contain MISSING values
Returns:
Dictionary with MISSING values filtered out
"""
return {k: v for k, v in kwargs.items() if not isinstance(v, _MISSING_TYPE)}
# Mapping of string type names to actual builtin types
BUILTIN_TYPES = {
'str': str,
'int': int,
'bool': bool,
'float': float,
'dict': dict,
'list': list,
'tuple': tuple,
'set': set,
}
API_URI_PATTERN = r'^/api/v\d+.*$'
COMPILED_API_URI_PATTERN = re.compile(API_URI_PATTERN)
[docs]
def join_urls(*args: str) -> str:
return '/'.join(frag.strip('/') for frag in args) + '/'
[docs]
def iso_string_to_datetime(iso: str) -> Optional[datetime]:
try:
return datetime.fromisoformat(iso.replace('Z', '+00:00'))
except ValueError:
return None
[docs]
def datetime_to_iso_string(date: datetime) -> str:
"""Convert datetime to ISO string in UTC.
Accepts any timezone-aware datetime. For naive datetimes (no timezone),
assumes local timezone and emits a warning.
Args:
date: The datetime to convert
Returns:
ISO 8601 formatted string in UTC (with 'Z' suffix)
Warns:
UserWarning: If datetime is naive (no timezone specified)
"""
import warnings
# Handle naive datetimes by assuming local timezone and warning
if date.tzinfo is None:
warnings.warn(
'Naive datetime provided. '
'Assuming local timezone. '
'Use datetime.now(timezone.utc) for explicit UTC times.',
UserWarning,
stacklevel=2,
)
# Assume local timezone
date = date.astimezone()
return date.astimezone(tz=timezone.utc).isoformat().replace('+00:00', 'Z')
[docs]
def to_uuid(uuid: str) -> Optional[UUID]:
try:
return UUID(uuid, version=4)
except ValueError:
return None
[docs]
def to_url(url: str) -> Optional[ParseResult]:
try:
parsed_url = urlparse(url)
return (
parsed_url
if all((parsed_url.scheme, parsed_url.netloc, parsed_url.path))
else None
)
except AttributeError:
return None
[docs]
def is_dunder(name: str) -> bool:
delimiter = '__'
return name.startswith(delimiter) and name.endswith(delimiter)
[docs]
def as_field(class_or_instance: object, name: str) -> Optional[Field]: # type: ignore[type-arg]
if is_dataclass(class_or_instance):
try:
return next(
field for field in fields(class_or_instance) if field.name == name
)
except StopIteration:
pass
return None
[docs]
def _resolve_type_hints_fallback(func: Callable[..., Any]) -> dict[str, Any]:
"""
Fallback type hint resolution when get_type_hints() fails.
This handles cases where TYPE_CHECKING imports aren't available at runtime.
Only validates types that can be resolved at runtime.
Args:
func: Function to extract type hints from
Returns:
Dictionary of resolvable type hints (name -> type)
"""
hints: dict[str, Any] = {}
raw_annotations = getattr(func, '__annotations__', {})
for name, annotation in raw_annotations.items():
# Skip return type
if name == 'return':
continue
# Try to resolve each annotation individually
try:
if isinstance(annotation, str):
# Try basic type resolution from builtins
if annotation in BUILTIN_TYPES:
hints[name] = BUILTIN_TYPES[annotation]
elif annotation == 'Any':
hints[name] = Any
# Otherwise, skip (likely a TYPE_CHECKING import)
elif hasattr(annotation, '__origin__'):
# Generic type like dict[str, Any]
hints[name] = annotation
elif inspect.isclass(annotation):
# Real class
hints[name] = annotation
except Exception:
# Can't resolve - skip validation for this parameter
continue
return hints
[docs]
def validate_payload_types(func: F) -> F:
"""A wrapper for validating the type of payload during create."""
@wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> Any:
try:
# Try to get type hints with proper resolution
hints = get_type_hints(func)
except NameError:
# TYPE_CHECKING imports aren't available at runtime
# Use fallback to resolve only available types
hints = _resolve_type_hints_fallback(func)
sig = inspect.signature(func)
bound_args = sig.bind(*args, **kwargs)
bound_args.apply_defaults()
for name, value in bound_args.arguments.items():
if name in hints:
expected_type = hints[name]
if not type_check(value, expected_type):
raise TypeError(
f"Argument '{name}' must be {expected_type}, got {type(value)}"
)
return func(*args, **kwargs)
return cast(F, wrapper)
[docs]
def sha256_file(path: str | Path) -> str:
"""Get the SHA256 hash of the local file."""
h = hashlib.sha256()
with open(path, 'rb') as f:
for chunk in iter(lambda: f.read(4096), b''):
h.update(chunk)
return h.hexdigest()
[docs]
def calculate_multipart_info(file_size: int, chunk_size: int) -> list[dict[str, int]]:
"""Calculate part information for multipart upload.
Args:
file_size: Total size of the file in bytes
chunk_size: Size of each chunk in bytes
Returns:
List of dictionaries containing part info with keys:
- part_number: 1-based part number
- start: Start byte offset
- size: Size of this part in bytes
"""
parts = []
part_number = 1
offset = 0
while offset < file_size:
part_size = min(chunk_size, file_size - offset)
parts.append({'part_number': part_number, 'start': offset, 'size': part_size})
offset += part_size
part_number += 1
return parts
[docs]
class FilePartReader:
"""File-like object that reads only a specific portion of a file.
Used for streaming multipart uploads without loading entire file into memory.
This class implements the context manager protocol and provides a read()
method compatible with requests streaming uploads.
Args:
file_path: Path to the file to read from
start: Starting byte offset in the file
size: Number of bytes to read from the start offset
Example:
>>> with FilePartReader('large_file.bin', start=0, size=5242880) as reader:
... requests.put(presigned_url, data=reader)
"""
def __init__(self, file_path: str | Path, start: int, size: int):
self.file_path = file_path
self.start = start
self.size = size
self.remaining = size
self.f: BinaryIO | None = None
[docs]
def __enter__(self) -> 'FilePartReader':
self.f = open(self.file_path, 'rb')
self.f.seek(self.start)
return self
[docs]
def __exit__(self, *args: Any) -> None:
if self.f:
self.f.close()
[docs]
def __len__(self) -> int:
"""Return the total size of this part.
This is required for requests to set Content-Length header
instead of using Transfer-Encoding: chunked (which S3 doesn't support).
"""
return self.size
[docs]
def read(self, chunk_size: int = -1) -> bytes:
"""Read up to chunk_size bytes, but never exceed the part size.
Args:
chunk_size: Number of bytes to read. If -1, reads remaining bytes.
Returns:
Bytes read from the file, up to the specified chunk size
"""
if self.remaining <= 0:
return b''
if not self.f:
raise RuntimeError(
'FilePartReader must be used within a context manager '
'(with FilePartReader(...) as reader:)'
)
if chunk_size < 0:
chunk_size = self.remaining
else:
chunk_size = min(chunk_size, self.remaining)
data = self.f.read(chunk_size)
self.remaining -= len(data)
return data
[docs]
def create_short_uuid() -> str:
return str(uuid4()).replace('-', '')[:18]
[docs]
def normalize_api_url(url: str) -> str:
"""Ensures the API URL ends with the correct path."""
parsed_url = urlparse(url)
if not COMPILED_API_URI_PATTERN.match(parsed_url.path):
parsed_url = parsed_url._replace(path='/api/v3')
return parsed_url.geturl()
[docs]
def raise_if_invalid_response(
res: Response, status_code: HTTPStatus = HTTPStatus.OK, data_type: type | None = dict
) -> None:
"""
Validates that a given API response has the expected status code and JSON payload
Arguments:
res (requests.HTTPResponse) - API response object
status_code [int] - Expected status code (default: 200)
Raises:
AirUnexpectedResponse - Raised if an unexpected response is received from the API
"""
json = None
if res.status_code != status_code:
# logger.debug(res.text)
raise AirUnexpectedResponse(message=res.text, status_code=res.status_code)
if not data_type:
return
try:
json = res.json()
except JSONDecodeError:
raise AirUnexpectedResponse(message=res.text, status_code=res.status_code)
if not isinstance(json, data_type):
raise AirUnexpectedResponse(
message=f'Expected API response to be of type {data_type}, '
+ f'got {type(json)}',
status_code=res.status_code,
)
[docs]
def wait_for_state(
model: AirModel,
target_states: str | list[str],
*,
state_field: str = 'state',
timeout: timedelta | None = None,
poll_interval: timedelta | None = None,
error_states: str | list[str] | None = None,
) -> None:
"""Wait for a model to reach one of the target states.
This is a generic utility that works with any AirModel that has a state field
(e.g., Simulation.state, Node.state, Image.upload_status, etc.).
Args:
model: The AirModel instance to monitor (Simulation, Node, Image, etc.)
target_states: Single state or list of states to wait for
state_field: Name of the field containing the state (default: 'state').
Use 'upload_status' for Images, 'state' for most other models.
timeout: Maximum time to wait (default: 120 seconds)
poll_interval: Time between status checks (default: 2 seconds)
error_states: Single state or list of states that should raise an error.
If None, no error states are checked.
Raises:
ValueError: If the model enters one of the error states
TimeoutError: If timeout is reached before target state
AttributeError: If the model doesn't have the specified state_field
Example:
>>> # Wait for simulation to become active
>>> wait_for_state(simulation, 'ACTIVE', error_states=['INVALID', 'DELETING'])
>>>
>>> # Wait for image upload to complete
>>> wait_for_state(image, 'COMPLETE', state_field='upload_status')
>>>
>>> # Wait for node to boot or become active
>>> wait_for_state(node, ['BOOTING', 'ACTIVE'], error_states='ERROR')
"""
# Set defaults
if timeout is None:
timeout = timedelta(seconds=120)
if poll_interval is None:
poll_interval = timedelta(seconds=2)
# Normalize to lists
if isinstance(target_states, str):
target_states = [target_states]
if isinstance(error_states, str):
error_states = [error_states]
elif error_states is None:
error_states = []
# Validate that the model has the specified state field
if not hasattr(model, state_field):
raise AttributeError(
f'Model {type(model).__name__} does not have a "{state_field}" field. '
f'Available fields: {", ".join(f.name for f in fields(model))}'
)
start_time = time.time()
timeout_seconds = timeout.total_seconds()
poll_interval_seconds = poll_interval.total_seconds()
end_time = start_time + timeout_seconds
while time.time() < end_time:
model.refresh()
current_state = getattr(model, state_field)
if current_state in target_states:
return
if current_state in error_states:
model_type = type(model).__name__
raise ValueError(
f'{model_type} entered error state: {current_state}. '
f'Please check the {model_type.lower()} for more details.'
)
# Wait before polling again
time.sleep(poll_interval_seconds)
# Timeout reached
current_state = getattr(model, state_field)
states_str = ', '.join(target_states)
model_type = type(model).__name__
raise TimeoutError(
f'Timed out waiting for {model_type} to reach state(s): {states_str}. '
f'Current {state_field}: {current_state}'
)