explainable_rl.explainability package

Submodules

explainable_rl.explainability.pdp module

class PDP(engine)[source]

Bases: object

Partial Dependence Plots class.

__init__(engine)[source]

Initialise PDP class.

Parameters

engine (Engine) – Engine object containing the trained agent.

_get_denorm_actions()[source]

Get actions denormalized values.

_get_denorm_states()[source]

Get states denormalized values.

_get_digitized_pdp()[source]

Compute average Q-value per each state-action pair. Marginal effect of the state-action pair averaging other state dimensions.

build_data_for_plots()[source]

Prepare data to build PDP plots.

plot_pdp(feature, fig_name=None, savefig=True)[source]

Build PDP plots. One marginalized plot per each state dimension.

Parameters
  • feature (str) – Feature to plot.

  • fig_name (str) – Name to save plot.

  • savefig (bool) – Whether to save the plot.

explainable_rl.explainability.shap_values module

class ShapValues(engine)[source]

Bases: object

SHAP Values class.

__init__(engine)[source]

Initialise the ShapValues class.

Parameters

engine (Engine) – Engine object.

bin_sample()[source]

Bin the samples.

Returns

Binned sample.

Return type

binned_sample (np.array)

compute_shap_values(sample)[source]

Compute the SHAP values for a given sample.

Parameters

sample (list) – List with the sample to compute the SHAP values.

Returns

Dictionary with the shap values for each feature. predicted_action (int): Predicted action.

Return type

shap_values (dict)

get_denorm_actions(actions)[source]

Get actions denormalized values.

Parameters

actions (list) – List of actions.

Returns

List of denormalized actions.

Return type

denorm_actions (list)

normalize_sample()[source]

Normalize sample.

Returns

Normalized sample.

Return type

normalized_sample (list)

plot_shap_values(sample, shap_values, predicted_action, fig_name=None, savefig=False)[source]

Plot shap values.

Parameters
  • sample (list) – Sample.

  • shap_values (dict) – Shap values.

  • predicted_action (float) – Predicted action.

  • fig_name (str) – Figure name.

  • savefig (bool) – Whether to save the figure or not.

predict_action()[source]

Predict action.

Returns

Predicted action.

Return type

action (list)

sample_plus_minus_samples(shap_ft, num_bins_per_shap_ft)[source]

Sample the plus and minus samples.

Parameters
  • shap_ft (int) – Feature to explain.

  • num_bins_per_shap_ft (int) – Number of bins for the feature to explain.

Returns

Plus sample. s_minus (np.array): Minus sample.

Return type

s_plus (np.array)

verify_cell_availability(binned_sample)[source]

Verify whether the cell has been visited.

Parameters

binned_sample (np.array) – Binned sample.

Returns

True if the cell has been visited, False otherwise.

Return type

bool

verify_outliers(binned_sample)[source]

Verify whether the sample is an outlier.

Parameters

binned_sample (np.array) – Binned sample.

Returns

True if the sample is an outlier, False otherwise.

Return type

bool

verify_sample_length()[source]

Verify whether the sample length is correct.

Returns

True if the sample length is correct, False otherwise.

Return type

bool

Module contents