Source code for air_sdk.endpoints.mixins

# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: MIT
from __future__ import annotations

import json
import warnings
from http import HTTPStatus
from typing import (
    TYPE_CHECKING,
    Any,
    Callable,
    Dict,
    Generic,
    Iterator,
    List,
    Optional,
    TypedDict,
)

from air_sdk.air_json_encoder import AirJSONEncoder
from air_sdk.air_model import DataDict, PrimaryKey, TAirModel_co
from air_sdk.exceptions import AirModelAttributeError
from air_sdk.utils import filter_missing, join_urls, raise_if_invalid_response

if TYPE_CHECKING:
    from air_sdk import AirApi
    from air_sdk.air_model import AirModel


[docs] def serialize_payload(data: Dict[str, Any] | List[Dict[str, Any]]) -> str: """Serialize the dictionary of values into json using the AirJSONEncoder.""" return json.dumps(data, indent=None, separators=(',', ':'), cls=AirJSONEncoder)
[docs] def build_resource_url( base_url: str, resource: AirModel | PrimaryKey, *paths: str ) -> str: """Build URL for resource-related endpoints. Extracts ID from AirModel object or uses PrimaryKey directly, then joins with base URL and additional path segments. Args: base_url: The base URL for the endpoint resource: An AirModel instance or a PrimaryKey (str/UUID) *paths: Additional path segments to append Returns: The constructed URL string Example: >>> build_resource_url('/api/simulations/', simulation, 'start') '/api/simulations/abc-123/start/' >>> build_resource_url('/api/nodes/', 'node-id-456', 'interfaces') '/api/nodes/node-id-456/interfaces/' """ from air_sdk.air_model import AirModel resource_id = resource.id if isinstance(resource, AirModel) else resource return join_urls(base_url, str(resource_id), *paths)
[docs] class BaseApiMixin: """A base class for API Mixins. This is primarily used for type hinting. """ __api__: AirApi url: str load_model: Callable[[DataDict], TAirModel_co]
[docs] class PaginatedResponseData(TypedDict): count: int next: Optional[str] previous: Optional[str] results: List[DataDict]
[docs] class ListApiMixin(BaseApiMixin, Generic[TAirModel_co]): """Returns an iterable of model objects. Handles pagination in the background. """
[docs] def list(self, **params: Any) -> Iterator[TAirModel_co]: """Return an iterator of model instances.""" # Filter out MISSING sentinel values params = filter_missing(**params) url = self.url # Merge default filters with provided params (params take precedence) if hasattr(self, 'default_filters') and isinstance(self.default_filters, dict): for key, value in self.default_filters.items(): params.update(self.default_filters) # Set up pagination next_url = None params.setdefault('limit', self.__api__.client.pagination_page_size) params = json.loads( serialize_payload(params) ) # Accounts for UUIDs and AirModel params while url or next_url: if isinstance(next_url, str): response = self.__api__.client.get(next_url) else: response = self.__api__.client.get(url, params=params) raise_if_invalid_response(response) paginated_response_data: PaginatedResponseData = response.json() url = None # type: ignore[assignment] next_url = paginated_response_data['next'] for obj_data in paginated_response_data['results']: yield self.load_model(obj_data)
[docs] class CreateApiMixin(BaseApiMixin, Generic[TAirModel_co]):
[docs] def create(self, *args: Any, **kwargs: Any) -> TAirModel_co: # Filter out MISSING sentinel values kwargs = filter_missing(**kwargs) # Merge default filters with provided params (params take precedence) if hasattr(self, 'default_filters') and isinstance(self.default_filters, dict): kwargs.update(self.default_filters) response = self.__api__.client.post(self.url, data=serialize_payload(kwargs)) raise_if_invalid_response(response, status_code=HTTPStatus.CREATED) return self.load_model(response.json())
[docs] class GetApiMixin(BaseApiMixin, Generic[TAirModel_co]):
[docs] def get(self, pk: PrimaryKey, **params: Any) -> TAirModel_co: detail_url = join_urls(self.url, str(pk)) response = self.__api__.client.get(detail_url, params=params) raise_if_invalid_response(response) return self.load_model(response.json())
[docs] class PutApiMixin(BaseApiMixin, Generic[TAirModel_co]):
[docs] def put(self, pk: PrimaryKey, **kwargs: Any) -> TAirModel_co: # Filter out MISSING sentinel values kwargs = filter_missing(**kwargs) response = self.__api__.client.put( join_urls(self.url, str(pk)), data=serialize_payload(kwargs) ) raise_if_invalid_response(response, status_code=HTTPStatus.OK) try: return self.load_model(response.json()) except AirModelAttributeError: # API returned partial response missing required fields # Fall back to fetching the full object via GET warnings.warn( f'PUT response missing required fields for {self.__class__.__name__} ' f'with pk={pk}, falling back to GET request', stacklevel=2, ) return self.get(pk) # type: ignore[attr-defined,no-any-return]
[docs] class PatchApiMixin(BaseApiMixin, Generic[TAirModel_co]):
[docs] def patch(self, pk: PrimaryKey, **kwargs: Any) -> TAirModel_co: # Filter out MISSING sentinel values kwargs = filter_missing(**kwargs) response = self.__api__.client.patch( join_urls(self.url, str(pk)), data=serialize_payload(kwargs) ) raise_if_invalid_response(response, status_code=HTTPStatus.OK) try: return self.load_model(response.json()) except AirModelAttributeError: # API returned partial response missing required fields # Fall back to fetching the full object via GET warnings.warn( f'PATCH response missing required fields for {self.__class__.__name__} ' f'with pk={pk}, falling back to GET request', stacklevel=2, ) return self.get(pk) # type: ignore[attr-defined,no-any-return]
[docs] class DeleteApiMixin(BaseApiMixin):
[docs] def delete(self, pk: PrimaryKey, **kwargs: Any) -> None: """Deletes the instances with the specified primary key.""" detail_url = join_urls(self.url, str(pk)) response = self.__api__.client.delete(detail_url, json=kwargs) raise_if_invalid_response( response, status_code=HTTPStatus.NO_CONTENT, data_type=None )