Agent¶
maro.rl.agent.abs_agent¶
-
class
maro.rl.agent.abs_agent.AbsAgent(model: maro.rl.model.learning_model.AbsCoreModel, config)[source]¶ Bases:
abc.ABCAbstract RL agent class.
It’s a sandbox for the RL algorithm. Scenario-specific details will be excluded. We focus on the abstraction algorithm development here. Environment observation and decision events will be converted to a uniform format before calling in. The output will be converted to an environment executable format before return back to the environment. Its key responsibility is optimizing policy based on interaction with the environment.
- Parameters
model (AbsCoreModel) – Task model or container of task models required by the algorithm.
config – Settings for the algorithm.
-
abstract
choose_action(state)[source]¶ This method uses the underlying model(s) to compute an action from a shaped state.
- Parameters
state – A state object shaped by a
StateShaperto conform to the model input format.- Returns
The action to be taken given
state. It is usually necessary to use anActionShaperto convert this to an environment executable action.
-
dump_model_to_file(path: str)[source]¶ Dump the algorithm’s trainable models to disk.
Dump trainable models to the specified directory. The model file is always prefixed with the agent’s name.
- Parameters
path (str) – path to the directory where the models are saved.
-
abstract
learn(*args, **kwargs)[source]¶ Algorithm-specific training logic.
The parameters are data to train the underlying model on. Algorithm-specific loss and optimization should be reflected here.
maro.rl.agent.dqn¶
-
class
maro.rl.agent.dqn.DQN(model: maro.rl.model.learning_model.SimpleMultiHeadModel, config: maro.rl.agent.dqn.DQNConfig)[source]¶ Bases:
maro.rl.agent.abs_agent.AbsAgentThe Deep-Q-Networks algorithm.
See https://web.stanford.edu/class/psych209/Readings/MnihEtAlHassibis15NatureControlDeepRL.pdf for details.
- Parameters
model (SimpleMultiHeadModel) – Q-value model.
config – Configuration for DQN algorithm.
-
choose_action(state: numpy.ndarray) → Union[int, numpy.ndarray][source]¶ This method uses the underlying model(s) to compute an action from a shaped state.
- Parameters
state – A state object shaped by a
StateShaperto conform to the model input format.- Returns
The action to be taken given
state. It is usually necessary to use anActionShaperto convert this to an environment executable action.
-
class
maro.rl.agent.dqn.DQNConfig(reward_discount: float, target_update_freq: int, epsilon: float = 0.0, tau: float = 0.1, double: bool = True, advantage_type: str = None, loss_cls=<class 'torch.nn.modules.loss.MSELoss'>)[source]¶ Bases:
objectConfiguration for the DQN algorithm.
- Parameters
reward_discount (float) – Reward decay as defined in standard RL terminology.
epsilon (float) – Exploration rate for epsilon-greedy exploration. Defaults to None.
tau (float) – Soft update coefficient, i.e., target_model = tau * eval_model + (1 - tau) * target_model.
double (bool) – If True, the next Q values will be computed according to the double DQN algorithm, i.e., q_next = Q_target(s, argmax(Q_eval(s, a))). Otherwise, q_next = max(Q_target(s, a)). See https://arxiv.org/pdf/1509.06461.pdf for details. Defaults to False.
advantage_type (str) – Advantage mode for the dueling architecture. Defaults to None, in which case it is assumed that the regular Q-value model is used.
loss_cls – Loss function class for evaluating TD errors. Defaults to torch.nn.MSELoss.
target_update_freq (int) – Number of training rounds between target model updates.
-
advantage_type¶
-
double¶
-
epsilon¶
-
loss_func¶
-
reward_discount¶
-
target_update_freq¶
-
tau¶
maro.rl.agent.ddpg¶
-
class
maro.rl.agent.ddpg.DDPG(model: maro.rl.model.learning_model.SimpleMultiHeadModel, config: maro.rl.agent.ddpg.DDPGConfig, explorer: maro.rl.exploration.noise_explorer.NoiseExplorer = None)[source]¶ Bases:
maro.rl.agent.abs_agent.AbsAgentThe Deep Deterministic Policy Gradient (DDPG) algorithm.
References: https://arxiv.org/pdf/1509.02971.pdf https://github.com/openai/spinningup/tree/master/spinup/algos/pytorch/ddpg
- Parameters
model (SimpleMultiHeadModel) – DDPG policy and q-value models.
config – Configuration for DDPG algorithm.
explorer (NoiseExplorer) – An NoiseExplorer instance for generating exploratory actions. Defaults to None.
-
choose_action(state) → Union[float, numpy.ndarray][source]¶ This method uses the underlying model(s) to compute an action from a shaped state.
- Parameters
state – A state object shaped by a
StateShaperto conform to the model input format.- Returns
The action to be taken given
state. It is usually necessary to use anActionShaperto convert this to an environment executable action.
-
class
maro.rl.agent.ddpg.DDPGConfig(reward_discount: float, q_value_loss_func: Callable, target_update_freq: int, policy_loss_coefficient: float = 1.0, tau: float = 1.0)[source]¶ Bases:
objectConfiguration for the DDPG algorithm. :Parameters: * reward_discount (float) – Reward decay as defined in standard RL terminology.
q_value_loss_func (Callable) – Loss function for the Q-value estimator.
target_update_freq (int) – Number of training rounds between policy target model updates.
actor_loss_coefficient (float) – The coefficient for policy loss in the total loss function, e.g., loss = q_value_loss +
policy_loss_coefficient* policy_loss. Defaults to 1.0.tau (float) – Soft update coefficient, e.g., target_model = tau * eval_model + (1-tau) * target_model. Defaults to 1.0.
-
policy_loss_coefficient¶
-
q_value_loss_func¶
-
reward_discount¶
-
target_update_freq¶
-
tau¶
maro.rl.agent.policy_optimization¶
Explorer¶
maro.rl.exploration.abs_explorer¶
maro.rl.exploration.epsilon_greedy_explorer¶
-
class
maro.rl.exploration.epsilon_greedy_explorer.EpsilonGreedyExplorer(num_actions: int, epsilon: float = 0.0)[source]¶ Bases:
maro.rl.exploration.abs_explorer.AbsExplorerEpsilon greedy explorer for discrete action spaces.
- Parameters
num_actions (int) – Number of all possible actions.
maro.rl.exploration.noise_explorer¶
-
class
maro.rl.exploration.noise_explorer.GaussianNoiseExplorer(min_action: Union[float, list, numpy.ndarray] = None, max_action: Union[float, list, numpy.ndarray] = None, noise_mean: Union[float, list, numpy.ndarray] = 0.0, noise_stddev: Union[float, list, numpy.ndarray] = 0.0, is_relative: bool = False)[source]¶ Bases:
maro.rl.exploration.noise_explorer.NoiseExplorerExplorer that adds a random noise to a model-generated action sampled from a Gaussian distribution.
-
class
maro.rl.exploration.noise_explorer.NoiseExplorer(min_action: Union[float, list, numpy.ndarray] = None, max_action: Union[float, list, numpy.ndarray] = None)[source]¶ Bases:
maro.rl.exploration.abs_explorer.AbsExplorerExplorer that adds a random noise to a model-generated action.
-
class
maro.rl.exploration.noise_explorer.UniformNoiseExplorer(min_action: Union[float, list, numpy.ndarray] = None, max_action: Union[float, list, numpy.ndarray] = None, noise_lower_bound: Union[float, list, numpy.ndarray] = 0.0, noise_upper_bound: Union[float, list, numpy.ndarray] = 0.0)[source]¶ Bases:
maro.rl.exploration.noise_explorer.NoiseExplorerExplorer that adds a random noise to a model-generated action sampled from a uniform distribution.
Scheduler¶
maro.rl.scheduling.scheduler¶
-
class
maro.rl.scheduling.scheduler.Scheduler(max_iter: int = - 1)[source]¶ Bases:
objectScheduler that generates new parameters each iteration.
- Parameters
max_iter (int) – Maximum number of iterations. If -1, using the scheduler in a for-loop will result in an infinite loop unless the
check_for_stoppingmethod is implemented.
-
property
iter¶
maro.rl.scheduling.simple_parameter_scheduler¶
-
class
maro.rl.scheduling.simple_parameter_scheduler.LinearParameterScheduler(max_iter: int, parameter_names: List[str], start: Union[float, list, tuple, numpy.ndarray], end: Union[float, list, tuple, numpy.ndarray])[source]¶ Bases:
maro.rl.scheduling.scheduler.SchedulerStatic exploration parameter generator based on a linear schedule.
- Parameters
max_iter (int) – Maximum number of iterations.
parameter_names (List[str]) – List of exploration parameter names.
start (Union[float, list, tuple, np.ndarray]) – Exploration parameter values for the first episode. These values must correspond to
parameter_names.end (Union[float, list, tuple, np.ndarray]) – Exploration parameter values rate for the last episode. These values must correspond to
parameter_names.
-
class
maro.rl.scheduling.simple_parameter_scheduler.TwoPhaseLinearParameterScheduler(max_iter: int, parameter_names: List[str], split: float, start: Union[float, list, tuple, numpy.ndarray], mid: Union[float, list, tuple, numpy.ndarray], end: Union[float, list, tuple, numpy.ndarray])[source]¶ Bases:
maro.rl.scheduling.scheduler.SchedulerExploration parameter generator based on two linear schedules joined together.
- Parameters
max_iter (int) – Maximum number of iterations.
parameter_names (List[str]) – List of exploration parameter names.
split (float) – The point where the switch from the first linear schedule to the second occurs.
start (Union[float, list, tuple, np.ndarray]) – Exploration parameter values for the first episode. These values must correspond to
parameter_names.mid (Union[float, list, tuple, np.ndarray]) – Exploration parameter values where the switch from the first linear schedule to the second occurs. In other words, this is the exploration rate where the first linear schedule ends and the second begins. These values must correspond to
parameter_names.end (Union[float, list, tuple, np.ndarray]) – Exploration parameter values for the last episode. These values must correspond to
parameter_names.
- Returns
An iterator over the series of exploration rates from episode 0 to
max_iter- 1.
Storage¶
maro.rl.storage.abs_store¶
-
class
maro.rl.storage.abs_store.AbsStore[source]¶ Bases:
abc.ABCA data store abstraction that supports get, put, update and sample operations.
-
filter(filters: Sequence[Callable])[source]¶ Multi-filter method.
The input to one filter is the output from the previous filter.
- Parameters
filters (Sequence[Callable]) – Filter list, each item is a lambda function, e.g., [lambda d: d[‘a’] == 1 and d[‘b’] == 1].
- Returns
Filtered indexes and corresponding objects.
-
abstract
get(indexes: Sequence)[source]¶ Get contents.
- Parameters
indexes – A sequence of indexes to retrieve contents at.
- Returns
Retrieved contents.
-
put(contents: Sequence)[source]¶ Put new contents.
- Parameters
contents (Sequence) – Contents to be added to the store.
- Returns
The indexes where the newly added entries reside in the store.
-
abstract
sample(size: int, weights: Sequence, replace: bool = True)[source]¶ Obtain a random sample from the experience pool.
- Parameters
size (int) – Sample sizes for each round of sampling in the chain. If this is a single integer, it is used as the sample size for all samplers in the chain.
weights (Sequence) – A sequence of sampling weights.
replace (bool) – If True, sampling is performed with replacement. Defaults to True.
- Returns
A random sample from the experience pool.
-
abstract
update(indexes: Sequence, contents: Sequence)[source]¶ Update the store contents at given positions.
- Parameters
indexes (Sequence) – Positions where updates are to be made.
contents (Sequence) – Item list, which has the same length as indexes.
- Returns
The indexes where store contents are updated.
-
maro.rl.storage.simple_store¶
-
class
maro.rl.storage.simple_store.OverwriteType(value)[source]¶ Bases:
enum.EnumAn enumeration.
-
RANDOM= 'random'¶
-
ROLLING= 'rolling'¶
-
-
class
maro.rl.storage.simple_store.SimpleStore(keys: list, capacity: int = - 1, overwrite_type: maro.rl.storage.simple_store.OverwriteType = None)[source]¶ Bases:
maro.rl.storage.abs_store.AbsStoreAn implementation of
AbsStorefor experience storage in RL.This implementation uses a dictionary of lists as the internal data structure. The objects for each key are stored in a list. To be useful for experience storage in RL, uniformity checks are performed during put operations to ensure that the list lengths stay the same for all keys at all times. Both unlimited and limited storage are supported.
- Parameters
keys (list) – List of keys identifying each column.
capacity (int) – If negative, the store is of unlimited capacity. Defaults to -1.
overwrite_type (OverwriteType) – If storage capacity is bounded, this specifies how existing entries are overwritten when the capacity is exceeded. Two types of overwrite behavior are supported: - Rolling, where overwrite occurs sequentially with wrap-around. - Random, where overwrite occurs randomly among filled positions. Alternatively, the user may also specify overwrite positions (see
put).
-
apply_multi_filters(filters: List[Callable])[source]¶ Multi-filter method.
The input to one filter is the output from its predecessor in the sequence.
- Parameters
filters (List[Callable]) – Filter list, each item is a lambda function, e.g., [lambda d: d[‘a’] == 1 and d[‘b’] == 1].
- Returns
Filtered indexes and corresponding objects.
-
apply_multi_samplers(samplers: list, replace: bool = True) → Tuple[source]¶ Multi-samplers method.
This implements chained sampling where the input to one sampler is the output from its predecessor in the sequence.
- Parameters
samplers (list) – A sequence of weight functions for computing the sampling weights of the items in the store, e.g., [lambda d: d[‘a’], lambda d: d[‘b’]].
replace (bool) – If True, sampling will be performed with replacement.
- Returns
Sampled indexes and corresponding objects.
-
property
capacity¶ Store capacity.
If negative, the store grows without bound. Otherwise, the number of items in the store will not exceed this capacity.
-
get(indexes: [<class 'int'>]) → dict[source]¶ Get contents.
- Parameters
indexes – A sequence of indexes to retrieve contents at.
- Returns
Retrieved contents.
-
property
keys¶
-
property
overwrite_type¶ An
OverwriteTypemember indicating the overwrite behavior when the store capacity is exceeded.
-
put(contents: Dict[str, List], overwrite_indexes: list = None) → List[int][source]¶ Put new contents in the store.
- Parameters
contents (dict) – Dictionary of items to add to the store. If the store is not empty, this must have the same keys as the store itself. Otherwise an
StoreMisalignmentwill be raised.overwrite_indexes (list, optional) – Indexes where the contents are to be overwritten. This is only used when the store has a fixed capacity and putting
contentsin the store would exceed this capacity. If this is None and overwriting is necessary, rolling or random overwriting will be done according to theoverwriteproperty. Defaults to None.
- Returns
The indexes where the newly added entries reside in the store.
-
sample(size, weights: Union[list, numpy.ndarray] = None, replace: bool = True)[source]¶ Obtain a random sample from the experience pool.
- Parameters
size (int) – Sample sizes for each round of sampling in the chain. If this is a single integer, it is used as the sample size for all samplers in the chain.
weights (Union[list, np.ndarray]) – Sampling weights.
replace (bool) – If True, sampling is performed with replacement. Defaults to True.
- Returns
Sampled indexes and the corresponding objects, e.g., [1, 2, 3], [‘a’, ‘b’, ‘c’].
-
sample_by_key(key, size: int, replace: bool = True)[source]¶ Obtain a random sample from the store using one of the columns as sampling weights.
- Parameters
key – The column whose values are to be used as sampling weights.
size (int) – Sample size.
replace (bool) – If True, sampling is performed with replacement.
- Returns
Sampled indexes and the corresponding objects.
-
sample_by_keys(keys: list, sizes: list, replace: bool = True)[source]¶ Obtain a random sample from the store by chained sampling using multiple columns as sampling weights.
- Parameters
keys (list) – The column whose values are to be used as sampling weights.
sizes (list) – Sample size.
replace (bool) – If True, sampling is performed with replacement.
- Returns
Sampled indexes and the corresponding objects.
-
update(indexes: list, contents: Dict[str, List])[source]¶ Update contents at given positions.
- Parameters
indexes (list) – Positions where updates are to be made.
contents (dict) – Contents to write to the internal store at given positions. It is subject to uniformity checks to ensure that all values have the same length.
- Returns
The indexes where store contents are updated.