# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from abc import ABC, abstractmethod
import torch
from maro.rl.model import AbsCoreModel
[docs]class AbsAgent(ABC):
"""Abstract RL agent class.
It's a sandbox for the RL algorithm. Scenario-specific details will be excluded.
We focus on the abstraction algorithm development here. Environment observation and decision events will
be converted to a uniform format before calling in. The output will be converted to an environment
executable format before return back to the environment. Its key responsibility is optimizing policy based
on interaction with the environment.
Args:
model (AbsCoreModel): Task model or container of task models required by the algorithm.
config: Settings for the algorithm.
"""
def __init__(self, model: AbsCoreModel, config):
self.model = model
self.config = config
self.device = None
[docs] def to_device(self, device):
self.device = device
self.model = self.model.to(device)
[docs] @abstractmethod
def choose_action(self, state):
"""This method uses the underlying model(s) to compute an action from a shaped state.
Args:
state: A state object shaped by a ``StateShaper`` to conform to the model input format.
Returns:
The action to be taken given ``state``. It is usually necessary to use an ``ActionShaper`` to convert
this to an environment executable action.
"""
return NotImplementedError
[docs] def set_exploration_params(self, **params):
pass
[docs] @abstractmethod
def learn(self, *args, **kwargs):
"""Algorithm-specific training logic.
The parameters are data to train the underlying model on. Algorithm-specific loss and optimization
should be reflected here.
"""
return NotImplementedError
[docs] def load_model(self, model):
"""Load models from memory."""
self.model.load_state_dict(model)
[docs] def dump_model(self):
"""Return the algorithm's trainable models."""
return self.model.state_dict()
[docs] def load_model_from_file(self, path: str):
"""Load trainable models from disk.
Load trainable models from the specified directory. The model file is always prefixed with the agent's name.
Args:
path (str): path to the directory where the models are saved.
"""
self.model.load_state_dict(torch.load(path))
[docs] def dump_model_to_file(self, path: str):
"""Dump the algorithm's trainable models to disk.
Dump trainable models to the specified directory. The model file is always prefixed with the agent's name.
Args:
path (str): path to the directory where the models are saved.
"""
torch.save(self.model.state_dict(), path)