Source code for tests.test_agents.test_sarsa_lambda

# Import functions
from tests.test_agents.test_td import TestTD
from explainable_rl.agents.sarsa_lambda import SarsaLambdaAgent
from explainable_rl.environments.strategic_pricing_prediction import StrategicPricingPredictionMDP


[docs]class TestSarsaLambda(TestTD): """Test the SarsaLambda class."""
[docs] def setUp(self) -> None: """Set up the test class.""" self.env = StrategicPricingPredictionMDP(self.dh) self.agent = SarsaLambdaAgent(self.env, gamma=0.9, lambda_=0.9)
[docs] def test_update_q_values(self): """Test the update_q_values method.""" self.agent.create_tables() self.agent.Q[0, 0, 0, 2] = 1.5 self.agent.Q[3, 0, 0, 3] = 5 state = [0, 0, 0] action = 2 epsilon = 0.5 next_state = [3, 0, 0] reward = 10 lr = 0.1 self.agent.e[0, 0, 0, 2] = 1 self.agent._update_q_values( state=state, action=action, next_state=next_state, reward=reward, epsilon=epsilon, lr=lr, ) result_Q = self.agent.Q[0, 0, 0, 2] result_e = self.agent.e[0, 0, 0, 2] target_Q = [ 1.5 + lr * (10 + 0.9 * 5 - 1.5) * 2, 1.5 + lr * (10 + 0.9 * 0 - 1.5) * 2, ] target_e = 0.9 * 0.9 * 2 assert result_Q in target_Q assert result_e == target_e