import logging
from typing import TYPE_CHECKING, Any
import numpy as np
from .mtor_types import FloatArray, OptimalControlSolution, PhaseID, ProblemProtocol
if TYPE_CHECKING:
pass
logger = logging.getLogger(__name__)
[docs]
class Solution:
"""
Optimal control solution with comprehensive data access and analysis capabilities.
Provides unified interface for accessing optimization results, trajectories,
solver diagnostics, mesh information, and adaptive refinement data. Supports
both single-phase and multiphase problems with automatic data concatenation.
**Data Access Patterns:**
**Mission-wide access (concatenates all phases):**
- `solution["variable_name"]` - Variable across all phases
- `solution["time_states"]` - State time points across all phases
- `solution["time_controls"]` - Control time points across all phases
**Phase-specific access:**
- `solution[(phase_id, "variable_name")]` - Variable in specific phase
- `solution[(phase_id, "time_states")]` - State times in specific phase
- `solution[(phase_id, "time_controls")]` - Control times in specific phase
**Existence checking:**
- `"variable_name" in solution` - Check mission-wide variable
- `(phase_id, "variable") in solution` - Check phase-specific variable
Examples:
Basic solution workflow:
>>> solution = mtor.solve_adaptive(problem)
>>> if solution.status["success"]:
... print(f"Objective: {solution.status['objective']:.6f}")
... solution.plot()
Mission-wide data access:
>>> altitude_all = solution["altitude"] # All phases concatenated
>>> velocity_all = solution["velocity"] # All phases concatenated
>>> state_times_all = solution["time_states"] # All phase state times
Phase-specific data access:
>>> altitude_p1 = solution[(1, "altitude")] # Phase 1 only
>>> velocity_p2 = solution[(2, "velocity")] # Phase 2 only
>>> state_times_p1 = solution[(1, "time_states")]
Data extraction patterns:
>>> # Final/initial values
>>> final_altitude = solution["altitude"][-1]
>>> initial_velocity = solution["velocity"][0]
>>> final_mass_p1 = solution[(1, "mass")][-1]
>>>
>>> # Extrema
>>> max_altitude = max(solution["altitude"])
>>> min_thrust_p2 = min(solution[(2, "thrust")])
Variable existence checking:
>>> if "altitude" in solution:
... altitude_data = solution["altitude"]
>>> if (2, "thrust") in solution:
... thrust_p2 = solution[(2, "thrust")]
Phase information access:
>>> for phase_id, phase_data in solution.phases.items():
... duration = phase_data["times"]["duration"]
... state_names = phase_data["variables"]["state_names"]
Solution validation:
>>> status = solution.status
>>> if status["success"]:
... objective = status["objective"]
... mission_time = status["total_mission_time"]
... else:
... print(f"Failed: {status['message']}")
"""
[docs]
def __init__(
self,
raw_solution: OptimalControlSolution | None,
problem: ProblemProtocol | None,
auto_summary: bool = True,
) -> None:
"""
Initialize solution wrapper from raw multiphase optimization results.
Args:
raw_solution: Raw optimization results from solver
problem: Problem protocol instance
auto_summary: Whether to automatically display comprehensive summary (default: True)
"""
# Store raw data for internal use and direct CasADi access
self._raw_solution = raw_solution
self._problem = problem
# Store raw CasADi objects for advanced users
self.raw_solution = raw_solution.raw_solution if raw_solution else None
self.opti = raw_solution.opti_object if raw_solution else None
# Build variable name mappings for dictionary access
if problem is not None:
self._phase_state_names = {}
self._phase_control_names = {}
for phase_id in problem._get_phase_ids():
self._phase_state_names[phase_id] = problem._get_phase_ordered_state_names(phase_id)
self._phase_control_names[phase_id] = problem._get_phase_ordered_control_names(
phase_id
)
else:
self._phase_state_names = {}
self._phase_control_names = {}
if auto_summary:
self._show_comprehensive_summary()
def _show_comprehensive_summary(self) -> None:
try:
from .summary import print_comprehensive_solution_summary
print_comprehensive_solution_summary(self)
except ImportError as e:
logger.warning(f"Could not import comprehensive summary: {e}")
except Exception as e:
logger.warning(f"Error in comprehensive summary: {e}")
@property
def status(self) -> dict[str, Any]:
"""
Complete solution status and optimization results.
Provides comprehensive optimization outcome information including
success status, objective value, and mission timing. Essential
for solution validation and performance assessment.
Returns:
Dictionary containing complete status information:
- **success** (bool): Optimization success status
- **message** (str): Detailed solver status message
- **objective** (float): Final objective function value
- **total_mission_time** (float): Sum of all phase durations
Examples:
Success checking:
>>> if solution.status["success"]:
... print("Optimization successful")
Objective extraction:
>>> objective = solution.status["objective"]
>>> mission_time = solution.status["total_mission_time"]
Error handling:
>>> status = solution.status
>>> if not status["success"]:
... print(f"Failed: {status['message']}")
... print(f"Objective: {status['objective']}") # May be NaN
Status inspection:
>>> print(f"Success: {solution.status['success']}")
>>> print(f"Message: {solution.status['message']}")
>>> print(f"Objective: {solution.status['objective']:.6e}")
>>> print(f"Mission time: {solution.status['total_mission_time']:.3f}")
"""
if self._raw_solution is None:
return {
"success": False,
"message": "No solution available",
"objective": float("nan"),
"total_mission_time": float("nan"),
}
# Calculate total mission time
if self._raw_solution.phase_initial_times and self._raw_solution.phase_terminal_times:
earliest_start = min(self._raw_solution.phase_initial_times.values())
latest_end = max(self._raw_solution.phase_terminal_times.values())
total_time = latest_end - earliest_start
else:
total_time = float("nan")
return {
"success": self._raw_solution.success,
"message": self._raw_solution.message,
"objective": self._raw_solution.objective
if self._raw_solution.objective is not None
else float("nan"),
"total_mission_time": total_time,
}
@property
def phases(self) -> dict[PhaseID, dict[str, Any]]:
"""
Comprehensive phase information and data organization.
Provides detailed data for each phase including timing, variables,
mesh configuration, and trajectory arrays. Essential for understanding
multiphase structure and accessing phase-specific information.
Returns:
Dictionary mapping phase IDs to phase data:
**Phase data structure:**
- **times** (dict): Phase timing
- initial (float): Phase start time
- final (float): Phase end time
- duration (float): Phase duration
- **variables** (dict): Variable information
- state_names (list): State variable names
- control_names (list): Control variable names
- num_states (int): Number of states
- num_controls (int): Number of controls
- **mesh** (dict): Mesh configuration
- polynomial_degrees (list): Polynomial degree per interval
- mesh_nodes (FloatArray): Mesh node locations
- num_intervals (int): Total intervals
- **time_arrays** (dict): Time coordinates
- states (FloatArray): State time points
- controls (FloatArray): Control time points
- **integrals** (float | FloatArray | None): Integral values
Examples:
Phase iteration:
>>> for phase_id, phase_data in solution.phases.items():
... print(f"Phase {phase_id}")
Timing information:
>>> phase_1 = solution.phases[1]
>>> duration = phase_1["times"]["duration"]
>>> start_time = phase_1["times"]["initial"]
>>> end_time = phase_1["times"]["final"]
Variable information:
>>> variables = solution.phases[1]["variables"]
>>> state_names = variables["state_names"] # ["x", "y", "vx", "vy"]
>>> control_names = variables["control_names"] # ["thrust_x", "thrust_y"]
>>> num_states = variables["num_states"] # 4
>>> num_controls = variables["num_controls"] # 2
Mesh information:
>>> mesh = solution.phases[1]["mesh"]
>>> degrees = mesh["polynomial_degrees"] # [6, 8, 6]
>>> intervals = mesh["num_intervals"] # 3
>>> nodes = mesh["mesh_nodes"] # [-1, -0.5, 0.5, 1]
Time arrays:
>>> time_arrays = solution.phases[1]["time_arrays"]
>>> state_times = time_arrays["states"] # State time coordinates
>>> control_times = time_arrays["controls"] # Control time coordinates
Integral values:
>>> integrals = solution.phases[1]["integrals"]
>>> if isinstance(integrals, float):
... single_integral = integrals # Single integral
>>> else:
... multiple_integrals = integrals # Array of integrals
"""
if self._raw_solution is None:
return {}
phases_data = {}
for phase_id in self._get_phase_ids():
# Time information
initial_time = self._raw_solution.phase_initial_times.get(phase_id, float("nan"))
final_time = self._raw_solution.phase_terminal_times.get(phase_id, float("nan"))
duration = (
final_time - initial_time
if not (np.isnan(initial_time) or np.isnan(final_time))
else float("nan")
)
# Variable information
state_names = self._phase_state_names.get(phase_id, [])
control_names = self._phase_control_names.get(phase_id, [])
# Mesh information
polynomial_degrees = self._raw_solution.phase_mesh_intervals.get(phase_id, [])
mesh_nodes = self._raw_solution.phase_mesh_nodes.get(
phase_id, np.array([], dtype=np.float64)
)
# Time arrays
time_states = self._raw_solution.phase_time_states.get(
phase_id, np.array([], dtype=np.float64)
)
time_controls = self._raw_solution.phase_time_controls.get(
phase_id, np.array([], dtype=np.float64)
)
# Integrals
integrals = self._raw_solution.phase_integrals.get(phase_id, None)
phases_data[phase_id] = {
"times": {"initial": initial_time, "final": final_time, "duration": duration},
"variables": {
"state_names": state_names.copy(),
"control_names": control_names.copy(),
"num_states": len(state_names),
"num_controls": len(control_names),
},
"mesh": {
"polynomial_degrees": polynomial_degrees.copy() if polynomial_degrees else [],
"mesh_nodes": mesh_nodes.copy()
if mesh_nodes.size > 0
else np.array([], dtype=np.float64),
"num_intervals": len(polynomial_degrees) if polynomial_degrees else 0,
},
"time_arrays": {"states": time_states.copy(), "controls": time_controls.copy()},
"integrals": integrals,
}
return phases_data
@property
def parameters(self) -> dict[str, Any] | None:
"""
Static parameter optimization results and information.
Provides access to optimized static parameters with comprehensive
parameter information. Returns None if no parameters were defined.
Returns:
Parameter information dictionary or None:
- **values** (FloatArray): Optimized parameter values
- **names** (list[str] | None): Parameter names if available
- **count** (int): Number of static parameters
Examples:
Parameter existence check:
>>> if solution.parameters is not None:
... print("Problem has static parameters")
Parameter access:
>>> params = solution.parameters
>>> if params:
... values = params["values"] # [500.0, 1500.0, 0.1]
... count = params["count"] # 3
... names = params["names"] # ["mass", "thrust", "drag"] or None
Named parameter access:
>>> params = solution.parameters
>>> if params and params["names"]:
... for name, value in zip(params["names"], params["values"]):
... print(f"{name}: {value:.6f}")
Unnamed parameter access:
>>> params = solution.parameters
>>> if params:
... for i, value in enumerate(params["values"]):
... print(f"Parameter {i}: {value:.6f}")
No parameters case:
>>> if solution.parameters is None:
... print("No static parameters in problem")
"""
if self._raw_solution is None or self._raw_solution.static_parameters is None:
return None
# Try to get parameter names if available
param_names = None
if self._problem is not None and hasattr(self._problem, "_static_parameters"):
try:
static_params = self._problem._static_parameters
if hasattr(static_params, "parameter_names"):
param_names = static_params.parameter_names.copy()
except (AttributeError, IndexError):
pass
return {
"values": self._raw_solution.static_parameters.copy(),
"names": param_names,
"count": len(self._raw_solution.static_parameters),
}
@property
def adaptive(self) -> dict[str, Any] | None:
"""
Adaptive mesh refinement algorithm results and convergence diagnostics.
Provides comprehensive adaptive algorithm performance data including
convergence status, error estimates, and refinement statistics. Only
available for adaptive solver solutions.
Returns:
Adaptive algorithm data dictionary or None:
- **converged** (bool): Algorithm convergence status
- **iterations** (int): Refinement iterations performed
- **target_tolerance** (float): Target error tolerance
- **phase_converged** (dict): Per-phase convergence status
- **final_errors** (dict): Final error estimates per phase
- **gamma_factors** (dict): Normalization factors per phase
Examples:
Adaptive solution check:
>>> if solution.adaptive:
... print("Adaptive solution available")
Convergence assessment:
>>> adaptive_info = solution.adaptive
>>> if adaptive_info:
... converged = adaptive_info["converged"]
... iterations = adaptive_info["iterations"]
... tolerance = adaptive_info["target_tolerance"]
Per-phase convergence:
>>> if solution.adaptive:
... for phase_id, converged in solution.adaptive["phase_converged"].items():
... status = "✓" if converged else "✗"
... print(f"Phase {phase_id}: {status}")
Error analysis:
>>> if solution.adaptive:
... for phase_id, errors in solution.adaptive["final_errors"].items():
... max_error = max(errors) if errors else 0.0
... print(f"Phase {phase_id} max error: {max_error:.2e}")
Algorithm statistics:
>>> adaptive = solution.adaptive
>>> if adaptive:
... print(f"Converged: {adaptive['converged']}")
... print(f"Iterations: {adaptive['iterations']}")
... print(f"Target tolerance: {adaptive['target_tolerance']:.1e}")
Fixed mesh solution:
>>> if solution.adaptive is None:
... print("Fixed mesh solution - no adaptive data")
"""
if self._raw_solution is None or self._raw_solution.adaptive_data is None:
return None
adaptive_data = self._raw_solution.adaptive_data
return {
"converged": adaptive_data.converged,
"iterations": adaptive_data.total_iterations,
"target_tolerance": adaptive_data.target_tolerance,
"phase_converged": adaptive_data.phase_converged.copy(),
"final_errors": {
phase_id: errors.copy()
for phase_id, errors in adaptive_data.final_phase_error_estimates.items()
},
"gamma_factors": {
phase_id: factors.copy() if factors is not None else None
for phase_id, factors in adaptive_data.phase_gamma_factors.items()
},
}
def _get_phase_ids(self) -> list[PhaseID]:
if self._raw_solution is None:
return []
return sorted(self._raw_solution.phase_initial_times.keys())
def __getitem__(self, key: str | tuple[PhaseID, str]) -> FloatArray:
if not self.status["success"]:
raise RuntimeError(
f"Cannot access variable '{key}': Solution failed with message: {self.status['message']}"
)
if isinstance(key, tuple):
return self._get_by_tuple_key(key)
elif isinstance(key, str):
return self._get_by_string_key(key)
else:
raise KeyError(
f"Invalid key type: {type(key)}. Use string or (phase_id, variable_name) tuple"
)
def _get_by_tuple_key(self, key: tuple[PhaseID, str]) -> FloatArray:
if len(key) != 2:
raise KeyError("Tuple key must have exactly 2 elements: (phase_id, variable_name)")
# Explicit None check for mypy type safety
if self._raw_solution is None:
raise RuntimeError("Cannot access variable: No solution data available")
phase_id, var_name = key
if phase_id not in self._get_phase_ids():
raise KeyError(f"Phase {phase_id} not found in solution")
if var_name == "time_states":
return self._raw_solution.phase_time_states.get(
phase_id, np.array([], dtype=np.float64)
)
elif var_name == "time_controls":
return self._raw_solution.phase_time_controls.get(
phase_id, np.array([], dtype=np.float64)
)
if phase_id in self._phase_state_names and var_name in self._phase_state_names[phase_id]:
var_index = self._phase_state_names[phase_id].index(var_name)
if phase_id in self._raw_solution.phase_states and var_index < len(
self._raw_solution.phase_states[phase_id]
):
return self._raw_solution.phase_states[phase_id][var_index]
if (
phase_id in self._phase_control_names
and var_name in self._phase_control_names[phase_id]
):
var_index = self._phase_control_names[phase_id].index(var_name)
if phase_id in self._raw_solution.phase_controls and var_index < len(
self._raw_solution.phase_controls[phase_id]
):
return self._raw_solution.phase_controls[phase_id][var_index]
raise KeyError(f"Variable '{var_name}' not found in phase {phase_id}")
def _get_by_string_key(self, key: str) -> FloatArray:
matching_arrays = []
for phase_id in self._get_phase_ids():
try:
phase_data = self[(phase_id, key)]
matching_arrays.append(phase_data)
except KeyError:
continue
if not matching_arrays:
all_vars = []
for phase_id in self._get_phase_ids():
phase_vars = (
self._phase_state_names.get(phase_id, [])
+ self._phase_control_names.get(phase_id, [])
+ ["time_states", "time_controls"]
)
all_vars.extend([f"({phase_id}, '{var}')" for var in phase_vars])
raise KeyError(f"Variable '{key}' not found in any phase. Available: {all_vars}")
if len(matching_arrays) == 1:
return matching_arrays[0]
return np.concatenate(matching_arrays, dtype=np.float64)
def __contains__(self, key: str | tuple[PhaseID, str]) -> bool:
try:
self[key]
return True
except KeyError:
return False
[docs]
def plot(
self,
phase_id: PhaseID | None = None,
*variable_names: str,
figsize: tuple[float, float] = (12.0, 8.0),
show_phase_boundaries: bool = True,
) -> None:
"""
Plot solution trajectories with comprehensive customization options.
Creates trajectory plots with automatic formatting, phase boundaries,
and flexible variable selection. Supports both single-phase and
multiphase visualization with professional styling.
Args:
phase_id: Phase selection:
- None: Plot all phases (default)
- int: Plot specific phase only
variable_names: Variable selection:
- Empty: Plot all variables
- Specified: Plot only named variables
figsize: Figure size tuple (width, height)
show_phase_boundaries: Display vertical lines at phase transitions
Examples:
Basic plotting:
>>> solution.plot() # All variables, all phases
Specific phase:
>>> solution.plot(phase_id=1) # Phase 1 only
Selected variables:
>>> solution.plot(phase_id=None, "altitude", "velocity", "thrust")
Custom formatting:
>>> solution.plot(
... figsize=(16, 10),
... show_phase_boundaries=True
... )
Phase-specific variables:
>>> solution.plot(1, "x_position", "y_position") # Phase 1 positions
No phase boundaries:
>>> solution.plot(show_phase_boundaries=False)
"""
from .plot import plot_multiphase_solution
plot_multiphase_solution(self, phase_id, variable_names, figsize, show_phase_boundaries)
[docs]
def summary(self, comprehensive: bool = True) -> None:
"""
Display solution summary with comprehensive details and diagnostics.
Prints detailed overview including solver status, phase information,
mesh details, and adaptive algorithm results. Essential for solution
validation and performance analysis.
Args:
comprehensive: Summary detail level:
- True: Full detailed summary (default)
- False: Concise key information only
Examples:
Full summary:
>>> solution.summary() # Comprehensive details
Concise summary:
>>> solution.summary(comprehensive=False) # Key information only
Manual summary control:
>>> # Solve without automatic summary
>>> solution = mtor.solve_adaptive(problem, show_summary=False)
>>> # Display summary when needed
>>> solution.summary()
Conditional summary:
>>> if solution.status["success"]:
... solution.summary()
... else:
... solution.summary(comprehensive=False) # Brief failure info
"""
if comprehensive:
try:
from .summary import print_comprehensive_solution_summary
print_comprehensive_solution_summary(self)
except ImportError as e:
logger.warning(f"Could not import comprehensive summary: {e}")
except Exception as e:
logger.warning(f"Error in comprehensive summary: {e}")
else:
# Simple summary
print(f"Solution Status: {self.status['success']}")
print(f"Objective: {self.status['objective']:.6e}")
print(f"Total Mission Time: {self.status['total_mission_time']:.6f}")
print(f"Phases: {len(self.phases)}")
if self.adaptive:
print(f"Adaptive: Converged in {self.adaptive['iterations']} iterations")