Source code for maro.rl.scheduling.simple_parameter_scheduler

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from typing import List, Union

import numpy as np

from .scheduler import Scheduler


[docs]class LinearParameterScheduler(Scheduler): """Static exploration parameter generator based on a linear schedule. Args: max_iter (int): Maximum number of iterations. parameter_names (List[str]): List of exploration parameter names. start (Union[float, list, tuple, np.ndarray]): Exploration parameter values for the first episode. These values must correspond to ``parameter_names``. end (Union[float, list, tuple, np.ndarray]): Exploration parameter values rate for the last episode. These values must correspond to ``parameter_names``. """ def __init__( self, max_iter: int, parameter_names: List[str], start: Union[float, list, tuple, np.ndarray], end: Union[float, list, tuple, np.ndarray] ): super().__init__(max_iter) self._parameter_names = parameter_names if isinstance(start, float): self._current_values = start * np.ones(len(self._parameter_names)) elif isinstance(start, (list, tuple)): self._current_values = np.asarray(start) else: self._current_values = start if isinstance(end, float): end = end * np.ones(len(self._parameter_names)) elif isinstance(end, (list, tuple)): end = np.asarray(end) self._delta = (end - self._current_values) / (self._max_iter - 1)
[docs] def next_params(self): current_values = self._current_values.copy() self._current_values += self._delta return dict(zip(self._parameter_names, current_values))
[docs]class TwoPhaseLinearParameterScheduler(Scheduler): """Exploration parameter generator based on two linear schedules joined together. Args: max_iter (int): Maximum number of iterations. parameter_names (List[str]): List of exploration parameter names. split (float): The point where the switch from the first linear schedule to the second occurs. start (Union[float, list, tuple, np.ndarray]): Exploration parameter values for the first episode. These values must correspond to ``parameter_names``. mid (Union[float, list, tuple, np.ndarray]): Exploration parameter values where the switch from the first linear schedule to the second occurs. In other words, this is the exploration rate where the first linear schedule ends and the second begins. These values must correspond to ``parameter_names``. end (Union[float, list, tuple, np.ndarray]): Exploration parameter values for the last episode. These values must correspond to ``parameter_names``. Returns: An iterator over the series of exploration rates from episode 0 to ``max_iter`` - 1. """ def __init__( self, max_iter: int, parameter_names: List[str], split: float, start: Union[float, list, tuple, np.ndarray], mid: Union[float, list, tuple, np.ndarray], end: Union[float, list, tuple, np.ndarray] ): if split < 0 or split > 1.0: raise ValueError("split must be a float between 0 and 1.") super().__init__(max_iter) self._parameter_names = parameter_names self._split = int(self._max_iter * split) if isinstance(start, float): self._current_values = start * np.ones(len(self._parameter_names)) elif isinstance(start, (list, tuple)): self._current_values = np.asarray(start) else: self._current_values = start if isinstance(mid, float): mid = mid * np.ones(len(self._parameter_names)) elif isinstance(mid, (list, tuple)): mid = np.asarray(mid) if isinstance(end, float): end = end * np.ones(len(self._parameter_names)) elif isinstance(end, (list, tuple)): end = np.asarray(end) self._delta_1 = (mid - self._current_values) / self._split self._delta_2 = (end - mid) / (max_iter - self._split - 1)
[docs] def next_params(self): current_values = self._current_values.copy() self._current_values += self._delta_1 if self._iter_index < self._split else self._delta_2 return dict(zip(self._parameter_names, current_values))