from explainable_rl.foundation.library import *
# Import functions
from explainable_rl.agents.td import TD
from explainable_rl.environments.strategic_pricing_prediction import StrategicPricingPredictionMDP
from explainable_rl.data_handler.data_handler import DataHandler
from tests.test_hyperparams import hyperparam_dict
[docs]class TestTD(unittest.TestCase):
"""Test the TD class."""
dh = None
[docs] @classmethod
def setUpClass(cls) -> None:
"""Set up the test class."""
dataset = pd.read_csv(
hyperparam_dict["dataset"]["data_path"],
sep=hyperparam_dict["dataset"]["col_delimiter"],
)
cls.dh = DataHandler(
hyperparam_dict=hyperparam_dict, dataset=dataset, test_dataset=dataset
)
[docs] def setUp(self) -> None:
"""Set up the test class."""
self.env = StrategicPricingPredictionMDP(self.dh)
self.agent = TD(self.env, gamma=0.9)
[docs] def tearDown(self) -> None:
"""Tear down the test class."""
del self.agent
[docs] def test_update_q_values(self):
"""Implemented in tests for subclasses."""
pass
[docs] def test_step(self):
"""Implemented in tests for subclasses."""
pass
[docs] def test_init_q_table(self):
"""Test the init_q_table method."""
self.agent.env.bins = [10, 5, 4, 6]
self.agent._init_q_table()
assert self.agent.Q.shape == (10, 5, 4, 6)
[docs] def test_convert_to_string(self):
"""Test the convert_to_string method."""
state = [0, 5, 3, 2]
result = self.agent._convert_to_string(state)
target = "0,5,3,2"
assert result == target
[docs] def test_epsilon_greedy_policy(self):
"""Test the epsilon_greedy_policy method."""
epsilon = 0
state = [0, 0, 0]
self.agent._init_q_table()
self.agent.Q[0, 0, 0, 2] = 1.5
self.agent.state = [0, 0, 0]
result = self.agent._epsilon_greedy_policy(state=state, epsilon=epsilon)
assert result == 2
[docs] def test_create_tables(self):
"""Test the create_tables method."""
self.agent.env.bins = [10, 5, 4, 6]
self.agent.create_tables()
assert self.agent.Q.shape == (10, 5, 4, 6)
assert self.agent.state_to_action is not None
[docs] def test_fit(self):
"""Implemented in tests for subclasses."""
pass