Source code for explainable_rl.foundation.utils

from explainable_rl.foundation.library import *


[docs]def save_engine(engine, path_name=None): """ Save engine. Args: engine (Engine): Engine to save. path_name (str): Path to save the engine. """ engine.dh.dataset = None # Used to remove training set with open(path_name, "wb") as f: pickle.dump(engine, f)
[docs]def load_engine(path_name): """ Load engine. Args: path_name (str or List(str)): Path to save the engine. """ if isinstance(path_name, str): path_name = [path_name] engines = [] for p in path_name: with open(p, "rb") as f: engines.append(pickle.load(f)) if len(engines) == 1: engines = engines[0] return engines
[docs]def load_data(data_path, n_samples, delimiter=","): """Load data from file. Args: data_path (str): Path to data file. n_samples (int): Number of samples to load. delimiter (str): Which separates columns. """ file_type = data_path.split(".")[-1] if file_type == "csv": dataset = pd.read_csv(data_path, sep=delimiter) elif file_type == "xlsx": dataset = pd.read_excel(data_path) elif file_type == "parquet": dataset = pd.read_parquet(data_path) else: raise ValueError("File type not supported") dataset.sample(frac=1) return dataset[:n_samples]
[docs]def split_train_test(dataset, train_test_split=0.2): """Split dataset into train and test. Args: dataset (pd.DataFrame): Dataset. train_test_split (float): Proportion of test data. Returns: train_dataset (pd.DataFrame): Train dataset. test_dataset (pd.DataFrame): Test dataset. """ dataset = dataset.sample(frac=1) split = int(train_test_split * len(dataset)) return dataset[split:], dataset[:split]
[docs]def convert_to_string(state): """Convert a state to a string. Args: state (list): State to convert. Returns: str: State as a string. """ return ",".join([str(s) for s in state])
[docs]def convert_to_list(state_str): """Convert a state string to a list. Args: state_str (str): State as a string. Returns: list: State as a list. """ return [int(s) for s in state_str.split(",")]
[docs]def decay_param(param, decay, min_param): """Decay a parameter. Args: param (float): Parameter to decay. decay (float): Decay rate. min_param (float): Minimum value of the parameter. Returns: float: Updated parameter. """ return max(param * (1 - decay), min_param)