# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import Union
import numpy as np
import torch
from maro.rl.model import SimpleMultiHeadModel
from maro.rl.utils import get_max, get_td_errors, select_by_actions
from maro.utils.exception.rl_toolkit_exception import UnrecognizedTask
from .abs_agent import AbsAgent
[docs]class DQNConfig:
"""Configuration for the DQN algorithm.
Args:
reward_discount (float): Reward decay as defined in standard RL terminology.
epsilon (float): Exploration rate for epsilon-greedy exploration. Defaults to None.
tau (float): Soft update coefficient, i.e., target_model = tau * eval_model + (1 - tau) * target_model.
double (bool): If True, the next Q values will be computed according to the double DQN algorithm,
i.e., q_next = Q_target(s, argmax(Q_eval(s, a))). Otherwise, q_next = max(Q_target(s, a)).
See https://arxiv.org/pdf/1509.06461.pdf for details. Defaults to False.
advantage_type (str): Advantage mode for the dueling architecture. Defaults to None, in which
case it is assumed that the regular Q-value model is used.
loss_cls: Loss function class for evaluating TD errors. Defaults to torch.nn.MSELoss.
target_update_freq (int): Number of training rounds between target model updates.
"""
__slots__ = [
"reward_discount", "target_update_freq", "epsilon", "tau", "double", "advantage_type", "loss_func"
]
def __init__(
self,
reward_discount: float,
target_update_freq: int,
epsilon: float = .0,
tau: float = 0.1,
double: bool = True,
advantage_type: str = None,
loss_cls=torch.nn.MSELoss
):
self.reward_discount = reward_discount
self.target_update_freq = target_update_freq
self.epsilon = epsilon
self.tau = tau
self.double = double
self.advantage_type = advantage_type
self.loss_func = loss_cls(reduction="none")
[docs]class DQN(AbsAgent):
"""The Deep-Q-Networks algorithm.
See https://web.stanford.edu/class/psych209/Readings/MnihEtAlHassibis15NatureControlDeepRL.pdf for details.
Args:
model (SimpleMultiHeadModel): Q-value model.
config: Configuration for DQN algorithm.
"""
def __init__(self, model: SimpleMultiHeadModel, config: DQNConfig):
if (config.advantage_type is not None and
(model.task_names is None or set(model.task_names) != {"state_value", "advantage"})):
raise UnrecognizedTask(
f"Expected model task names 'state_value' and 'advantage' since dueling DQN is used, "
f"got {model.task_names}"
)
super().__init__(model, config)
self._training_counter = 0
self._target_model = model.copy() if model.trainable else None
[docs] def choose_action(self, state: np.ndarray) -> Union[int, np.ndarray]:
state = torch.from_numpy(state)
if self.device:
state = state.to(self.device)
is_single = len(state.shape) == 1
if is_single:
state = state.unsqueeze(dim=0)
q_values = self._get_q_values(state, training=False)
num_actions = q_values.shape[1]
greedy_action = q_values.argmax(dim=1).data
# No exploration
if self.config.epsilon == .0:
return greedy_action.item() if is_single else greedy_action.numpy()
if is_single:
return greedy_action if np.random.random() > self.config.epsilon else np.random.choice(num_actions)
# batch inference
return np.array([
act if np.random.random() > self.config.epsilon else np.random.choice(num_actions)
for act in greedy_action
])
[docs] def learn(self, states: np.ndarray, actions: np.ndarray, rewards: np.ndarray, next_states: np.ndarray):
states = torch.from_numpy(states)
actions = torch.from_numpy(actions)
rewards = torch.from_numpy(rewards)
next_states = torch.from_numpy(next_states)
if self.device:
states = states.to(self.device)
actions = actions.to(self.device)
rewards = rewards.to(self.device)
next_states = next_states.to(self.device)
q_all = self._get_q_values(states)
q = select_by_actions(q_all, actions)
next_q_all = self._get_q_values(next_states, is_eval=False, training=False)
if self.config.double:
next_q = select_by_actions(next_q_all) # (N,)
else:
next_q, _ = get_max(next_q_all) # (N,)
loss = get_td_errors(q, next_q, rewards, self.config.reward_discount, loss_func=self.config.loss_func)
self.model.step(loss.mean())
self._training_counter += 1
if self._training_counter % self.config.target_update_freq == 0:
self._target_model.soft_update(self.model, self.config.tau)
return loss.detach().numpy()
[docs] def set_exploration_params(self, epsilon):
self.config.epsilon = epsilon
def _get_q_values(self, states: torch.Tensor, is_eval: bool = True, training: bool = True):
output = self.model(states, training=training) if is_eval else self._target_model(states, training=False)
if self.config.advantage_type is None:
return output
else:
state_values = output["state_value"]
advantages = output["advantage"]
# Use mean or max correction to address the identifiability issue
corrections = advantages.mean(1) if self.config.advantage_type == "mean" else advantages.max(1)[0]
return state_values + advantages - corrections.unsqueeze(1)