explainable_rl.agents package
Submodules
explainable_rl.agents.double_q_learner module
- class DoubleQLearner(env, gamma, verbose=False)[source]
Bases:
TDDouble Q-Learner agent.
- __init__(env, gamma, verbose=False)[source]
Initialise the agent class.
- Parameters
env (MDP) – MDP object.
gamma (float) – Discount factor.
verbose (bool) – Defines whether print statements should be called.
- _step(epsilon, lr, use_uncertainty=False)[source]
Perform a step in the environment.
- Parameters
epsilon (float) – Epsilon-greedy policy parameter.
lr (float) – Learning rate.
use_uncertainty (bool) – Whether to use uncertainty informed policy.
- Returns
Defines whether the episode is finished.
- Return type
bool
- _update_q_values(state, action, next_state, reward, epsilon, lr, **kwargs)[source]
Update the Q table.
- Parameters
state (list) – Current state of the agent.
action (int) – Selected action.
next_state (list) – Next state of the agent.
reward (float) – Reward for the selected action.
epsilon (float) – The exploration parameter.
lr (float) – Learning rate.
**kwargs (dict) – The keyword arguments.
explainable_rl.agents.q_learner module
- class QLearningAgent(env, gamma, verbose=False)[source]
Bases:
TDQ-Learning agent.
- __init__(env, gamma, verbose=False)[source]
Initialise the agent class.
- Parameters
env (MDP) – MDP object.
gamma (float) – Discount factor.
verbose (bool) – Defines whether print statements should be called.
- _update_q_values(state, action, next_state, reward, epsilon, lr, **kwargs)[source]
Update the Q table.
- Parameters
state (list) – Current state of the agent.
action (int) – Selected action.
next_state (list) – Next state of the agent.
reward (float) – Reward for the selected action.
epsilon (float) – The exploration parameter.
lr (float) – Learning rate.
**kwargs (dict) – The keyword arguments.
explainable_rl.agents.sarsa module
- class SarsaAgent(env, gamma, verbose=False)[source]
Bases:
TDSarsa agent.
- __init__(env, gamma, verbose=False)[source]
Initialise the agent class.
- Parameters
env (MDP) – MDP object.
gamma (float) – Discount factor.
verbose (bool) – Defines whether print statements should be called.
- _update_q_values(state, action, next_state, reward, epsilon, lr, **kwargs)[source]
Update the Q table.
- Parameters
state (list) – Current state of the agent.
action (int) – Selected action.
next_state (list) – Next state of the agent.
reward (float) – Reward for the selected action.
epsilon (float) – The exploration parameter.
lr (float) – Learning rate.
**kwargs (dict) – The keyword arguments.
explainable_rl.agents.sarsa_lambda module
- class SarsaLambdaAgent(env, gamma, verbose=False, lambda_=0.9)[source]
Bases:
TDSarsa Lambda agent.
- __init__(env, gamma, verbose=False, lambda_=0.9)[source]
Initialise the agent class.
- Parameters
env (MDP) – MDP object.
gamma (float) – Discount factor.
verbose (bool) – Defines whether print statements should be called.
- _update_q_values(state, action, next_state, reward, epsilon, lr, **kwargs)[source]
Update the Q table.
- Parameters
state (list) – Current state of the agent.
action (int) – Selected action.
next_state (list) – Next state of the agent.
reward (float) – Reward for the selected action.
epsilon (float) – The exploration parameter.
lr (float) – Learning rate.
**kwargs (dict) – The keyword arguments.
explainable_rl.agents.td module
- class TD(env, gamma, verbose=False)[source]
Bases:
AgentAgent class to store and update q-table.
- __init__(env, gamma, verbose=False)[source]
Initialise the agent.
- Parameters
env (Environment) – Environment object.
gamma (float) – Discount factor.
verbose (bool) – Print training information.
- _epsilon_greedy_policy(state=None, epsilon=0.1, Q=None)[source]
Epsilon-greedy policy.
- Parameters
state (int) – State.
epsilon (float) – Epsilon of epsilon-greedy policy. Defaults to 0 for pure exploitation.
- _get_action_scores(possible_actions, q_importance, q_values_weights, uncertainty_weights)[source]
Get the score for each action from a state.
- Parameters
possible_actions (set) – the possible actions for an agent in a state.
q_importance (float) – the weighting of the q value vs the amount a state has been seen.
q_values_weights (dict) – the q-weight of each state-action pair.
uncertainty_weights (dict) – the count-weight of each state-action pair.
- Returns
- the weighted score of each possible action
from the state.
- Return type
action_scores (dict)
- _get_possible_actions(state)[source]
Get the possible actions from a state.
- Parameters
state (list) – current state of the agent.
- Returns
- the possible actions that the agent can
take from the state.
- Return type
possible_actions (set)
- _get_q_value_weights(sum_possible_q, state, possible_actions)[source]
Get the q value of each action as a percentage of the total q value.
- Parameters
sum_possible_q (float) – the sum of the q values for the state.
state (list) – the state of the agent.
possible_actions (set) – the possible actions that the agent can take from the state.
- Returns
count of how many times a state-action pair has appeared. q_values_weights (dict): the q-weight of each state-action pair.
- Return type
state_action_counts (dict)
- static _get_uncertainty_weights(state_action_counts)[source]
Get uncertainty weight of an action from a state.
This is defined as the proportion of times a state is visited in the historical data vs the total state visits of the possible next states.
- Parameters
state_action_counts (dict) – the number of times a state has been visited in the historical data.
- Returns
uncertainty weight of each possible state.
- Return type
dict
- _step(epsilon, lr, use_uncertainty)[source]
Perform a step in the environment.
- Parameters
epsilon (float) – Epsilon-greedy policy parameter.
lr (float) – Learning rate.
use_uncertainty (bool) – Whether to use uncertainty informed policy.
- Returns
Defines whether the episode is finished.
- Return type
bool
- _update_q_values(state, action, next_state, reward, epsilon, lr, **kwargs)[source]
Update the Q table.
- Parameters
state (list) – Current state of the agent.
action (int) – Selected action.
next_state (list) – Next state of the agent.
reward (float) – Reward for the selected action.
epsilon (float) – The exploration parameter.
lr (float) – Learning rate.
**kwargs (dict) – The keyword arguments.
- create_tables(verbose=False)[source]
Initialize the agent.
This resets the environment, creates the q-table and the state to action mapping.
- Parameters
verbose (bool) – Print information.
- fit(agent_hyperparams, training_hyperparams, verbose=False, pbar=None)[source]
Fit agent to the dataset.
- Parameters
agent_hyperparams (dict) – Dictionary of agent hyperparameters.
training_hyperparams (dict) – Dictionary of training hyperparameters.
verbose (bool) – Print training information.
pbar (tqdm) – Progress bar.
- uncertainty_informed_policy(state=None, epsilon=0.1, use_uncertainty=False, q_importance=0.7)[source]
Get epsilon greedy policy that favours more densely populated state-action pairs.
- Parameters
state (list) – Current state of the agent.
epsilon (float) – The exploration parameter.
use_uncertainty (bool) – Whether to use uncertainty informed policy.
q_importance (float) – The importance of the q value in the policy.
- Returns
selected action.
- Return type
action (int)