# Based on an adaptive quadrature algorithm by Pedro Gonnet
import sys
from collections import defaultdict
from math import sqrt
from operator import attrgetter
import numpy as np
from scipy.linalg import norm
from sortedcontainers import SortedSet
from adaptive.learner.base_learner import BaseLearner
from adaptive.notebook_integration import ensure_holoviews
from adaptive.utils import cache_latest, restore
from .integrator_coeffs import (
T_left,
T_right,
V_inv,
Vcond,
alpha,
b_def,
eps,
gamma,
hint,
min_sep,
ndiv_max,
ns,
xi,
)
def _downdate(c, nans, depth):
# This is algorithm 5 from the thesis of Pedro Gonnet.
b = b_def[depth].copy()
m = ns[depth] - 1
for i in nans:
b[m + 1] /= alpha[m]
xii = xi[depth][i]
b[m] = (b[m] + xii * b[m + 1]) / alpha[m - 1]
for j in range(m - 1, 0, -1):
b[j] = (b[j] + xii * b[j + 1] - gamma[j + 1] * b[j + 2]) / alpha[j - 1]
b = b[1:]
c[:m] -= c[m] / b[m] * b[:m]
c[m] = 0
m -= 1
return c
def _zero_nans(fx):
"""Caution: this function modifies fx."""
nans = []
for i in range(len(fx)):
if not np.isfinite(fx[i]):
nans.append(i)
fx[i] = 0.0
return nans
def _calc_coeffs(fx, depth):
"""Caution: this function modifies fx."""
nans = _zero_nans(fx)
c_new = V_inv[depth] @ fx
if nans:
fx[nans] = np.nan
c_new = _downdate(c_new, nans, depth)
return c_new
class DivergentIntegralError(ValueError):
pass
class _Interval:
"""
Attributes
----------
(a, b) : (float, float)
The left and right boundary of the interval.
c : numpy array of shape (4, 33)
Coefficients of the fit.
depth : int
The level of refinement, `depth=0` means that it has 5 (the minimal
number of) points and `depth=3` means it has 33 (the maximal number
of) points.
fx : numpy array of size `(5, 9, 17, 33)[self.depth]`.
The function values at the points `self.points(self.depth)`.
igral : float
The integral value of the interval.
err : float
The error associated with the integral value.
rdepth : int
The number of splits that the interval has gone through, starting at 1.
ndiv : int
A number that is used to determine whether the interval is divergent.
parent : _Interval
The parent interval.
children : list of `_Interval`s
The intervals resulting from a split.
data : dict
A dictionary with the x-values and y-values: `{x1: y1, x2: y2 ...}`.
done : bool
The integral and the error for the interval has been calculated.
done_leaves : set or None
Leaves used for the error and the integral estimation of this
interval. None means that this information was already propagated to
the ancestors of this interval.
depth_complete : int or None
The level of refinement at which the interval has the integral value
evaluated. If None there is no level at which the integral value is
known yet.
Methods
-------
refinement_complete : depth, optional
If true, all the function values in the interval are known at `depth`.
By default the depth is the depth of the interval.
"""
__slots__ = [
"a",
"b",
"c",
"c00",
"depth",
"igral",
"err",
"fx",
"rdepth",
"ndiv",
"parent",
"children",
"data",
"done_leaves",
"depth_complete",
"removed",
]
def __init__(self, a, b, depth, rdepth):
self.children = []
self.data = {}
self.a = a
self.b = b
self.depth = depth
self.rdepth = rdepth
self.done_leaves = set()
self.depth_complete = None
self.removed = False
@classmethod
def make_first(cls, a, b, depth=2):
ival = _Interval(a, b, depth, rdepth=1)
ival.ndiv = 0
ival.parent = None
ival.err = sys.float_info.max # needed because inf/2 == inf
return ival
@property
def T(self):
"""Get the correct shift matrix.
Should only be called on children of a split interval.
"""
assert self.parent is not None
left = self.a == self.parent.a
right = self.b == self.parent.b
assert left != right
return T_left if left else T_right
def refinement_complete(self, depth):
"""The interval has all the y-values to calculate the intergral."""
if len(self.data) < ns[depth]:
return False
return all(p in self.data for p in self.points(depth))
def points(self, depth=None):
if depth is None:
depth = self.depth
a = self.a
b = self.b
return (a + b) / 2 + (b - a) * xi[depth] / 2
def refine(self):
self.depth += 1
return self
def split(self):
points = self.points()
m = points[len(points) // 2]
ivals = [
_Interval(self.a, m, 0, self.rdepth + 1),
_Interval(m, self.b, 0, self.rdepth + 1),
]
self.children = ivals
for ival in ivals:
ival.parent = self
ival.ndiv = self.ndiv
ival.err = self.err / 2
return ivals
def calc_igral(self):
self.igral = (self.b - self.a) * self.c[0] / sqrt(2)
def update_heuristic_err(self, value):
"""Sets the error of an interval using a heuristic (half the error of
the parent) when the actual error cannot be calculated due to its
parents not being finished yet. This error is propagated down to its
children."""
self.err = value
for child in self.children:
if child.depth_complete or (
child.depth_complete == 0 and self.depth_complete is not None
):
continue
child.update_heuristic_err(value / 2)
def calc_err(self, c_old):
c_new = self.c
c_diff = np.zeros(max(len(c_old), len(c_new)))
c_diff[: len(c_old)] = c_old
c_diff[: len(c_new)] -= c_new
c_diff = norm(c_diff)
self.err = (self.b - self.a) * c_diff
for child in self.children:
if child.depth_complete is None:
child.update_heuristic_err(self.err / 2)
return c_diff
def calc_ndiv(self):
div = self.parent.c00 and self.c00 / self.parent.c00 > 2
self.ndiv += div
if self.ndiv > ndiv_max and 2 * self.ndiv > self.rdepth:
raise DivergentIntegralError
if div:
for child in self.children:
child.update_ndiv_recursively()
def update_ndiv_recursively(self):
self.ndiv += 1
if self.ndiv > ndiv_max and 2 * self.ndiv > self.rdepth:
raise DivergentIntegralError
for child in self.children:
child.update_ndiv_recursively()
def complete_process(self, depth):
"""Calculate the integral contribution and error from this interval,
and update the done leaves of all ancestor intervals."""
assert self.depth_complete is None or self.depth_complete == depth - 1
self.depth_complete = depth
fx = [self.data[k] for k in self.points(depth)]
self.fx = np.array(fx)
force_split = False # This may change when refining
first_ival = self.parent is None and depth == 2
if depth and not first_ival:
# Store for usage in refine
c_old = self.c
self.c = _calc_coeffs(self.fx, depth)
if first_ival:
self.c00 = 0.0
return False, False
self.calc_igral()
if depth:
# Refine
c_diff = self.calc_err(c_old)
force_split = c_diff > hint * norm(self.c)
else:
# Split
self.c00 = self.c[0]
if self.parent.depth_complete is not None:
c_old = self.T[:, : ns[self.parent.depth_complete]] @ self.parent.c
self.calc_err(c_old)
self.calc_ndiv()
for child in self.children:
if child.depth_complete is not None:
child.calc_ndiv()
if child.depth_complete == 0:
c_old = child.T[:, : ns[self.depth_complete]] @ self.c
child.calc_err(c_old)
if self.done_leaves is not None and not len(self.done_leaves):
# This interval contributes to the integral estimate.
self.done_leaves = {self}
# Use this interval in the integral estimates of the ancestors
# while possible.
ival = self.parent
old_leaves = set()
while ival is not None:
unused_children = [
child for child in ival.children if child.done_leaves is not None
]
if not all(len(child.done_leaves) for child in unused_children):
break
if ival.done_leaves is None:
ival.done_leaves = set()
old_leaves.add(ival)
for child in ival.children:
if child.done_leaves is None:
continue
ival.done_leaves.update(child.done_leaves)
child.done_leaves = None
ival.done_leaves -= old_leaves
ival = ival.parent
remove = self.err < (abs(self.igral) * eps * Vcond[depth])
return force_split, remove
def __repr__(self):
lst = [
f"(a, b)=({self.a:.5f}, {self.b:.5f})",
f"depth={self.depth}",
f"rdepth={self.rdepth}",
f"err={self.err:.5E}",
"igral={:.5E}".format(self.igral if hasattr(self, "igral") else np.inf),
]
return " ".join(lst)
[docs]class IntegratorLearner(BaseLearner):
def __init__(self, function, bounds, tol):
"""
Parameters
----------
function : callable: X → Y
The function to learn.
bounds : pair of reals
The bounds of the interval on which to learn 'function'.
tol : float
Relative tolerance of the error to the integral, this means that
the learner is done when: `tol > err / abs(igral)`.
Attributes
----------
approximating_intervals : set of intervals
The intervals that can be used in the determination of the integral.
n : int
The total number of evaluated points.
igral : float
The integral value in `self.bounds`.
err : float
The absolute error associated with `self.igral`.
max_ivals : int, default: 1000
Maximum number of intervals that can be present in the calculation
of the integral. If this amount exceeds max_ivals, the interval
with the smallest error will be discarded.
Methods
-------
done : bool
Returns whether the `tol` has been reached.
plot : hv.Scatter
Plots all the points that are evaluated.
"""
self.function = function
self.bounds = bounds
self.tol = tol
self.max_ivals = 1000
self.priority_split = []
self.data = {}
self.pending_points = set()
self._stack = []
self.x_mapping = defaultdict(lambda: SortedSet([], key=attrgetter("rdepth")))
self.ivals = set()
ival = _Interval.make_first(*self.bounds)
self.add_ival(ival)
self.first_ival = ival
@property
def approximating_intervals(self):
return self.first_ival.done_leaves
[docs] def tell(self, point, value):
if point not in self.x_mapping:
raise ValueError(f"Point {point} doesn't belong to any interval")
self.data[point] = value
self.pending_points.discard(point)
# Select the intervals that have this point
ivals = self.x_mapping[point]
for ival in ivals:
ival.data[point] = value
if ival.depth_complete is None:
from_depth = 0 if ival.parent is not None else 2
else:
from_depth = ival.depth_complete + 1
for depth in range(from_depth, ival.depth + 1):
if ival.refinement_complete(depth):
force_split, remove = ival.complete_process(depth)
if remove:
# Remove the interval (while remembering the excess
# integral and error), since it is either too narrow,
# or the estimated relative error is already at the
# limit of numerical accuracy and cannot be reduced
# further.
self.propagate_removed(ival)
elif force_split and not ival.children:
# If it already has children it has already been split
assert ival in self.ivals
self.priority_split.append(ival)
[docs] def tell_pending(self):
pass
[docs] def propagate_removed(self, ival):
def _propagate_removed_down(ival):
ival.removed = True
self.ivals.discard(ival)
for child in ival.children:
_propagate_removed_down(child)
_propagate_removed_down(ival)
[docs] def add_ival(self, ival):
for x in ival.points():
# Update the mappings
self.x_mapping[x].add(ival)
if x in self.data:
self.tell(x, self.data[x])
elif x not in self.pending_points:
self.pending_points.add(x)
self._stack.append(x)
self.ivals.add(ival)
[docs] def ask(self, n, tell_pending=True):
"""Choose points for learners."""
if not tell_pending:
with restore(self):
return self._ask_and_tell_pending(n)
else:
return self._ask_and_tell_pending(n)
def _ask_and_tell_pending(self, n):
points, loss_improvements = self.pop_from_stack(n)
n_left = n - len(points)
while n_left > 0:
assert n_left >= 0
try:
self._fill_stack()
except ValueError:
raise RuntimeError("No way to improve the integral estimate.")
new_points, new_loss_improvements = self.pop_from_stack(n_left)
points += new_points
loss_improvements += new_loss_improvements
n_left -= len(new_points)
return points, loss_improvements
[docs] def pop_from_stack(self, n):
points = self._stack[:n]
self._stack = self._stack[n:]
loss_improvements = [
max(ival.err for ival in self.x_mapping[x]) for x in points
]
return points, loss_improvements
[docs] def remove_unfinished(self):
pass
def _fill_stack(self):
# XXX: to-do if all the ivals have err=inf, take the interval
# with the lowest rdepth and no children.
force_split = bool(self.priority_split)
if force_split:
ival = self.priority_split.pop()
else:
ival = max(self.ivals, key=lambda x: (x.err, x.a))
assert not ival.children
# If the interval points are smaller than machine precision, then
# don't continue with splitting or refining.
points = ival.points()
if (
points[1] - points[0] < points[0] * min_sep
or points[-1] - points[-2] < points[-2] * min_sep
):
self.ivals.remove(ival)
elif ival.depth == 3 or force_split:
# Always split when depth is maximal or if refining didn't help
self.ivals.remove(ival)
for ival in ival.split():
self.add_ival(ival)
else:
self.add_ival(ival.refine())
# Remove the interval with the smallest error
# if number of intervals is larger than max_ivals
if len(self.ivals) > self.max_ivals:
self.ivals.remove(min(self.ivals, key=lambda x: (x.err, x.a)))
return self._stack
@property
def npoints(self):
"""Number of evaluated points."""
return len(self.data)
@property
def igral(self):
return sum(i.igral for i in self.approximating_intervals)
@property
def err(self):
if self.approximating_intervals:
err = sum(i.err for i in self.approximating_intervals)
if err > sys.float_info.max:
err = np.inf
else:
err = np.inf
return err
[docs] def done(self):
err = self.err
igral = self.igral
err_excess = sum(i.err for i in self.approximating_intervals if i.removed)
return (
err == 0
or err < abs(igral) * self.tol
or (err - err_excess < abs(igral) * self.tol < err_excess)
or not self.ivals
)
[docs] @cache_latest
def loss(self, real=True):
return abs(abs(self.igral) * self.tol - self.err)
[docs] def plot(self):
hv = ensure_holoviews()
ivals = sorted(self.ivals, key=attrgetter("a"))
if not self.data:
return hv.Path([])
xs, ys = zip(*[(x, y) for ival in ivals for x, y in sorted(ival.data.items())])
return hv.Path((xs, ys))
def _get_data(self):
# Change the defaultdict of SortedSets to a normal dict of sets.
x_mapping = {k: set(v) for k, v in self.x_mapping.items()}
return (
self.priority_split,
self.data,
self.pending_points,
self._stack,
x_mapping,
self.ivals,
self.first_ival,
)
def _set_data(self, data):
(
self.priority_split,
self.data,
self.pending_points,
self._stack,
x_mapping,
self.ivals,
self.first_ival,
) = data
# Add the pending_points to the _stack such that they are evaluated again
for x in self.pending_points:
if x not in self._stack:
self._stack.append(x)
# x_mapping is a data structure that can't easily be saved
# so we recreate it here
self.x_mapping = defaultdict(lambda: SortedSet([], key=attrgetter("rdepth")))
for k, _set in x_mapping.items():
self.x_mapping[k].update(_set)
def __getstate__(self):
return (
self.function,
self.bounds,
self.tol,
self._get_data(),
)
def __setstate__(self, state):
function, bounds, tol, data = state
self.__init__(function, bounds, tol)
self._set_data(data)