Source code for air_sdk.endpoints.simulations

# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
# All rights reserved.
# SPDX-License-Identifier: MIT
from __future__ import annotations

import json
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from functools import cached_property, singledispatchmethod
from http import HTTPStatus
from io import TextIOBase
from pathlib import Path
from typing import TYPE_CHECKING, Any, Iterator, Literal

from air_sdk.air_model import (
    AirModel,
    BaseEndpointAPI,
    PrimaryKey,
)
from air_sdk.bc import (
    BaseCompatMixin,
    SimulationCompatMixin,
    SimulationEndpointAPICompatMixin,
)
from air_sdk.const import TopologyFormat
from air_sdk.endpoints import mixins
from air_sdk.endpoints.history import History
from air_sdk.exceptions import AirUnexpectedResponse
from air_sdk.utils import (
    join_urls,
    raise_if_invalid_response,
    validate_payload_types,
)
from air_sdk.utils import wait_for_state as wait_for_state_util

if TYPE_CHECKING:
    from air_sdk.endpoints.interfaces import InterfaceEndpointAPI
    from air_sdk.endpoints.node_instructions import NodeInstructionEndpointAPI
    from air_sdk.endpoints.nodes import NodeEndpointAPI
    from air_sdk.endpoints.services import Service, ServiceEndpointAPI
    from air_sdk.endpoints.ztp_scripts import ZTPScript


[docs] @dataclass(eq=False) class Simulation(BaseCompatMixin, SimulationCompatMixin, AirModel): id: str = field(repr=False) name: str created: datetime = field(repr=False) modified: datetime = field(repr=False) state: str creator: str auto_oob_enabled: bool | None = field(repr=False) disable_auto_oob_dhcp: bool | None = field(repr=False) auto_netq_enabled: bool | None = field(repr=False) netq_username: str | None = field(default=None, repr=False) netq_password: str | None = field(default=None, repr=False) sleep_at: datetime | None = field(default=None, repr=False) expires_at: datetime | None = field(default=None, repr=False) documentation: str | None = field(default=None, repr=False) complete_checkpoint_count: int = field(default=0, repr=False)
[docs] @classmethod def get_model_api(cls) -> type[SimulationEndpointAPI]: """Returns the respective `AirModelAPI` type for this model""" return SimulationEndpointAPI
@property def model_api(self) -> SimulationEndpointAPI: """The current model API instance.""" return self.get_model_api()(self.__api__)
[docs] def update(self, **kwargs: Any) -> None: self.model_api.update(simulation=self, **kwargs)
[docs] def enable_auto_oob(self, **kwargs: Any) -> None: self.model_api.enable_auto_oob(simulation=self, **kwargs)
[docs] def disable_auto_oob(self, **kwargs: Any) -> None: self.model_api.disable_auto_oob(simulation=self, **kwargs)
[docs] def enable_auto_netq(self, **kwargs: Any) -> None: self.model_api.enable_auto_netq(simulation=self, **kwargs)
[docs] def disable_auto_netq(self, **kwargs: Any) -> None: self.model_api.disable_auto_netq(simulation=self, **kwargs)
[docs] @validate_payload_types def create_ztp_script(self, *, content: str, **kwargs: Any) -> ZTPScript: ztp_script_endpoint_api = self.__api__.ztp_scripts url = ztp_script_endpoint_api.url.format(simulation_id=self.id) data = mixins.serialize_payload({'content': content, **kwargs}) response = self.__api__.client.post(url, data=data) raise_if_invalid_response(response, status_code=HTTPStatus.CREATED) self.refresh() self.ztp_script = ztp_script_endpoint_api.load_model( response.json() | {'simulation': self} ) return self.ztp_script
[docs] @validate_payload_types def update_ztp_script(self, *, content: str, **kwargs: Any) -> ZTPScript: ztp_script_endpoint_api = self.__api__.ztp_scripts updated_script = ztp_script_endpoint_api.patch( simulation=self, content=content, **kwargs ) self.ztp_script = updated_script return updated_script
[docs] def delete_ztp_script(self) -> None: ztp_script_endpoint_api = self.__api__.ztp_scripts ztp_script_endpoint_api.delete(simulation=self) self.clear_cached_property('ztp_script')
[docs] def export(self, **kwargs: Any) -> dict[str, Any]: return self.model_api.export(simulation=self, **kwargs)
[docs] def clone(self, **kwargs: Any) -> Simulation: return self.model_api.clone(simulation=self, **kwargs)
[docs] def start(self, **kwargs: Any) -> None: self.model_api.start(simulation=self, **kwargs)
[docs] def shutdown(self, **kwargs: Any) -> None: self.model_api.shutdown(simulation=self, **kwargs)
[docs] def rebuild(self, **kwargs: Any) -> None: self.model_api.rebuild(simulation=self, **kwargs)
[docs] def wait_for_state( self, target_states: str | list[str], timeout: timedelta | None = None, poll_interval: timedelta | None = None, error_states: str | list[str] | None = None, ) -> None: wait_for_state_util( self, target_states, state_field='state', timeout=timeout, poll_interval=poll_interval, error_states=error_states, )
[docs] def set_sleep_time(self, sleep_at: datetime | None) -> None: self.update(sleep_at=sleep_at)
[docs] def set_expire_time(self, expires_at: datetime | None) -> None: self.update(expires_at=expires_at)
# RELATED MODELS @cached_property def ztp_script(self) -> ZTPScript | None: try: return self.__api__.ztp_scripts.get(simulation=self) except AirUnexpectedResponse as e: # Only return None for 404 (ZTP script not found) if e.status_code == HTTPStatus.NOT_FOUND: return None # Re-raise other errors (500, 403, None, etc.) as they indicate real problems raise
[docs] def get_history(self, **kwargs: Any) -> Iterator[History]: return self.__api__.histories.list( model='simulation', object_id=self.id, **kwargs )
@property def nodes(self) -> NodeEndpointAPI: from air_sdk.endpoints.nodes import NodeEndpointAPI return NodeEndpointAPI( self.__api__, default_filters={'simulation': str(self.__pk__)} ) @property def interfaces(self) -> InterfaceEndpointAPI: from air_sdk.endpoints.interfaces import InterfaceEndpointAPI return InterfaceEndpointAPI( self.__api__, default_filters={'simulation': str(self.__pk__)} ) @property def node_instructions(self) -> NodeInstructionEndpointAPI: from air_sdk.endpoints.node_instructions import NodeInstructionEndpointAPI return NodeInstructionEndpointAPI( self.__api__, default_filters={'simulation': str(self.__pk__)} ) @property def services(self) -> ServiceEndpointAPI: from air_sdk.endpoints.services import ServiceEndpointAPI return ServiceEndpointAPI( self.__api__, default_filters={'simulation': str(self.__pk__)} )
[docs] def create_service(self, **kwargs: Any) -> Service: return self.model_api.create_service(simulation=self, **kwargs)
[docs] def node_bulk_assign(self, **kwargs: Any) -> None: return self.model_api.node_bulk_assign(simulation=self, **kwargs)
[docs] def node_bulk_reset(self, **kwargs: Any) -> None: self.model_api.node_bulk_reset(simulation=self, **kwargs)
[docs] def node_bulk_rebuild(self, **kwargs: Any) -> None: self.model_api.node_bulk_rebuild(simulation=self, **kwargs)
[docs] class SimulationEndpointAPI( SimulationEndpointAPICompatMixin, mixins.ListApiMixin[Simulation], mixins.CreateApiMixin[Simulation], mixins.GetApiMixin[Simulation], mixins.PatchApiMixin[Simulation], mixins.DeleteApiMixin, BaseEndpointAPI[Simulation], ): API_PATH = 'simulations' IMPORT_PATH = 'import' EXPORT_PATH = 'export' START_PATH = 'start' SHUTDOWN_PATH = 'shutdown' PARSE_PATH = 'parse' REBUILD_PATH = 'rebuild' NODE_BULK_ASSIGN_PATH = 'nodes/bulk-assign' NODE_BULK_RESET_PATH = 'nodes/bulk-reset' NODE_BULK_REBUILD_PATH = 'nodes/bulk-rebuild' model = Simulation @singledispatchmethod def _resolve_json_from_source( self, source: dict[str, Any] | str | Path | TextIOBase ) -> dict[str, Any]: """Resolve JSON data from various sources. Handles source as dict, JSON string, file path, or file object. Args: source: JSON data as dict, JSON string, Path to JSON file, or file object Returns: Resolved dict Raises: ValueError: If resolved content is not a dict JSONDecodeError: If string/file content is not valid JSON FileNotFoundError: If file path does not exist """ # Default implementation: handle dict if not isinstance(source, dict): raise ValueError(f'JSON data must be a dict, got {type(source)}') return source @_resolve_json_from_source.register def _(self, source: TextIOBase) -> dict[str, Any]: """Resolve JSON from file object.""" resolved = json.load(source) if not isinstance(resolved, dict): raise ValueError(f'JSON data must be a dict, got {type(resolved)}') return resolved @_resolve_json_from_source.register def _(self, source: Path) -> dict[str, Any]: """Resolve JSON from Path object.""" with source.open('r') as f: resolved = json.load(f) if not isinstance(resolved, dict): raise ValueError(f'JSON data must be a dict, got {type(resolved)}') return resolved @_resolve_json_from_source.register def _(self, source: str) -> dict[str, Any]: """Resolve JSON from string (file path or JSON string).""" # Try as file path first path = Path(source) if path.exists() and path.is_file(): with path.open('r') as f: resolved = json.load(f) else: # Parse as JSON string # TODO: Consider raising FileNotFoundError for path-like strings # that don't exist (e.g., contain '/', '\', or end with '.json') # to provide clearer error messages instead of JSONDecodeError resolved = json.loads(source) if not isinstance(resolved, dict): raise ValueError(f'JSON data must be a dict, got {type(resolved)}') return resolved def _resolve_simulation_manifest( self, simulation_manifest: dict[str, Any] | str | Path | TextIOBase ) -> dict[str, Any]: # First, resolve the top-level manifest resolved = self._resolve_json_from_source(simulation_manifest) # Second, resolve the content field if it needs resolution (for JSON format) if 'content' in resolved and 'format' in resolved: format_type = resolved['format'] if ( isinstance(format_type, str) and format_type.upper() == TopologyFormat.JSON ): content = resolved['content'] # Only resolve if content is not already a dict if not isinstance(content, dict): resolved['content'] = self._resolve_json_from_source(content) return resolved def _resolve_dot_topology(self, topology: str | Path | TextIOBase) -> str: if isinstance(topology, TextIOBase): resolved = topology.read() elif isinstance(topology, Path): with topology.open('r') as f: resolved = f.read() elif isinstance(topology, str): # Try as file path first path = Path(topology) if path.exists() and path.is_file(): with path.open('r') as f: resolved = f.read() else: # Use as raw DOT content resolved = topology if not isinstance(resolved, str): raise ValueError( f'DOT topology format requires str content, got {type(resolved)}' ) return resolved def _wait_and_start_simulation( self, simulation: Simulation, timeout: timedelta = timedelta(seconds=120), poll_interval: timedelta = timedelta(seconds=2), ) -> None: """Wait for simulation to be ready and then start it. Args: simulation: The simulation to wait for and start timeout: Maximum time to wait (default: 120 seconds) poll_interval: Time between status checks (default: 2 seconds) Raises: ValueError: If simulation enters an error state TimeoutError: If timeout is exceeded before simulation is ready """ # TODO: Use constants when Sim state MR is merged # Wait for simulation to be INACTIVE (ready to start) simulation.wait_for_state( target_states='INACTIVE', timeout=timeout, poll_interval=poll_interval, error_states='INVALID', ) # Start the simulation simulation.start()
[docs] @validate_payload_types def import_from_data( self, *, attempt_start: bool = False, start_timeout: timedelta | None = None, **kwargs: Any, ) -> Simulation: # Let API validate format/content and all parameters response = self.__api__.client.post( join_urls(self.url, self.IMPORT_PATH), data=mixins.serialize_payload(kwargs), ) raise_if_invalid_response(response, status_code=HTTPStatus.CREATED) sim: Simulation = self.load_model(response.json()) # If attempt_start, wait for simulation to be ready and start it if attempt_start: if start_timeout is not None: self._wait_and_start_simulation(sim, timeout=start_timeout) else: self._wait_and_start_simulation(sim) return sim
[docs] @validate_payload_types def import_from_simulation_manifest( self, *, simulation_manifest: dict[str, Any] | str | Path | TextIOBase, attempt_start: bool = False, start_timeout: timedelta | None = None, ) -> Simulation: # Resolve manifest (including content field) resolved_manifest = self._resolve_simulation_manifest(simulation_manifest) # Pass attempt_start and start_timeout as separate parameters # (they should not be in the manifest itself) return self.import_from_data( **resolved_manifest, attempt_start=attempt_start, start_timeout=start_timeout, )
[docs] @validate_payload_types def import_from_dot( self, *, topology_data: str | Path | TextIOBase, attempt_start: bool = False, start_timeout: timedelta | None = None, **kwargs: Any, ) -> Simulation: # Resolve DOT topology data resolved_content = self._resolve_dot_topology(topology_data) # Require name to be provided # Note: BC layer already maps 'title' to 'name', so we only check for 'name' here # TODO: Remove this once the API allows create without name and let the DOT2JSON # parser to extract the name from the DOT content if 'name' not in kwargs: raise ValueError( 'The "name" (or "title") parameter is required when importing ' 'DOT topology. Please provide a name for the simulation.' ) # Call import_from_data with DOT format return self.import_from_data( format=TopologyFormat.DOT, content=resolved_content, attempt_start=attempt_start, start_timeout=start_timeout, **kwargs, )
[docs] @validate_payload_types def update(self, *, simulation: Simulation | PrimaryKey, **kwargs: Any) -> Simulation: sim_id = simulation.id if isinstance(simulation, Simulation) else simulation result = self.patch(sim_id, **kwargs) if isinstance(simulation, Simulation): # Refresh the original object using the patch response data simulation.__refresh__(refreshed_obj=result) return result
[docs] @validate_payload_types def export( self, *, simulation: Simulation | PrimaryKey, topology_format: Literal['JSON'] = 'JSON', **kwargs: Any, ) -> dict[str, Any]: sim_id = simulation.id if isinstance(simulation, Simulation) else simulation url = join_urls(self.url, str(sim_id), self.EXPORT_PATH) response = self.__api__.client.get( url, params=json.loads( mixins.serialize_payload({'topology_format': topology_format, **kwargs}) ), ) raise_if_invalid_response(response) response_data: dict[str, Any] = response.json() return response_data
[docs] @validate_payload_types def clone(self, *, simulation: Simulation | PrimaryKey, **kwargs: Any) -> Simulation: url = join_urls(self.url, 'clone') response = self.__api__.client.post( url, data=mixins.serialize_payload({'simulation': simulation, **kwargs}) ) raise_if_invalid_response(response, status_code=HTTPStatus.CREATED) return self.load_model(response.json())
[docs] @validate_payload_types def enable_auto_oob( self, *, simulation: Simulation | PrimaryKey, **kwargs: Any ) -> None: url = mixins.build_resource_url(self.url, simulation, 'enable-auto-oob') response = self.__api__.client.patch(url, data=mixins.serialize_payload(kwargs)) raise_if_invalid_response(response, data_type=None) if isinstance(simulation, Simulation): simulation.refresh()
[docs] @validate_payload_types def disable_auto_oob( self, *, simulation: Simulation | PrimaryKey, **kwargs: Any ) -> None: url = mixins.build_resource_url(self.url, simulation, 'disable-auto-oob') response = self.__api__.client.patch(url, data=mixins.serialize_payload(kwargs)) raise_if_invalid_response(response, data_type=None) if isinstance(simulation, Simulation): simulation.refresh()
[docs] @validate_payload_types def enable_auto_netq( self, *, simulation: Simulation | PrimaryKey, **kwargs: Any ) -> None: url = mixins.build_resource_url(self.url, simulation, 'enable-auto-netq') response = self.__api__.client.patch(url, data=mixins.serialize_payload(kwargs)) raise_if_invalid_response(response, data_type=None) if isinstance(simulation, Simulation): simulation.refresh()
[docs] @validate_payload_types def disable_auto_netq( self, *, simulation: Simulation | PrimaryKey, **kwargs: Any ) -> None: url = mixins.build_resource_url(self.url, simulation, 'disable-auto-netq') response = self.__api__.client.patch(url, data=mixins.serialize_payload(kwargs)) raise_if_invalid_response(response, data_type=None) if isinstance(simulation, Simulation): simulation.refresh()
[docs] @validate_payload_types def start(self, *, simulation: Simulation | PrimaryKey, **kwargs: Any) -> None: url = mixins.build_resource_url(self.url, simulation, self.START_PATH) response = self.__api__.client.patch(url, data=mixins.serialize_payload(kwargs)) if isinstance(simulation, Simulation): simulation.refresh() raise_if_invalid_response(response)
[docs] @validate_payload_types def rebuild(self, *, simulation: Simulation | PrimaryKey, **kwargs: Any) -> None: url = mixins.build_resource_url(self.url, simulation, self.REBUILD_PATH) response = self.__api__.client.patch(url, data=mixins.serialize_payload(kwargs)) if isinstance(simulation, Simulation): simulation.refresh() raise_if_invalid_response(response)
[docs] @validate_payload_types def shutdown(self, *, simulation: Simulation | PrimaryKey, **kwargs: Any) -> None: url = mixins.build_resource_url(self.url, simulation, self.SHUTDOWN_PATH) response = self.__api__.client.patch(url, data=mixins.serialize_payload(kwargs)) if isinstance(simulation, Simulation): simulation.refresh() raise_if_invalid_response(response)
[docs] @validate_payload_types def create_service( self, *, simulation: Simulation | PrimaryKey, **kwargs: Any, ) -> Service: # Get simulation object if needed sim = simulation if isinstance(simulation, Simulation) else self.get(simulation) # BC: If 'interface' param provided, delegate to services API # (which handles 'node:interface' parsing via ServiceEndpointAPICompatMixin) if 'interface' in kwargs: # Pass simulation for BC 'node:interface' string resolution return sim.services.create(simulation=sim, **kwargs) # type: ignore[call-arg] # V3: Extract required parameters node_name = kwargs.pop('node_name', None) interface_name = kwargs.pop('interface_name', None) if not node_name or not interface_name: raise ValueError( "Must provide either 'interface' parameter or both " "'node_name' and 'interface_name' parameters" ) # Resolve node name to Node object node_obj = next(sim.nodes.list(name=node_name), None) if not node_obj: raise ValueError(f'Node "{node_name}" not found in simulation') # Resolve interface name to interface ID interface_obj = next( self.__api__.interfaces.list(node=node_obj.id, name=interface_name), None, ) if not interface_obj: raise ValueError(f'Interface "{interface_name}" not found on node') interface_id = str(interface_obj.id) return sim.services.create(interface=interface_id, **kwargs)
[docs] def parse(self, **kwargs: Any) -> dict[str, Any]: url = join_urls(self.url, self.PARSE_PATH) response = self.__api__.client.post(url, data=mixins.serialize_payload(kwargs)) raise_if_invalid_response(response) response_data: dict[str, Any] = response.json() return response_data
[docs] def node_bulk_assign(self, **kwargs: Any) -> None: url = join_urls(self.url, self.NODE_BULK_ASSIGN_PATH) response = self.__api__.client.patch(url, data=mixins.serialize_payload(kwargs)) raise_if_invalid_response( response, status_code=HTTPStatus.NO_CONTENT, data_type=None )
[docs] def node_bulk_reset(self, **kwargs: Any) -> None: url = join_urls(self.url, self.NODE_BULK_RESET_PATH) response = self.__api__.client.patch(url, data=mixins.serialize_payload(kwargs)) raise_if_invalid_response( response, status_code=HTTPStatus.NO_CONTENT, data_type=None )
[docs] def node_bulk_rebuild(self, **kwargs: Any) -> None: url = join_urls(self.url, self.NODE_BULK_REBUILD_PATH) response = self.__api__.client.patch(url, data=mixins.serialize_payload(kwargs)) raise_if_invalid_response( response, status_code=HTTPStatus.NO_CONTENT, data_type=None )