Source code for maro.rl.agent.ddpg

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from typing import Callable, Union

import numpy as np
import torch

from maro.rl.exploration import NoiseExplorer
from maro.rl.model import SimpleMultiHeadModel
from maro.utils.exception.rl_toolkit_exception import UnrecognizedTask

from .abs_agent import AbsAgent


[docs]class DDPGConfig: """Configuration for the DDPG algorithm. Args: reward_discount (float): Reward decay as defined in standard RL terminology. q_value_loss_func (Callable): Loss function for the Q-value estimator. target_update_freq (int): Number of training rounds between policy target model updates. actor_loss_coefficient (float): The coefficient for policy loss in the total loss function, e.g., loss = q_value_loss + ``policy_loss_coefficient`` * policy_loss. Defaults to 1.0. tau (float): Soft update coefficient, e.g., target_model = tau * eval_model + (1-tau) * target_model. Defaults to 1.0. """ __slots__ = ["reward_discount", "q_value_loss_func", "target_update_freq", "policy_loss_coefficient", "tau"] def __init__( self, reward_discount: float, q_value_loss_func: Callable, target_update_freq: int, policy_loss_coefficient: float = 1.0, tau: float = 1.0, ): self.reward_discount = reward_discount self.q_value_loss_func = q_value_loss_func self.target_update_freq = target_update_freq self.policy_loss_coefficient = policy_loss_coefficient self.tau = tau
[docs]class DDPG(AbsAgent): """The Deep Deterministic Policy Gradient (DDPG) algorithm. References: https://arxiv.org/pdf/1509.02971.pdf https://github.com/openai/spinningup/tree/master/spinup/algos/pytorch/ddpg Args: model (SimpleMultiHeadModel): DDPG policy and q-value models. config: Configuration for DDPG algorithm. explorer (NoiseExplorer): An NoiseExplorer instance for generating exploratory actions. Defaults to None. """ def __init__(self, model: SimpleMultiHeadModel, config: DDPGConfig, explorer: NoiseExplorer = None): if model.task_names is None or set(model.task_names) != {"policy", "q_value"}: raise UnrecognizedTask(f"Expected model task names 'policy' and 'q_value', but got {model.task_names}") super().__init__(model, config) self._explorer = explorer self._target_model = model.copy() if model.trainable else None self._train_cnt = 0
[docs] def choose_action(self, state) -> Union[float, np.ndarray]: state = torch.from_numpy(state).to(self.device) is_single = len(state.shape) == 1 if is_single: state = state.unsqueeze(dim=0) action = self.model(state, task_name="policy", training=False).data.numpy() action_dim = action.shape[1] if self._explorer: action = self._explorer(action) if action_dim == 1: action = action.squeeze(axis=1) return action[0] if is_single else action
[docs] def learn(self, states: np.ndarray, actions: np.ndarray, rewards: np.ndarray, next_states: np.ndarray): states = torch.from_numpy(states).to(self.device) actual_actions = torch.from_numpy(actions).to(self.device) rewards = torch.from_numpy(rewards).to(self.device) next_states = torch.from_numpy(next_states).to(self.device) if len(actual_actions.shape) == 1: actual_actions = actual_actions.unsqueeze(dim=1) # (N, 1) current_q_values = self.model(torch.cat([states, actual_actions], dim=1), task_name="q_value") current_q_values = current_q_values.squeeze(dim=1) # (N,) next_actions = self._target_model(states, task_name="policy", training=False) next_q_values = self._target_model( torch.cat([next_states, next_actions], dim=1), task_name="q_value", training=False ).squeeze(1) # (N,) target_q_values = (rewards + self.config.reward_discount * next_q_values).detach() # (N,) q_value_loss = self.config.q_value_loss_func(current_q_values, target_q_values) actions_from_model = self.model(states, task_name="policy") policy_loss = -self.model(torch.cat([states, actions_from_model], dim=1), task_name="q_value").mean() self.model.learn(q_value_loss + self.config.policy_loss_coefficient * policy_loss) self._train_cnt += 1 if self._train_cnt % self.config.target_update_freq == 0: self._target_model.soft_update(self.model, self.config.tau)