explainable_rl.foundation package
Submodules
explainable_rl.foundation.agent module
- class Agent(env, gamma, verbose=False)[source]
Bases:
objectParent of all child agents (e.g Q-learner, SARSA).
- __init__(env, gamma, verbose=False)[source]
Initialise the agent.
- Parameters
env (Environment) – Environment object.
gamma (float) – Discount factor.
verbose (bool) – Print training information.
- static _convert_to_string(state)[source]
Convert a state to a string.
- Parameters
state (list) – The state to convert.
- Returns
The state as a string.
- Return type
state_str (string)
- _epsilon_greedy_policy(state, epsilon)[source]
Epsilon-greedy policy.
- Parameters
state (int) – State.
epsilon (float) – Epsilon of epsilon-greedy policy. Defaults to 0 for pure exploitation.
- fit(agent_hyperparams, training_hyperparams, verbose=False, pbar=None)[source]
Fit agent to the dataset.
- Parameters
agent_hyperparams (dict) – Dictionary of agent hyperparameters.
training_hyperparams (dict) – Dictionary of training hyperparameters.
verbose (bool) – Print training information.
pbar (tqdm) – Progress bar.
- predict_actions(states, epsilon=0)[source]
Predict action for a list of states using epsilon-greedy policy.
- Parameters
states (list) – States (binned).
epsilon (float) – Epsilon of epsilon-greedy policy. Defaults to 0 for pure exploitation.
- Returns
List of recommended actions.
- Return type
list
- predict_rewards(states, actions)[source]
Predict reward for a list of state-actions.
This function uses the avg reward matrix (which simulates a real-life scenario).
- Parameters
states (list) – States (binned).
actions (list) – Actions (binned).
- Returns
List of recommended actions.
- Return type
list
- uncertainty_informed_policy(state=None, epsilon=0.1, use_uncertainty=False, q_importance=0.7)[source]
Get epsilon greedy policy that favours more densely populated state-action pairs.
- Parameters
state (list) – Current state of the agent.
epsilon (float) – The exploration parameter.
use_uncertainty (bool) – Whether to use uncertainty informed policy.
q_importance (float) – The importance of the q value in the policy.
- Returns
selected action.
- Return type
action (int)
explainable_rl.foundation.engine module
- class Engine(dh, hyperparam_dict, verbose=False)[source]
Bases:
objectResponsible for creating the agent and environment instances and running the training loop.
- __init__(dh, hyperparam_dict, verbose=False)[source]
Initialise engine class.
- Parameters
dh (DataHandler) – DataHandler to be given to the Environment.
hyperparam_dict (dict) – Dictionary containing all hyperparameters.
verbose (bool) – Whether print statements about the program flow should be displayed.
- _evaluate_total_agent_reward()[source]
Calculate the total reward obtained on the evaluation states using the agent’s policy.
- Returns
Total (not scaled) cumulative reward.
- Return type
total_agent_reward (float)
- _evaluate_total_hist_reward()[source]
Calculate the total reward obtained on the evaluation states using the agent’s policy.
- Returns
Total (not scaled) cumulative based on historical data.
- Return type
total_hist_reward (float)
explainable_rl.foundation.environment module
- class MDP(dh)[source]
Bases:
objectDefine the MDP super class which all particular MDP should inherit from.
- __init__(dh)[source]
Initialise the Strategic Pricing MDP class.
- Parameters
dh (DataHandler) – Data handler object.
explainable_rl.foundation.library module
explainable_rl.foundation.utils module
- convert_to_list(state_str)[source]
Convert a state string to a list.
- Parameters
state_str (str) – State as a string.
- Returns
State as a list.
- Return type
list
- convert_to_string(state)[source]
Convert a state to a string.
- Parameters
state (list) – State to convert.
- Returns
State as a string.
- Return type
str
- decay_param(param, decay, min_param)[source]
Decay a parameter.
- Parameters
param (float) – Parameter to decay.
decay (float) – Decay rate.
min_param (float) – Minimum value of the parameter.
- Returns
Updated parameter.
- Return type
float
- load_data(data_path, n_samples, delimiter=',')[source]
Load data from file.
- Parameters
data_path (str) – Path to data file.
n_samples (int) – Number of samples to load.
delimiter (str) – Which separates columns.
- load_engine(path_name)[source]
Load engine.
- Parameters
path_name (str or List(str)) – Path to save the engine.