from __future__ import annotations
import abc
import functools
import gzip
import inspect
import os
import pickle
import warnings
from contextlib import _GeneratorContextManager, contextmanager
from itertools import product
from typing import Any, Callable, Mapping, Sequence
import cloudpickle
[docs]def named_product(**items: Mapping[str, Sequence[Any]]):
names = items.keys()
vals = items.values()
return [dict(zip(names, res)) for res in product(*vals)]
[docs]@contextmanager
def restore(*learners) -> _GeneratorContextManager:
states = [learner.__getstate__() for learner in learners]
try:
yield
finally:
for state, learner in zip(states, learners):
learner.__setstate__(state)
[docs]def cache_latest(f: Callable) -> Callable:
"""Cache the latest return value of the function and add it
as 'self._cache[f.__name__]'."""
@functools.wraps(f)
def wrapper(*args, **kwargs):
self = args[0]
if not hasattr(self, "_cache"):
self._cache = {}
self._cache[f.__name__] = f(*args, **kwargs)
return self._cache[f.__name__]
return wrapper
[docs]def save(fname: str, data: Any, compress: bool = True) -> bool:
fname = os.path.expanduser(fname)
dirname = os.path.dirname(fname)
if dirname:
os.makedirs(dirname, exist_ok=True)
blob = cloudpickle.dumps(data, protocol=pickle.HIGHEST_PROTOCOL)
if compress:
blob = gzip.compress(blob)
temp_file = f"{fname}.{os.getpid()}"
try:
with open(temp_file, "wb") as f:
f.write(blob)
except OSError:
return False
try:
os.replace(temp_file, fname)
except OSError:
return False
finally:
if os.path.exists(temp_file):
os.remove(temp_file)
return True
[docs]def load(fname: str, compress: bool = True) -> Any:
fname = os.path.expanduser(fname)
_open = gzip.open if compress else open
with _open(fname, "rb") as f:
return cloudpickle.load(f)
[docs]def copy_docstring_from(other: Callable) -> Callable:
def decorator(method):
return functools.wraps(other)(method)
return decorator
class _RequireAttrsABCMeta(abc.ABCMeta):
def __call__(self, *args, **kwargs):
obj = super().__call__(*args, **kwargs)
for name, type_ in obj.__annotations__.items():
try:
x = getattr(obj, name)
except AttributeError:
raise AttributeError(
f"Required attribute {name} not set in __init__."
) from None
else:
if not isinstance(x, type_):
msg = f"The attribute '{name}' should be of type {type_}, not {type(x)}."
raise TypeError(msg)
return obj
def _default_parameters(function, function_prefix: str = "function."):
sig = inspect.signature(function)
defaults = {
f"{function_prefix}{k}": v.default
for i, (k, v) in enumerate(sig.parameters.items())
if v.default != inspect._empty and i >= 1
}
return defaults
[docs]def assign_defaults(function, df, function_prefix: str = "function."):
defaults = _default_parameters(function, function_prefix)
for k, v in defaults.items():
df[k] = len(df) * [v]
df[k] = df[k].astype("category")
[docs]def partial_function_from_dataframe(function, df, function_prefix: str = "function."):
if function_prefix == "":
raise ValueError(
"The function_prefix cannot be an empty string because"
" it is used to distinguish between function and learner parameters."
)
kwargs = {}
for col in df.columns:
if col.startswith(function_prefix):
k = col.split(function_prefix, 1)[1]
vs = df[col]
v, *rest = vs.unique()
if rest:
raise ValueError(f"The column '{col}' can only have one value.")
kwargs[k] = v
if not kwargs:
return function
sig = inspect.signature(function)
for k, v in kwargs.items():
if k not in sig.parameters:
raise ValueError(
f"The DataFrame contains a default parameter"
f" ({k}={v}) but the function does not have that parameter."
)
default = sig.parameters[k].default
if default != inspect._empty and kwargs[k] != default:
warnings.warn(
f"The DataFrame contains a default parameter"
f" ({k}={v}) but the function already has a default ({k}={default})."
" The DataFrame's value will be used."
)
return functools.partial(function, **kwargs)