Source code for tests.test_explainability.test_shap_values

from explainable_rl.foundation.library import *

# Import functions
from explainable_rl.explainability.shap_values import ShapValues
from explainable_rl.foundation.engine import Engine
from explainable_rl.data_handler.data_handler import DataHandler
from tests.test_hyperparams import hyperparam_dict


[docs]class TestShapValues(unittest.TestCase): """Test ShapValues class. """ dh = None shap_values = None engine = None
[docs] @classmethod def setUpClass(cls): """Setup TestShapValues class. """ dataset = pd.read_csv(hyperparam_dict['dataset']['data_path'], sep=hyperparam_dict['dataset']['col_delimiter']) cls.dh = DataHandler(hyperparam_dict=hyperparam_dict, dataset=dataset, test_dataset=dataset) cls.engine = Engine(dh=cls.dh, hyperparam_dict=hyperparam_dict) cls.engine.create_world() cls.engine.train_agent() cls.shap_values = ShapValues(engine=cls.engine) cls.shap_values.sample = [9, 1, 1]
[docs] def test_create_shap_values(self): """Test creation of ShapValues object. """ assert isinstance(self.shap_values, ShapValues)
[docs] def test_verify_sample_length(self): """Test verify_sample_length method. """ result = self.shap_values.verify_sample_length() assert isinstance(result, bool)
[docs] def test_bin_sample(self): """Test bin_sample method. """ result = self.shap_values.bin_sample() assert isinstance(result, list) assert len(result) == len(self.shap_values.sample)
[docs] def test_verify_cell_availability(self): """Test verify_cell_availability method. """ binned_sample = [0, 0, 0] result = self.shap_values.verify_cell_availability(binned_sample) assert isinstance(result, bool)
[docs] def test_sample_plus_minus_samples(self): """Test sample_plus_minus_samples method. """ self.shap_values.sample = self.shap_values.normalize_sample() self.shap_values.binned_sample = self.shap_values.bin_sample() shap_ft = 0 num_bins_per_shap_ft = 10 result_plus, result_minus = self.shap_values.sample_plus_minus_samples(shap_ft, num_bins_per_shap_ft) assert isinstance(result_plus, list) assert isinstance(result_minus, list) assert len(result_plus) == len(result_minus) assert len(result_plus) == len(self.shap_values.sample)
[docs] def test_get_denorm_actions(self): """Test get_denorm_actions method. """ actions = np.array([0, 0, 0, 0, 0]) result = self.shap_values.get_denorm_actions(actions) assert isinstance(result, list) assert len(result) == len(actions)
[docs] def test_normalize_sample(self): """Test normalize_sample method. """ result = self.shap_values.normalize_sample() assert isinstance(result, list) assert len(result) == len(self.shap_values.sample)
[docs] def test_predict_action(self): """Test predict_action method. """ self.shap_values.sample = self.shap_values.normalize_sample() self.shap_values.binned_sample = self.shap_values.bin_sample() result = self.shap_values.predict_action() assert isinstance(result, float)
[docs] def test_verify_outliers(self): """Test verify_outliers method. """ binned_sample_correct = [0, 0, 0] binned_sample_wrong = [0, 0, 20] result_correct = self.shap_values.verify_outliers(binned_sample_correct) result_wrong = self.shap_values.verify_outliers(binned_sample_wrong) assert isinstance(result_correct, bool) assert isinstance(result_wrong, bool) assert result_correct == False assert result_wrong == True