Tutorial DataSaver
ΒΆ
Note
Because this documentation consists of static html, the live_plot
and live_info
widget is not live. Download the notebook
in order to see the real behaviour.
See also
The complete source code of this tutorial can be found in
tutorial.DataSaver.ipynb
import adaptive
adaptive.notebook_extension()
If the function that you want to learn returns a value along with some
metadata, you can wrap your learner in an adaptive.DataSaver
.
In the following example the function to be learned returns its result and the execution time in a dictionary:
from operator import itemgetter
def f_dict(x):
"""The function evaluation takes roughly the time we `sleep`."""
import random
from time import sleep
waiting_time = random.random()
sleep(waiting_time)
a = 0.01
y = x + a**2 / (a**2 + x**2)
return {'y': y, 'waiting_time': waiting_time}
# Create the learner with the function that returns a 'dict'
# This learner cannot be run directly, as Learner1D does not know what to do with the 'dict'
_learner = adaptive.Learner1D(f_dict, bounds=(-1, 1))
# Wrapping the learner with 'adaptive.DataSaver' and tell it which key it needs to learn
learner = adaptive.DataSaver(_learner, arg_picker=itemgetter('y'))
learner.learner
is the original learner, so
learner.learner.loss()
will call the correct loss method.
runner = adaptive.Runner(learner, goal=lambda l: l.learner.loss() < 0.1)
await runner.task # This is not needed in a notebook environment!
runner.live_info()
runner.live_plot(plotter=lambda l: l.learner.plot(), update_interval=0.1)
Now the DataSavingLearner
will have an dictionary attribute
extra_data
that has x
as key and the data that was returned by
learner.function
as values.
learner.extra_data
OrderedDict([(1, {'y': 1.000099990001, 'waiting_time': 0.47508615843437285}),
(-1,
{'y': -0.9999000099990001, 'waiting_time': 0.7970502914031158}),
(0.0, {'y': 1.0, 'waiting_time': 0.6444894636356452}),
(-0.75,
{'y': -0.7498222538215429, 'waiting_time': 0.21584433196969488}),
(-0.5,
{'y': -0.4996001599360256, 'waiting_time': 0.7257488252756469}),
(-0.25,
{'y': -0.24840255591054314, 'waiting_time': 0.709707390356665}),
(-0.125,
{'y': -0.11864069952305246,
'waiting_time': 0.057482163218342985}),
(-0.0625,
{'y': -0.0375390015600624, 'waiting_time': 0.0620452673812979}),
(0.5,
{'y': 0.5003998400639744, 'waiting_time': 0.8214860991559911}),
(0.75,
{'y': 0.7501777461784571, 'waiting_time': 0.5434958492780316}),
(-0.03125,
{'y': 0.06163824383164006, 'waiting_time': 0.7905975231668958}),
(0.25,
{'y': 0.2515974440894569, 'waiting_time': 0.12591674575199618}),
(0.125,
{'y': 0.13135930047694755, 'waiting_time': 0.10747153212798022}),
(-0.015625,
{'y': 0.27495388762769585, 'waiting_time': 0.22235195561959809}),
(0.0625,
{'y': 0.0874609984399376, 'waiting_time': 0.5084370475196994}),
(-0.0078125,
{'y': 0.6131699135839903, 'waiting_time': 0.5708648049160236}),
(0.015625,
{'y': 0.30620388762769585, 'waiting_time': 0.11572751351428023}),
(0.0078125,
{'y': 0.6287949135839903, 'waiting_time': 0.23648680471800254}),
(0.03125,
{'y': 0.12413824383164006, 'waiting_time': 0.9317882451497077}),
(0.00390625,
{'y': 0.8715190438995976, 'waiting_time': 0.10997374954764194}),
(-0.00390625,
{'y': 0.8637065438995976, 'waiting_time': 0.7854232833945709}),
(-0.625,
{'y': -0.6247440655192271, 'waiting_time': 0.09025183360966282}),
(-0.375,
{'y': -0.3742893942085628, 'waiting_time': 0.38319924277609974}),
(-0.875,
{'y': -0.8748694048124327, 'waiting_time': 0.325312133354032}),
(0.875,
{'y': 0.8751305951875673, 'waiting_time': 0.7052379455997435}),
(0.375,
{'y': 0.3757106057914372, 'waiting_time': 0.21427402821958785}),
(0.625,
{'y': 0.6252559344807729, 'waiting_time': 0.7152488978639219}),
(-0.01171875,
{'y': 0.40963707758975415, 'waiting_time': 0.40349946890142085}),
(-0.005859375,
{'y': 0.7385633611533918, 'waiting_time': 0.3499236250477289}),
(0.01171875,
{'y': 0.43307457758975415, 'waiting_time': 0.7307355376086471}),
(0.005859375,
{'y': 0.7502821111533918, 'waiting_time': 0.44856724453312546}),
(-0.009765625,
{'y': 0.5020904154886127,
'waiting_time': 0.0016725311933690756}),
(-0.0234375,
{'y': 0.1305706215220334, 'waiting_time': 0.5240530882947154})])