Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions examples/MonteCarloExamples/scenarioBskSimAttFeedbackMC.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@
sNavTransName = "sNavTransMsg"
attGuidName = "attGuidMsg"

def run(show_plots):
def run(show_plots, recordSimParams=False):
"""This function is called by the py.test environment."""

# A MonteCarlo simulation can be created using the `MonteCarlo` module.
Expand All @@ -112,7 +112,7 @@ def run(show_plots):
monteCarlo = Controller()
monteCarlo.setSimulationFunction(scenario_AttFeedback.scenario_AttFeedback) # Required: function that configures the base scenario
monteCarlo.setExecutionFunction(scenario_AttFeedback.runScenario) # Required: function that runs the scenario
monteCarlo.setExecutionCount(4) # Required: Number of MCs to run
monteCarlo.setExecutionCount(3) # Required: Number of MCs to run

monteCarlo.setArchiveDir(path + "/scenarioBskSimAttFeedbackMC") # Optional: If/where to save retained data.
monteCarlo.setShouldDisperseSeeds(True) # Optional: Randomize the seed for each module
Expand Down Expand Up @@ -144,6 +144,7 @@ def run(show_plots):
retentionPolicy.addMessageLog(attGuidName, ["sigma_BR", "omega_BR_B"])
retentionPolicy.setDataCallback(displayPlots)
monteCarlo.addRetentionPolicy(retentionPolicy)
monteCarlo.setRecordSimParams(recordSimParams)

failures = monteCarlo.executeSimulations()

Expand Down
62 changes: 62 additions & 0 deletions src/tests/test_bskMcTestScript.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@
import shutil
import sys
import pytest
import json
import ast
import numpy as np

# Check if Bokeh is available
bokeh_spec = importlib.util.find_spec("bokeh")
Expand Down Expand Up @@ -80,3 +83,62 @@ def test_scenarioBskMcScenarios(show_plots):
shutil.rmtree(path + "/../../examples/MonteCarloExamples/scenarioBskSimAttFeedbackMC/")

assert testFailCount < 1, testMessages

def test_dispersionApplicationMc():
testMessages = []

scene_plt_dispersions = importlib.import_module('scenarioBskSimAttFeedbackMC')

try:
figureList = scene_plt_dispersions.run(False, recordSimParams=True)

except Exception as err:
testFailCount += 1
testMessages.append(f"Error in {'scenarioBskSimAttFeedbackMC'}: {str(err)}")

# check path existence
assert os.path.exists(path + "/../../examples/MonteCarloExamples/scenarioBskSimAttFeedbackMC/")

# define file paths
dispPath = os.path.join(path + "/../../examples/MonteCarloExamples/scenarioBskSimAttFeedbackMC/run0.json")
attrPath = os.path.join(path + "/../../examples/MonteCarloExamples/scenarioBskSimAttFeedbackMC/run0attributes.json")

with open(dispPath, 'r') as f:
dispData = json.load(f)
with open(attrPath, 'r') as f:
attrData = json.load(f)

# check if keys are identical
dispDataKeys = set(dispData.keys())
attrDataKeys = set(attrData.keys())

assert dispDataKeys == attrDataKeys, "Key sets differ"

for key in dispDataKeys:

dispVal = ast.literal_eval(dispData[key])

if type(dispVal) == list:
arrayDisp = np.array(dispVal).flatten()
attrVal = ast.literal_eval(attrData[key])
arrayAttr = np.array(attrVal).flatten()

np.testing.assert_allclose(
arrayAttr,
arrayDisp,
atol=1e-12,
err_msg=f"Numerical mismatch for parameter: {key} in the first simulation run"
)

else:
dispVal = dispData[key]
attrVal = attrData[key]

assert dispVal == attrVal, (
f"Mismatch for parameter: {key} in the first simulation run. "
f"Expected: '{dispVal}', Found: '{attrVal}'"
)

# Clean up
if os.path.exists(path + "/../../examples/MonteCarloExamples/scenarioBskSimAttFeedbackMC/"):
shutil.rmtree(path + "/../../examples/MonteCarloExamples/scenarioBskSimAttFeedbackMC/")
69 changes: 68 additions & 1 deletion src/utilities/MonteCarlo/Controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
import sys
import traceback
import warnings
import ast
import re

with warnings.catch_warnings():
warnings.simplefilter("ignore", category=DeprecationWarning)
Expand Down Expand Up @@ -286,6 +288,16 @@ def getRetainedDatas(self, cases):
for case in cases:
yield self.getRetainedData(case) # call this method recursively, yielding the result

def setRecordSimParams(self, recordSimParams):
"""
Record attributes from dispersions list after the dispersions are applied

Args:
recordSimParams: bool
Whether to save the attributes .json file.
"""
self.simParams.recordSimParams = recordSimParams

def getParameters(self, caseNumber):
"""
Get the parameters used for a particular run of the montecarlo
Expand Down Expand Up @@ -728,6 +740,7 @@ def __init__(self, creationFunction, executionFunction, configureFunction,
self.dispersionMag = {}
self.saveDispMag = False
self.showProgressBar = showProgressBar
self.recordSimParams = False



Expand Down Expand Up @@ -828,7 +841,18 @@ def __call__(cls, params):
for variable, value in list(modifications.items()):
if simParams.verbose:
print(f"Setting attribute {variable} to {value} on simInstance")
setattr(simInstance, variable, value)
parsedValue = ast.literal_eval(value)
cls.setNestedAttr(simInstance, variable, parsedValue)

# save attributes for verification
if simParams.recordSimParams == True:
setAttributes = {}
for variable in modifications.keys():
setAttribute = cls.getNestedAttr(simInstance, variable)
setAttributes[variable] = str(setAttribute)
recordFileName = simParams.filename + "attributes.json"
with open(recordFileName, 'w') as outfile:
json.dump(setAttributes, outfile, indent=4)

# setup data logging
if len(simParams.retentionPolicies) > 0:
Expand Down Expand Up @@ -874,6 +898,49 @@ def __call__(cls, params):
traceback.print_exc()
return (False, simParams.index) # there was an error

@staticmethod
def setNestedAttr(obj, attrString, value):
"""
A helper function to set a nested attribute on an object using a string.
Handles both attribute access ('.') and item access ('[]').
"""
parts = re.split(r'\.|\[(\d+)\]', attrString)
parts = [p for p in parts if p is not None and p != '']
currentObj = obj

# traverse all parts except the last one to get the parent object
for part in parts[:-1]:
if part.isdigit():
currentObj = currentObj[int(part)]
else:
currentObj = getattr(currentObj, part)

# set the final attribute on the parent object
lastPart = parts[-1]
if lastPart.isdigit():
currentObj[int(lastPart)] = value
else:
setattr(currentObj, lastPart, value)

@staticmethod
def getNestedAttr(obj, attrString):
"""
A helper function to get a nested attribute value on an object using a string.
Handles both attribute access ('.') and item access ('[]').
"""
parts = re.split(r'\.|\[(\d+)\]', attrString)
parts = [p for p in parts if p is not None and p != '']
currentObj = obj

# traverse all parts to get the final object/value
for part in parts:
if part.isdigit():
currentObj = currentObj[int(part)]
else:
currentObj = getattr(currentObj, part)

return currentObj

@staticmethod
def disperseSeeds(simInstance):
"""
Expand Down
Loading