import abc
import functools
import gzip
import os
import pickle
from contextlib import contextmanager
from itertools import product
from atomicwrites import AtomicWriter
[docs]def named_product(**items):
names = items.keys()
vals = items.values()
return [dict(zip(names, res)) for res in product(*vals)]
[docs]@contextmanager
def restore(*learners):
states = [learner.__getstate__() for learner in learners]
try:
yield
finally:
for state, learner in zip(states, learners):
learner.__setstate__(state)
[docs]def cache_latest(f):
"""Cache the latest return value of the function and add it
as 'self._cache[f.__name__]'."""
@functools.wraps(f)
def wrapper(*args, **kwargs):
self = args[0]
if not hasattr(self, "_cache"):
self._cache = {}
self._cache[f.__name__] = f(*args, **kwargs)
return self._cache[f.__name__]
return wrapper
[docs]def save(fname, data, compress=True):
fname = os.path.expanduser(fname)
dirname = os.path.dirname(fname)
if dirname:
os.makedirs(dirname, exist_ok=True)
blob = pickle.dumps(data, protocol=pickle.HIGHEST_PROTOCOL)
if compress:
blob = gzip.compress(blob)
with AtomicWriter(fname, "wb", overwrite=True).open() as f:
f.write(blob)
[docs]def load(fname, compress=True):
fname = os.path.expanduser(fname)
_open = gzip.open if compress else open
with _open(fname, "rb") as f:
return pickle.load(f)
[docs]def copy_docstring_from(other):
def decorator(method):
return functools.wraps(other)(method)
return decorator
class _RequireAttrsABCMeta(abc.ABCMeta):
def __call__(self, *args, **kwargs):
obj = super().__call__(*args, **kwargs)
for name, type_ in obj.__annotations__.items():
try:
x = getattr(obj, name)
except AttributeError:
raise AttributeError(
f"Required attribute {name} not set in __init__."
) from None
else:
if not isinstance(x, type_):
msg = f"The attribute '{name}' should be of type {type_}, not {type(x)}."
raise TypeError(msg)
return obj