Source code for tests.test_foundation.test_engine

from explainable_rl.foundation.library import *

# Import functions
from explainable_rl.foundation.engine import Engine
from explainable_rl.data_handler.data_handler import DataHandler
from explainable_rl.agents.q_learner import QLearningAgent
from explainable_rl.agents.sarsa import SarsaAgent
from explainable_rl.agents.sarsa_lambda import SarsaLambdaAgent
from explainable_rl.agents.double_q_learner import DoubleQLearner
from explainable_rl.environments.strategic_pricing_prediction import StrategicPricingPredictionMDP
from explainable_rl.environments.strategic_pricing_suggestion import StrategicPricingSuggestionMDP
from tests.test_hyperparams import hyperparam_dict


[docs]class TestEngine(unittest.TestCase): """Test the Engine class.""" dh = None
[docs] @classmethod def setUpClass(cls): """Set up the data handler for the tests.""" dataset = pd.read_csv( hyperparam_dict["dataset"]["data_path"], sep=hyperparam_dict["dataset"]["col_delimiter"], ) cls.dh = DataHandler( hyperparam_dict=hyperparam_dict, dataset=dataset, test_dataset=dataset )
[docs] def setUp(self) -> None: """Set up the engine for the tests.""" self.engine = Engine(self.dh, hyperparam_dict=hyperparam_dict)
[docs] def tearDown(self) -> None: """Tear down the engine after the tests.""" del self.engine
[docs] def test_create_world_agents(self): """Test the create_world method with different agent types.""" types = ["q_learner", "sarsa", "sarsa_lambda", "double_q_learner"] types_dict = { "q_learner": QLearningAgent, "sarsa": SarsaAgent, "sarsa_lambda": SarsaLambdaAgent, "double_q_learner": DoubleQLearner, } for agent_type in types: self.engine.agent_type = agent_type self.engine.create_world() assert isinstance(self.engine.agent, types_dict[agent_type]) assert isinstance(self.engine.env, StrategicPricingPredictionMDP)
[docs] def test_create_agent(self): """Test the create_agent method with different agent types.""" types = ["q_learner", "sarsa", "sarsa_lambda", "double_q_learner"] types_dict = { "q_learner": QLearningAgent, "sarsa": SarsaAgent, "sarsa_lambda": SarsaLambdaAgent, "double_q_learner": DoubleQLearner, } for agent_type in types: self.engine.agent_type = agent_type self.engine.env = StrategicPricingPredictionMDP(self.dh, self.engine.bins) self.engine.create_agent() assert isinstance(self.engine.agent, types_dict[agent_type]) assert isinstance(self.engine.env, StrategicPricingPredictionMDP)
[docs] def test_create_env(self): """Test the create_env method with different env types.""" types = ["strategic_pricing_predict", "strategic_pricing_suggest"] types_dict = { "strategic_pricing_predict": StrategicPricingPredictionMDP, "strategic_pricing_suggest": StrategicPricingSuggestionMDP, } for env_type in types: self.engine.env_type = env_type self.engine.create_env() self.engine.agent = QLearningAgent(self.engine.env, gamma=0.8) assert isinstance(self.engine.env, types_dict[env_type]) assert isinstance(self.engine.agent, QLearningAgent)
[docs] def test_train_agent(self): """Test the train_agent method.""" self.engine.create_world() original_q = copy.deepcopy(self.engine.agent.Q) self.engine.train_agent() assert self.engine.agent.Q is not None assert self.engine.agent.Q is not original_q
[docs] def test_get_bins(self): """Test the get_bins method.""" bins = self.engine._get_bins() target = [10, 2, 2, 5] assert isinstance(bins, list) assert len(bins) == 4 assert bins == target