Source code for maro.rl.storage.abs_store

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from abc import ABC, abstractmethod
from typing import Callable, Sequence


[docs]class AbsStore(ABC): """A data store abstraction that supports get, put, update and sample operations.""" def __init__(self): pass
[docs] @abstractmethod def get(self, indexes: Sequence): """Get contents. Args: indexes: A sequence of indexes to retrieve contents at. Returns: Retrieved contents. """ pass
[docs] def put(self, contents: Sequence): """Put new contents. Args: contents (Sequence): Contents to be added to the store. Returns: The indexes where the newly added entries reside in the store. """ pass
[docs] @abstractmethod def update(self, indexes: Sequence, contents: Sequence): """Update the store contents at given positions. Args: indexes (Sequence): Positions where updates are to be made. contents (Sequence): Item list, which has the same length as indexes. Returns: The indexes where store contents are updated. """ pass
[docs] def filter(self, filters: Sequence[Callable]): """Multi-filter method. The input to one filter is the output from the previous filter. Args: filters (Sequence[Callable]): Filter list, each item is a lambda function, e.g., [lambda d: d['a'] == 1 and d['b'] == 1]. Returns: Filtered indexes and corresponding objects. """ pass
[docs] @abstractmethod def sample(self, size: int, weights: Sequence, replace: bool = True): """Obtain a random sample from the experience pool. Args: size (int): Sample sizes for each round of sampling in the chain. If this is a single integer, it is used as the sample size for all samplers in the chain. weights (Sequence): A sequence of sampling weights. replace (bool): If True, sampling is performed with replacement. Defaults to True. Returns: A random sample from the experience pool. """ pass