Source code for tests.test_foundation.test_utils

# Import functions
from explainable_rl.foundation.utils import *


[docs]class TestUtils(unittest.TestCase): """Test the utils functions."""
[docs] def test_convert_to_string(self): """Test convert_to_string function. """ state = [1, 2, 3] state_str = convert_to_string(state) assert state_str == "1,2,3"
[docs] def test_convert_to_list(self): """Test convert_to_list function. """ state_str = "1,2,3" state = convert_to_list(state_str) assert state == [1, 2, 3]
[docs] def test_decay_param(self): """Test decay_param function. """ param = 1 decay = 0.1 min_param = 0.1 param = decay_param(param, decay, min_param) assert param == 0.9
[docs] def test_load_dataset(self): """Test load_dataset function. """ dataset_50 = load_data(data_path="tests/test_env_data.csv", n_samples=50) dataset_25 = load_data(data_path="tests/test_env_data.csv", n_samples=25) assert len(dataset_50) == 50 assert len(dataset_25) == 25 assert isinstance(dataset_50, pd.DataFrame) assert isinstance(dataset_25, pd.DataFrame)
[docs] def test_split_train_test(self): """Test split_train_test function. """ dataset = load_data(data_path="tests/test_env_data.csv", n_samples=50) train_dataset, test_dataset = split_train_test(dataset, train_test_split=0.2) assert len(train_dataset) == 40 assert len(test_dataset) == 10 assert isinstance(train_dataset, pd.DataFrame) assert isinstance(test_dataset, pd.DataFrame)