Source code for tests.test_explainability.test_pdp

from explainable_rl.foundation.library import *

# Import functions
from explainable_rl.explainability.pdp import PDP
from explainable_rl.foundation.engine import Engine
from explainable_rl.data_handler.data_handler import DataHandler
from tests.test_hyperparams import hyperparam_dict


[docs]class TestPDP(unittest.TestCase): """Test PDP class. """ dh = None pdp = None engine = None
[docs] @classmethod def setUpClass(cls): """Setup TestPDP class. """ dataset = pd.read_csv( hyperparam_dict["dataset"]["data_path"], sep=hyperparam_dict["dataset"]["col_delimiter"], ) cls.dh = DataHandler( hyperparam_dict=hyperparam_dict, dataset=dataset, test_dataset=dataset ) cls.engine = Engine(dh=cls.dh, hyperparam_dict=hyperparam_dict) cls.engine.create_world() cls.engine.train_agent() cls.pdp = PDP(engine=cls.engine) cls.pdp.build_data_for_plots()
[docs] def test_create_pdp(self): """Test creation of PDP object. """ assert isinstance(self.pdp, PDP)
[docs] def test_get_digitized_pdp(self): """Test digitized pdp. """ assert isinstance(self.pdp._dig_state_actions, list) assert isinstance(self.pdp._dig_state_actions_std, list) assert isinstance(self.pdp._dig_state_actions_samples, list) assert len(self.pdp._dig_state_actions) == len(self.dh.state_labels) assert len(self.pdp._dig_state_actions_std) == len(self.dh.state_labels) assert len(self.pdp._dig_state_actions_samples) == len(self.dh.state_labels)
[docs] def test_get_denorm_actions(self): """Test denormalized actions. """ assert isinstance(self.pdp._denorm_actions, list)
[docs] def test_get_denorm_states(self): """Test denormalized states. """ assert isinstance(self.pdp._denorm_states, list)