Source code for explainable_rl.explainability.shap_values

from explainable_rl.foundation.library import *


[docs]class ShapValues: """SHAP Values class."""
[docs] def __init__(self, engine): """Initialise the ShapValues class. Args: engine (Engine): Engine object. """ self.sample = None self.features = engine.dh.state_labels self.env = engine.env self.Q = engine.agent.Q self.minmax_scalars = engine.dh.minmax_scalars self.action = engine.dh.action_labels self.number_of_samples = engine.hyperparameters["explainability"][ "shap_num_samples" ] self.binned_sample = None self.verbose = engine.hyperparameters["program_flow"]["verbose"]
[docs] def compute_shap_values(self, sample): """Compute the SHAP values for a given sample. Args: sample (list): List with the sample to compute the SHAP values. Returns: shap_values (dict): Dictionary with the shap values for each feature. predicted_action (int): Predicted action. """ self.sample = sample # Verify if sample length is correct if self.verbose: print("Verify if sample length is correct") if not self.verify_sample_length(): raise ValueError("The sample length is not correct.") # Normalize sample self.sample = self.normalize_sample() # Bin sample self.binned_sample = self.bin_sample() # Verify if sample is an outlier if self.verbose: print("Verify if sample is an outlier") if self.verify_outliers(self.binned_sample): raise ValueError("The sample is an outlier.") # Verify if cell has been visited if self.verbose: print("Verify if selected cell has been visited") if not self.verify_cell_availability(self.binned_sample): raise ValueError("The cell has not been visited by the agent.") # Predict action if self.verbose: print("Predict action") predicted_action = self.predict_action() # Loop over all state dimensions if self.verbose: print("Compute shap values") shap_values = {} for shap_ft in range(len(self.features)): if self.verbose: print("Compute shap values for feature: ", self.features[shap_ft]) num_bins_per_shap_ft = self.env.bins[shap_ft] action_samples_plus = np.zeros(self.number_of_samples, dtype=int) action_samples_minus = np.zeros(self.number_of_samples, dtype=int) for sample in range(self.number_of_samples): verified_samples = False # Sample plus and minus samples while not verified_samples: s_plus, s_minus = self.sample_plus_minus_samples( shap_ft, num_bins_per_shap_ft ) if not self.verify_cell_availability( s_plus ) or not self.verify_cell_availability(s_minus): verified_samples = False else: verified_samples = True # Find best Q values for 2 samples Q_state_plus = np.zeros(self.env.bins[-1]) Q_state_minus = np.zeros(self.env.bins[-1]) for a in range(self.env.bins[-1]): index_plus = tuple(list(s_plus) + [a]) current_plus = self.Q[index_plus] Q_state_plus[a] = current_plus index_minus = tuple(list(s_minus) + [a]) current_minus = self.Q[index_minus] Q_state_minus[a] = current_minus action_samples_plus[sample] = np.argmax(np.array(Q_state_plus)) action_samples_minus[sample] = np.argmax(np.array(Q_state_minus)) # Denorm actions denorm_action_samples_plus = self.get_denorm_actions(action_samples_plus) denorm_action_samples_minus = self.get_denorm_actions(action_samples_minus) # Compute difference between arrays difference = np.array(denorm_action_samples_plus) - np.array( denorm_action_samples_minus ) # Compute mean mean_difference = round(np.mean(difference, axis=0), 4) # Append shap value for that feature shap_values.update({self.features[shap_ft]: mean_difference}) self.shap_values = shap_values return shap_values, predicted_action
[docs] def verify_sample_length(self): """Verify whether the sample length is correct. Returns: bool: True if the sample length is correct, False otherwise. """ if len(self.sample) != len(self.features): return False return True
[docs] def bin_sample(self): """Bin the samples. Returns: binned_sample (np.array): Binned sample. """ state_dims = list(range(len(self.features))) binned_sample = self.env.bin_state(self.sample, idxs=state_dims) return binned_sample
[docs] def verify_cell_availability(self, binned_sample): """Verify whether the cell has been visited. Args: binned_sample (np.array): Binned sample. Returns: bool: True if the cell has been visited, False otherwise. """ num_actions = self.env.bins[-1] # The last element in the bins list is the number of actions for a in range(num_actions): index_current = tuple(list(binned_sample) + [a]) if self.Q[index_current] != 0: # At least one action has been visited for this state has been visited return True return False
[docs] def verify_outliers(self, binned_sample): """Verify whether the sample is an outlier. Args: binned_sample (np.array): Binned sample. Returns: bool: True if the sample is an outlier, False otherwise. """ for ft in range(len(self.features)): if binned_sample[ft] >= self.env.bins[ft] or binned_sample[ft] < 0: return True return False
[docs] def sample_plus_minus_samples(self, shap_ft, num_bins_per_shap_ft): """Sample the plus and minus samples. Args: shap_ft (int): Feature to explain. num_bins_per_shap_ft (int): Number of bins for the feature to explain. Returns: s_plus (np.array): Plus sample. s_minus (np.array): Minus sample. """ shap_ft_random = random.randrange(num_bins_per_shap_ft) s_plus = np.zeros(len(self.sample)) s_minus = np.zeros(len(self.sample)) s_plus[shap_ft] = self.binned_sample[shap_ft] s_minus[shap_ft] = shap_ft_random for ft in range(len(self.features)): if shap_ft != ft: num_bins_ft = self.env.bins[ft] ft_random = random.randrange(num_bins_ft) s_plus[ft] = ft_random s_minus[ft] = ft_random s_plus = [int(i) for i in s_plus] s_minus = [int(i) for i in s_minus] return s_plus, s_minus
[docs] def get_denorm_actions(self, actions): """Get actions denormalized values. Args: actions (list): List of actions. Returns: denorm_actions (list): List of denormalized actions. """ denorm_actions = [] if len(self.action) == 1: scalar = self.minmax_scalars[self.action[0]] for a in actions: # Divide dig actions by # bins of the action dimension to get a value between 0 and 1 denorm_a = scalar.inverse_transform( a.reshape(-1, 1) / self.env.bins[-1] ) denorm_actions.append(denorm_a[0][0]) else: for a in actions: denorm_a = self.action[a] denorm_actions.append(denorm_a) return denorm_actions
[docs] def normalize_sample(self): """Normalize sample. Returns: normalized_sample (list): Normalized sample. """ normalized_sample = [] for idx, ft in enumerate(self.features): scalar = self.minmax_scalars[ft] idx_df = pd.DataFrame( np.array(self.sample[idx]).reshape(-1, 1), columns=[ft] ) norm_ft = scalar.transform(idx_df) normalized_sample.append(norm_ft[0][0]) return normalized_sample
[docs] def predict_action(self): """Predict action. Returns: action (list): Predicted action. """ Q_state = np.zeros(self.env.bins[-1]) for a in range(self.env.bins[-1]): index = tuple(list(self.binned_sample) + [a]) current_q = self.Q[index] Q_state[a] = current_q binned_action = np.argmax(np.array(current_q)) action = self.get_denorm_actions([binned_action]) return round(action[0], 4)
[docs] def plot_shap_values( self, sample, shap_values, predicted_action, fig_name=None, savefig=False ): """Plot shap values. Args: sample (list): Sample. shap_values (dict): Shap values. predicted_action (float): Predicted action. fig_name (str): Figure name. savefig (bool): Whether to save the figure or not. """ # Sort values sorted_shap_values = sorted(shap_values.items(), key=lambda x: x[1]) # Get values features = [i[0] for i in sorted_shap_values] values = [i[1] for i in sorted_shap_values] colors = ["red" if i < 0 else "green" for i in values] # Plot values plt.grid(zorder=0) plt.barh(features, values, color=colors, zorder=3) plt.title(f"Shap values for {sample} - Action: {predicted_action}") plt.tight_layout() if savefig: plt.savefig(fig_name, dpi=600) plt.show()