Source code for adaptive.learner.average_learner1D

import math
import sys
from collections import defaultdict
from copy import deepcopy
from math import hypot
from typing import (
    Callable,
    DefaultDict,
    Dict,
    Iterable,
    List,
    Optional,
    Sequence,
    Set,
    Tuple,
)

import numpy as np
import scipy.stats
from sortedcollections import ItemSortedDict
from sortedcontainers import SortedDict

from adaptive.learner.learner1D import Learner1D, _get_intervals
from adaptive.notebook_integration import ensure_holoviews
from adaptive.types import Real

Point = Tuple[int, Real]
Points = List[Point]

__all__: List[str] = ["AverageLearner1D"]


[docs]class AverageLearner1D(Learner1D): """Learns and predicts a noisy function 'f:ℝ → ℝ'. Parameters ---------- function : callable The function to learn. Must take a tuple of ``(seed, x)`` and return a real number. bounds : pair of reals The bounds of the interval on which to learn 'function'. loss_per_interval: callable, optional A function that returns the loss for a single interval of the domain. If not provided, then a default is used, which uses the scaled distance in the x-y plane as the loss. See the notes for more details of `adaptive.Learner1D` for more details. delta : float, optional, default 0.2 This parameter controls the resampling condition. A point is resampled if its uncertainty is larger than delta times the smallest neighboring interval. We strongly recommend ``0 < delta <= 1``. alpha : float (0 < alpha < 1), default 0.005 The true value of the function at x is within the confidence interval ``[self.data[x] - self.error[x], self.data[x] + self.error[x]]`` with probability ``1-2*alpha``. We recommend to keep ``alpha=0.005``. neighbor_sampling : float (0 < neighbor_sampling <= 1), default 0.3 Each new point is initially sampled at least a (neighbor_sampling*100)% of the average number of samples of its neighbors. min_samples : int (min_samples > 0), default 50 Minimum number of samples at each point x. Each new point is initially sampled at least min_samples times. max_samples : int (min_samples < max_samples), default np.inf Maximum number of samples at each point x. min_error : float (min_error >= 0), default 0 Minimum size of the confidence intervals. The true value of the function at x is within the confidence interval [self.data[x] - self.error[x], self.data[x] + self.error[x]] with probability 1-2*alpha. If self.error[x] < min_error, then x will not be resampled anymore, i.e., the smallest confidence interval at x is [self.data[x] - min_error, self.data[x] + min_error]. """ def __init__( self, function: Callable[[Tuple[int, Real]], Real], bounds: Tuple[Real, Real], loss_per_interval: Optional[ Callable[[Sequence[Real], Sequence[Real]], float] ] = None, delta: float = 0.2, alpha: float = 0.005, neighbor_sampling: float = 0.3, min_samples: int = 50, max_samples: int = sys.maxsize, min_error: float = 0, ): if not (0 < delta <= 1): raise ValueError("Learner requires 0 < delta <= 1.") if not (0 < alpha <= 1): raise ValueError("Learner requires 0 < alpha <= 1.") if not (0 < neighbor_sampling <= 1): raise ValueError("Learner requires 0 < neighbor_sampling <= 1.") if min_samples < 0: raise ValueError("min_samples should be positive.") if min_samples > max_samples: raise ValueError("max_samples should be larger than min_samples.") super().__init__(function, bounds, loss_per_interval) self.delta = delta self.alpha = alpha self.min_samples = min_samples self.min_error = min_error self.max_samples = max_samples self.neighbor_sampling = neighbor_sampling # Contains all samples f(x) for each # point x in the form {x0: {0: f_0(x0), 1: f_1(x0), ...}, ...} self._data_samples = SortedDict() # Contains the number of samples taken # at each point x in the form {x0: n0, x1: n1, ...} self._number_samples = SortedDict() # This set contains the points x that have less than min_samples # samples or less than a (neighbor_sampling*100)% of their neighbors self._undersampled_points: Set[Real] = set() # Contains the error in the estimate of the # mean at each point x in the form {x0: error(x0), ...} self.error: Dict[Real, float] = decreasing_dict() #  Distance between two neighboring points in the # form {xi: ((xii-xi)^2 + (yii-yi)^2)^0.5, ...} self._distances: Dict[Real, float] = decreasing_dict() # {xii: error[xii]/min(_distances[xi], _distances[xii], ...} self.rescaled_error: Dict[Real, float] = decreasing_dict() @property def nsamples(self) -> int: """Returns the total number of samples""" return sum(self._number_samples.values()) @property def min_samples_per_point(self) -> int: if not self._number_samples: return 0 return min(self._number_samples.values())
[docs] def ask(self, n: int, tell_pending: bool = True) -> Tuple[Points, List[float]]: """Return 'n' points that are expected to maximally reduce the loss.""" # If some point is undersampled, resample it if len(self._undersampled_points): x = next(iter(self._undersampled_points)) points, loss_improvements = self._ask_for_more_samples(x, n) # If less than 2 points were sampled, sample a new one elif len(self.data) <= 1: # TODO: if `n` is very large, we should suggest a few different points. points, loss_improvements = self._ask_for_new_point(n) #  Else, check the resampling condition else: if len(self.rescaled_error): # This is in case rescaled_error is empty (e.g. when sigma=0) x, resc_error = self.rescaled_error.peekitem(0) # Resampling condition if resc_error > self.delta: points, loss_improvements = self._ask_for_more_samples(x, n) else: points, loss_improvements = self._ask_for_new_point(n) else: points, loss_improvements = self._ask_for_new_point(n) if tell_pending: for p in points: self.tell_pending(p) return points, loss_improvements
def _ask_for_more_samples(self, x: Real, n: int) -> Tuple[Points, List[float]]: """When asking for n points, the learner returns n times an existing point to be resampled, since in general n << min_samples and this point will need to be resampled many more times""" n_existing = self._number_samples.get(x, 0) points = [(seed + n_existing, x) for seed in range(n)] xl, xr = self.neighbors_combined[x] loss_left = self.losses_combined.get((xl, x), float("inf")) loss_right = self.losses_combined.get((x, xr), float("inf")) loss = (loss_left + loss_right) / 2 if math.isinf(loss): loss_improvement = float("inf") else: loss_improvement = loss - loss * np.sqrt(n_existing) / np.sqrt( n_existing + n ) loss_improvements = [loss_improvement / n] * n return points, loss_improvements def _ask_for_new_point(self, n: int) -> Tuple[Points, List[float]]: """When asking for n new points, the learner returns n times a single new point, since in general n << min_samples and this point will need to be resampled many more times""" points, (loss_improvement,) = self._ask_points_without_adding(1) points = [(seed, x) for seed, x in zip(range(n), n * points)] loss_improvements = [loss_improvement / n] * n return points, loss_improvements
[docs] def tell_pending(self, seed_x: Point) -> None: _, x = seed_x self.pending_points.add(seed_x) if x not in self.data: self._update_neighbors(x, self.neighbors_combined) self._update_losses(x, real=False)
[docs] def tell(self, seed_x: Point, y: Real) -> None: seed, x = seed_x if y is None: raise TypeError( "Y-value may not be None, use learner.tell_pending(x)" "to indicate that this value is currently being calculated" ) if x not in self.data: self._update_data(x, y, "new") self._update_data_structures(seed_x, y, "new") elif seed not in self._data_samples[x]: # check if the seed is new self._update_data(x, y, "resampled") self._update_data_structures(seed_x, y, "resampled") self.pending_points.discard(seed_x)
def _update_rescaled_error_in_mean(self, x: Real, point_type: str) -> None: """Updates ``self.rescaled_error``. Parameters ---------- point_type : str Must be either "new" or "resampled". """ #  Update neighbors x_left, x_right = self.neighbors[x] dists = self._distances if x_left is None and x_right is None: return if x_left is None: d_left = dists[x] else: d_left = dists[x_left] if x_left in self.rescaled_error: xll = self.neighbors[x_left][0] norm = dists[x_left] if xll is None else min(dists[xll], dists[x_left]) self.rescaled_error[x_left] = self.error[x_left] / norm if x_right is None: d_right = dists[x_left] else: d_right = dists[x] if x_right in self.rescaled_error: xrr = self.neighbors[x_right][1] norm = dists[x] if xrr is None else min(dists[x], dists[x_right]) self.rescaled_error[x_right] = self.error[x_right] / norm # Update x if point_type == "resampled": norm = min(d_left, d_right) self.rescaled_error[x] = self.error[x] / norm def _update_data(self, x: Real, y: Real, point_type: str) -> None: if point_type == "new": self.data[x] = y elif point_type == "resampled": n = len(self._data_samples[x]) new_average = self.data[x] * n / (n + 1) + y / (n + 1) self.data[x] = new_average def _update_data_structures(self, seed_x: Point, y: Real, point_type: str) -> None: seed, x = seed_x if point_type == "new": self._data_samples[x] = {seed: y} if not self.bounds[0] <= x <= self.bounds[1]: return self._update_neighbors(x, self.neighbors_combined) self._update_neighbors(x, self.neighbors) self._update_scale(x, y) self._update_losses(x, real=True) # If the scale has increased enough, recompute all losses. if self._scale[1] > self._recompute_losses_factor * self._oldscale[1]: for interval in reversed(self.losses): self._update_interpolated_loss_in_interval(*interval) self._oldscale = deepcopy(self._scale) self._number_samples[x] = 1 self._undersampled_points.add(x) self.error[x] = np.inf self.rescaled_error[x] = np.inf self._update_distances(x) self._update_rescaled_error_in_mean(x, "new") elif point_type == "resampled": self._data_samples[x][seed] = y ns = self._number_samples ns[x] += 1 n = ns[x] if (x in self._undersampled_points) and (n >= self.min_samples): x_left, x_right = self.neighbors[x] if x_left is not None and x_right is not None: nneighbor = (ns[x_left] + ns[x_right]) / 2 elif x_left is not None: nneighbor = ns[x_left] elif x_right is not None: nneighbor = ns[x_right] else: nneighbor = 0 if n > self.neighbor_sampling * nneighbor: self._undersampled_points.discard(x) # We compute the error in the estimation of the mean as # the std of the mean multiplied by a t-Student factor to ensure that # the mean value lies within the correct interval of confidence y_avg = self.data[x] ys = self._data_samples[x].values() self.error[x] = self._calc_error_in_mean(ys, y_avg, n) self._update_distances(x) self._update_rescaled_error_in_mean(x, "resampled") if self.error[x] <= self.min_error or n >= self.max_samples: self.rescaled_error.pop(x, None) # We also need to update scale and losses self._update_scale(x, y) self._update_losses_resampling(x, real=True) # If the scale has increased enough, recompute all losses. # We only update the scale considering resampled points, since new # points are more likely to be outliers. if self._scale[1] > self._recompute_losses_factor * self._oldscale[1]: for interval in reversed(self.losses): self._update_interpolated_loss_in_interval(*interval) self._oldscale = deepcopy(self._scale) def _update_distances(self, x: Real) -> None: x_left, x_right = self.neighbors[x] y = self.data[x] if x_left is not None: self._distances[x_left] = hypot((x - x_left), (y - self.data[x_left])) if x_right is not None: self._distances[x] = hypot((x_right - x), (self.data[x_right] - y)) def _update_losses_resampling(self, x: Real, real=True) -> None: """Update all losses that depend on x, whenever the new point is a re-sampled point.""" # (x_left, x_right) are the "real" neighbors of 'x'. x_left, x_right = self._find_neighbors(x, self.neighbors) # (a, b) are the neighbors of the combined interpolated # and "real" intervals. a, b = self._find_neighbors(x, self.neighbors_combined) if real: for ival in _get_intervals(x, self.neighbors, self.nth_neighbors): self._update_interpolated_loss_in_interval(*ival) elif x_left is not None and x_right is not None: # 'x' happens to be in between two real points, # so we can interpolate the losses. dx = x_right - x_left loss = self.losses[x_left, x_right] self.losses_combined[a, x] = (x - a) * loss / dx self.losses_combined[x, b] = (b - x) * loss / dx # (no real point left of x) or (no real point right of a) left_loss_is_unknown = (x_left is None) or (not real and x_right is None) if (a is not None) and left_loss_is_unknown: self.losses_combined[a, x] = float("inf") # (no real point right of x) or (no real point left of b) right_loss_is_unknown = (x_right is None) or (not real and x_left is None) if (b is not None) and right_loss_is_unknown: self.losses_combined[x, b] = float("inf") def _calc_error_in_mean(self, ys: Iterable[Real], y_avg: Real, n: int) -> float: variance_in_mean = sum((y - y_avg) ** 2 for y in ys) / (n - 1) t_student = scipy.stats.t.ppf(1 - self.alpha, df=n - 1) return t_student * (variance_in_mean / n) ** 0.5
[docs] def tell_many(self, xs: Points, ys: Sequence[Real]) -> None: # Check that all x are within the bounds # TODO: remove this requirement, all other learners add the data # but ignore it going forward. if not np.prod([x >= self.bounds[0] and x <= self.bounds[1] for _, x in xs]): raise ValueError( "x value out of bounds, " "remove x or enlarge the bounds of the learner" ) # Create a mapping of points to a list of samples mapping: DefaultDict[Real, DefaultDict[int, Real]] = defaultdict( lambda: defaultdict(dict) ) for (seed, x), y in zip(xs, ys): mapping[x][seed] = y for x, seed_y_mapping in mapping.items(): if len(seed_y_mapping) == 1: seed, y = list(seed_y_mapping.items())[0] self.tell((seed, x), y) elif len(seed_y_mapping) > 1: # If we stored more than 1 y-value for the previous x, # use a more efficient routine to tell many samples # simultaneously, before we move on to a new x self.tell_many_at_point(x, seed_y_mapping)
[docs] def tell_many_at_point(self, x: Real, seed_y_mapping: Dict[int, Real]) -> None: """Tell the learner about many samples at a certain location x. Parameters ---------- x : float Value from the function domain. seed_y_mapping : Dict[int, Real] Dictionary of ``seed`` -> ``y`` at ``x``. """ # Check x is within the bounds if not np.prod(x >= self.bounds[0] and x <= self.bounds[1]): raise ValueError( "x value out of bounds, " "remove x or enlarge the bounds of the learner" ) # If x is a new point: if x not in self.data: # we make a copy because we don't want to modify the original dict seed_y_mapping = seed_y_mapping.copy() seed = next(iter(seed_y_mapping)) y = seed_y_mapping.pop(seed) self._update_data(x, y, "new") self._update_data_structures((seed, x), y, "new") ys = np.array(list(seed_y_mapping.values())) # If x is not a new point or if there were more than 1 sample in ys: if len(ys) > 0: self._data_samples[x].update(seed_y_mapping) n = len(ys) + self._number_samples[x] self.data[x] = ( np.mean(ys) * len(ys) + self.data[x] * self._number_samples[x] ) / n self._number_samples[x] = n # `self._update_data(x, y, "new")` included the point # in _undersampled_points. We remove it if there are # more than min_samples samples, disregarding neighbor_sampling. if n > self.min_samples: self._undersampled_points.discard(x) self.error[x] = self._calc_error_in_mean( self._data_samples[x].values(), self.data[x], n ) self._update_distances(x) self._update_rescaled_error_in_mean(x, "resampled") if self.error[x] <= self.min_error or n >= self.max_samples: self.rescaled_error.pop(x, None) self._update_scale(x, min(self._data_samples[x].values())) self._update_scale(x, max(self._data_samples[x].values())) self._update_losses_resampling(x, real=True) if self._scale[1] > self._recompute_losses_factor * self._oldscale[1]: for interval in reversed(self.losses): self._update_interpolated_loss_in_interval(*interval) self._oldscale = deepcopy(self._scale)
def _get_data(self) -> Dict[Real, Real]: return self._data_samples def _set_data(self, data: Dict[Real, Real]) -> None: if data: for x, samples in data.items(): self.tell_many_at_point(x, samples)
[docs] def plot(self): """Returns a plot of the evaluated data with error bars (not implemented for vector functions, i.e., it requires vdim=1). Returns ------- plot : `holoviews.element.Scatter * holoviews.element.ErrorBars * holoviews.element.Path` Plot of the evaluated data. """ hv = ensure_holoviews() if not self.data: p = hv.Scatter([]) * hv.ErrorBars([]) * hv.Path([]) elif not self.vdim > 1: xs, ys = zip(*sorted(self.data.items())) scatter = hv.Scatter(self.data) error = hv.ErrorBars([(x, self.data[x], self.error[x]) for x in self.data]) line = hv.Path((xs, ys)) p = scatter * error * line else: raise Exception("plot() not implemented for vector functions.") # Plot with 5% empty margins such that the boundary points are visible margin = 0.05 * (self.bounds[1] - self.bounds[0]) plot_bounds = (self.bounds[0] - margin, self.bounds[1] + margin) return p.redim(x=dict(range=plot_bounds))
def decreasing_dict() -> Dict: """This initialization orders the dictionary from large to small values""" def sorting_rule(key, value): return -value return ItemSortedDict(sorting_rule, SortedDict())