# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
# All rights reserved.
# SPDX-License-Identifier: MIT
from __future__ import annotations
import warnings
from abc import ABC, abstractmethod
from dataclasses import Field, InitVar, asdict, dataclass, fields, is_dataclass
from datetime import datetime
from functools import cached_property
from typing import (
TYPE_CHECKING,
Any,
ClassVar,
Generic,
Iterator,
Literal,
Mapping,
Optional,
Tuple,
Type,
TypeAlias,
TypeVar,
cast,
get_origin,
get_type_hints,
)
from uuid import UUID
from air_sdk.exceptions import AirError, AirModelAttributeError
from air_sdk.types import (
get_list_arg,
get_optional_arg,
is_optional_union,
is_typed_dict,
is_union,
type_check,
)
from air_sdk.utils import (
as_field,
is_dunder,
iso_string_to_datetime,
join_urls,
to_uuid,
)
if TYPE_CHECKING:
from air_sdk import AirApi
T = TypeVar('T')
TAirModel = TypeVar('TAirModel', bound='AirModel')
TAirModel_co = TypeVar('TAirModel_co', bound='AirModel', covariant=True)
TParentAirModel_co = TypeVar('TParentAirModel_co', bound='AirModel', covariant=True)
TSupportedPrimitive = TypeVar('TSupportedPrimitive', int, str, float, bool)
PrimaryKey: TypeAlias = str | UUID
SpecialField = Mapping[str, object]
DataDict = dict[str, Any]
[docs]
def _generate_special_field() -> SpecialField:
"""Get a unique mapping for assignment of `metadata` on special `BaseModel` fields."""
return {'property': object()}
[docs]
@dataclass(eq=False)
class BaseModel:
__serializing__ = False # A flag indicating if the instance is serializing
[docs]
def dict(self) -> DataDict:
try:
self.__serializing__ = True
result = asdict(self)
finally:
self.__serializing__ = False
return result
[docs]
def json(self) -> str:
from air_sdk.endpoints.mixins import serialize_payload
return serialize_payload(self.dict())
[docs]
@dataclass(eq=False)
class AirModel(BaseModel, ABC):
_api: InitVar # type: ignore[type-arg]
FIELD_FOREIGN_KEY: ClassVar[SpecialField] = _generate_special_field()
FIELD_LAZY: ClassVar[SpecialField] = _generate_special_field()
[docs]
def __post_init__(self, _api: 'AirApi') -> None:
self.__api__ = _api
@property
def primary_key_field(self) -> str:
"""Returns the name of the field representing the primary key."""
return 'id'
@property
def __pk__(self) -> PrimaryKey:
"""
Returns current model's primary key for API-related actions.
"""
pk: PrimaryKey = getattr(self, self.primary_key_field)
return pk
@property
def detail_url(self) -> str:
return join_urls(self.get_model_api()(self.__api__).url, str(self.__pk__))
[docs]
def clear_cached_property(self, property_name: str) -> None:
"""Clear a cached property to force re-computation on next access.
Args:
property_name: Name of the cached property to clear
Example:
>>> simulation.clear_cached_property('ztp_script')
>>> # Next access will re-query the API
>>> print(simulation.ztp_script)
"""
try:
delattr(self, property_name)
except AttributeError:
pass # Property not cached yet
[docs]
def __eq__(self, other: Any) -> bool:
if pk := getattr(self, '__pk__', None):
return bool(str(pk) == str(getattr(other, '__pk__', None)))
return bool(self is other)
[docs]
def __getitem__(self, key: str) -> Any:
"""Enable dictionary-style access to fields: obj['field_name']"""
try:
return getattr(self, key)
except AttributeError:
raise KeyError(key) from None
[docs]
def __setitem__(self, key: str, value: Any) -> None:
"""Enable dictionary-style assignment to fields: obj['field_name'] = value"""
setattr(self, key, value)
[docs]
def __contains__(self, key: str) -> bool:
"""Enable membership testing: 'field_name' in obj"""
return hasattr(self, key)
[docs]
def __getattribute__(self, name: str) -> Any:
value = super().__getattribute__(name)
# filter out dunder attributes (avoids recursion on upcoming `as_field` call)
if value and not is_dunder(name):
field = as_field(self, name)
if field is not None:
if self.__serializing__ and field.metadata == AirModel.FIELD_FOREIGN_KEY:
# Use the pk instead of recursively serializing FK fields.
if isinstance(value, list):
return [str(fk.__pk__) for fk in value]
return str(value.__pk__)
if field.metadata == AirModel.FIELD_LAZY and value == AirModel.FIELD_LAZY:
value = getattr( # Resolve FK fields when their attrs are accessed.
self.get_model_api()(self.__api__).get(self.__pk__), name
)
setattr(self, name, value)
return value
[docs]
def __refresh__(self, refreshed_obj: Optional[BaseModel] = None) -> None:
"""Refreshed the instances data from the backend.
Raises
------
NotImplementedError - When the model's API does not implement `get`.
"""
if refreshed_obj is None:
endpoint_api = self.get_model_api()(self.__api__)
if endpoint_api is None:
raise NotImplementedError
refreshed_obj = endpoint_api.get(pk=self.__pk__)
for field in fields(self):
setattr(self, field.name, getattr(refreshed_obj, field.name))
[docs]
@classmethod
@abstractmethod
def get_model_api(
cls: Type[TAirModel_co],
) -> Type[BaseEndpointAPI[TAirModel_co]]:
"""
Returns the respective `AirModelAPI` type for this model.
"""
[docs]
def refresh(self) -> None:
"""Refresh the instance by querying new API data.
This uses the `get` method on the model's EndpointAPI by default.
"""
self.__refresh__()
[docs]
def _ensure_pk_exists(self, context: str) -> None:
# Cannot perform detailed API calls for instances without populated primary keys
if self.__pk__ is None:
raise AirError(
f'The {self.__class__.__name__} cannot be {context}: '
'primary key is `None`.'
)
[docs]
def update(self, *args: Any, **kwargs: Any) -> None:
self._ensure_pk_exists('updated')
updated_obj = self.get_model_api()(self.__api__).patch(self.__pk__, **kwargs)
self.__refresh__(
updated_obj
) # Ensure update data is reflected in model instance.
[docs]
def full_update(self, *args: Any, **kwargs: Any) -> None:
self._ensure_pk_exists('fully updated')
updated_obj = self.get_model_api()(self.__api__).put(self.__pk__, **kwargs)
self.__refresh__(
updated_obj
) # Ensure update data is reflected in model instance.
[docs]
def delete(self) -> None:
"""Delete the instance and nullify the primary key."""
self._ensure_pk_exists('deleted')
self.get_model_api()(self.__api__).delete(self.__pk__)
setattr(self, self.primary_key_field, None)
[docs]
class ForeignKeyMixin(Generic[TAirModel_co]):
"""AirModel mixin for lazily resolving the instance."""
def __init__(self, primary_key: UUID, _api: 'AirApi') -> None:
self.__fk__ = primary_key
self.__fk_resolved__ = False
self.__api__ = _api
@property
def __pk__(self) -> PrimaryKey:
return self.__fk__
[docs]
def __getattribute__(self, name: str) -> Any:
"""Loads the instance upon initial access to an exposed attribute."""
if (
not is_dunder(name)
and as_field(self, name) is not None
and not self.__fk_resolved__
):
self.__refresh__()
self.__fk_resolved__ = True
return super().__getattribute__(name)
[docs]
class ApiNotImplementedMixin:
"""Mixin used to allow AirModel subclasses to have an unimplemented API."""
[docs]
def __refresh__(self, refreshed_obj: Optional[BaseModel] = None) -> None:
if refreshed_obj is None and getattr(self, '__fk_resolved__', True) is True:
raise NotImplementedError
super().__refresh__(refreshed_obj) # type: ignore[misc]
[docs]
class EndpointMethodMixin:
"""Mixin class for defining common endpoint methods.
This is used to prevent with the intention of raising
a `NotImplementedError` instead of an `AttributeError` when
specific endpoint methods are not implemented in the SDK or API.
"""
[docs]
def list(self, **kwargs: Any) -> Iterator[AirModel]:
raise NotImplementedError
[docs]
def create(self, *args: Any, **kwargs: Any) -> AirModel:
raise NotImplementedError
[docs]
def get(self, *args: Any, **kwargs: Any) -> AirModel:
raise NotImplementedError
[docs]
def put(self, *args: Any, **kwargs: Any) -> AirModel:
raise NotImplementedError
[docs]
def patch(self, *args: Any, **kwargs: Any) -> AirModel:
raise NotImplementedError
[docs]
def delete(self, *args: Any, **kwargs: Any) -> None:
raise NotImplementedError
[docs]
class BaseEndpointAPI(EndpointMethodMixin, Generic[TAirModel_co]):
model: Type[TAirModel_co]
API_PATH: str = ''
def __init__(self, api: 'AirApi', default_filters: Optional[dict[str, Any]] = None):
self.__api__ = api
self.default_filters = default_filters or {}
@cached_property
def url(self) -> str:
return join_urls(self.__api__.client.api_url, self.API_PATH)
@cached_property
def open_api_url(self) -> str:
return join_urls(self.__api__.client.api_url, '#')
@cached_property
def model_cls_type_hints(self) -> DataDict:
return get_type_hints(self.model)
@cached_property
def model_cls_fields(self) -> Tuple[Field[Any], ...]:
return fields(self.model)
[docs]
def load_model(self, data: DataDict) -> TAirModel_co:
"""Construct a new model instance, validate data, and set the API Client."""
model_field_names = {field.name for field in self.model_cls_fields}
provided_fields: list[Field[Any]] = []
missing_fields: list[Field[Any]] = []
for field in self.model_cls_fields:
if field.name in data:
provided_fields.append(field)
else:
missing_fields.append(field)
try:
model_inst = self.model(
_api=self.__api__,
**self.parse_provided_fields(provided_fields, data),
**self.get_defaults_for_missing_fields(missing_fields),
)
# Set extra API fields that are not in the SDK model as dynamic attributes
# This allows access to new API fields before SDK is updated
extra_fields = set(data) - model_field_names
for field_name in extra_fields:
# Skip dunder attributes and other potentially dangerous names
# Prevents API from overwriting SDK methods like .json() or .dict
if field_name.startswith('_') or field_name in ('dict', 'json'):
warnings.warn(
f"API returned field '{field_name}' which is reserved by the "
f'SDK. Skipping to avoid conflicts. Consider updating your SDK.',
UserWarning,
stacklevel=4,
)
continue
try:
setattr(model_inst, field_name, data[field_name])
except AttributeError as e:
warnings.warn(
f"API returned extra field '{field_name}' but the SDK failed to "
f'assign it: {e}. Consider updating your SDK.',
UserWarning,
stacklevel=4,
)
return model_inst
except TypeError as e:
raise AirModelAttributeError(
f'failed to instantiate `{self.model}`: {e}'
) from None
[docs]
def get_defaults_for_missing_fields(self, dc_fields: list[Field[Any]]) -> DataDict:
"""Get default values for fields missing from API response."""
special_fields: dict[str, object | None] = {}
for field in dc_fields:
if field.metadata == AirModel.FIELD_LAZY:
# lazy fields which are not present are assigned a placeholder value
special_fields[field.name] = AirModel.FIELD_LAZY
elif is_optional_union(self.model_cls_type_hints[field.name]):
# optional fields which are not present are assigned to `None`
special_fields[field.name] = None
return special_fields
[docs]
def parse_provided_fields(
self, dc_fields: list[Field[Any]], data: DataDict
) -> DataDict:
return {
field.name: self.parse_field(
self.model_cls_type_hints[field.name],
field.metadata,
data[field.name],
f'field `{field.name}` of `{self.model.__name__}`',
)
for field in dc_fields
}
[docs]
def parse_field(
self,
hint: Type[T],
metadata: Mapping[Any, Any],
provided_value: Any,
context: str,
) -> T:
"""Parse the provided value based on the type hint of the value.
This allows us to perform type checking of provided values and assists in
the implementation of our `FIELD_FOREIGN_KEY` and `FIELD_LAZY` fields.
If parsing fails (e.g., due to API type changes), falls back to
returning the raw value to prevent crashes.
"""
origin = get_origin(hint)
try:
if origin is not None:
if isinstance(origin, type) and issubclass(origin, list):
return cast(
T,
self.handle_list_field(
cast(Type[list[T]], hint), metadata, provided_value
),
)
elif (
isinstance(origin, type)
and issubclass(origin, dict)
and isinstance(provided_value, dict)
):
return cast(T, provided_value)
elif is_optional_union(hint): # field is optional (e.g., str | None)
return cast(
T,
self.handle_optional_field(
cast(Type[Optional[T]], hint), metadata, provided_value
),
)
elif is_union(hint): # non-optional union (e.g., int | float)
if type_check(provided_value, hint):
return cast(T, provided_value)
raise AirModelAttributeError(
f'value {provided_value!r} does not match union type {hint}'
)
elif isinstance(hint, type): # field is an AirModel
if issubclass(hint, AirModel):
return self.handle_air_model_field(hint, metadata, provided_value) # type: ignore
elif is_dataclass(hint) and isinstance(provided_value, dict):
return cast(hint, hint(**provided_value)) # type: ignore[valid-type]
elif issubclass(hint, datetime): # field is a datetime object
return cast(hint, self.handle_datetime_field(provided_value)) # type: ignore
elif issubclass(hint, (int, str, bool, float)): # field is a primitive
return self.handle_primitive_field(hint, provided_value) # type: ignore
# Handle TypedDict types - pass through dict values
if is_typed_dict(hint) and isinstance(provided_value, dict):
return cast(T, provided_value)
# Handle Literal and Union types using type_check
if (origin is Literal or is_union(hint)) and type_check(provided_value, hint):
return provided_value # type: ignore
# No parser matched - raise to trigger fallback
raise AirModelAttributeError('No parser matched for field type')
except Exception as e:
# FALLBACK: Catch all parsing errors (type mismatches, validation, etc.)
# Return raw value to prevent SDK crashes when API changes
warnings.warn(
f'{context}: {e}. Using raw value from API. Consider updating your SDK.',
UserWarning,
stacklevel=4,
)
return provided_value # type: ignore
[docs]
def handle_list_field(
self,
hint: Type[list[T]],
metadata: Mapping[Any, Any],
provided_value: Any | list[Any],
) -> list[T]:
"""
Provided `data` argument is validated to be an actual list.
Each item in `data` list is then validated against the target type
and parsed individually.
"""
if not isinstance(provided_value, list):
raise AirModelAttributeError(
f'field data is of type `{type(provided_value).__name__}`, '
f'expected `{list}`: {provided_value}'
)
return [
self.parse_field(
get_list_arg(hint), metadata, data_item, f'item at index `{index}`'
)
for index, data_item in enumerate(provided_value)
]
[docs]
def handle_optional_field(
self,
hint: Type[Optional[T]],
metadata: Mapping[Any, Any],
provided_value: Any,
) -> Optional[T]:
if provided_value is None:
return None
return self.parse_field(
get_optional_arg(hint), metadata, provided_value, 'optional field'
)
[docs]
def handle_air_model_field(
self,
hint: Type[TAirModel_co],
metadata: Mapping[Any, Any],
provided_value: Any,
) -> TAirModel_co:
"""
`AirModel` fields are validated as follows:
- Foreign key fields are reserved for on-demand loading
- Otherwise, field is parsed as a regular `BaseModel`
"""
if metadata == hint.FIELD_FOREIGN_KEY:
if isinstance(provided_value, str) or isinstance(provided_value, UUID):
if primary_key := to_uuid(str(provided_value)):
lazy_fk = cast(
hint, # type: ignore
type(str(hint.__name__), (ForeignKeyMixin, hint), {})(
primary_key, _api=self.__api__
),
)
return lazy_fk
raise AirModelAttributeError(
f'`{hint.__name__}` can not be parsed from foreign key due to '
f'invalid UUID value: {provided_value}'
)
elif isinstance(provided_value, hint):
return provided_value
if isinstance(provided_value, dict):
fk_api = hint.get_model_api()(self.__api__)
return fk_api.load_model(provided_value)
raise AirModelAttributeError(
f'`{hint.__name__}` can not be parsed from foreign key due to '
f'invalid value: {provided_value}'
)
[docs]
def handle_datetime_field(self, provided_value: Any) -> datetime:
if isinstance(provided_value, datetime):
return provided_value
elif isinstance(provided_value, str):
value = iso_string_to_datetime(provided_value)
if value is None:
raise AirModelAttributeError(
f'field data is not a valid ISO string: {provided_value}'
)
return value
raise AirModelAttributeError(
f'`{datetime}` field can not be parsed from field data of type '
f'`{type(provided_value).__name__}`: {provided_value}'
)
[docs]
def handle_primitive_field(
self,
hint: Type[TSupportedPrimitive],
provided_value: Any,
) -> TSupportedPrimitive:
"""Primitive fields are validated for type mismatch and returned as-is."""
if not isinstance(provided_value, hint):
raise AirModelAttributeError(
f'field data is of type `{type(provided_value).__name__}`, '
f'expected `{hint.__name__}`: {provided_value}'
)
return provided_value