# -*- coding: utf-8 -*-
import asyncio
import datetime
import importlib
import random
import warnings
from contextlib import suppress
_async_enabled = False
_holoviews_enabled = False
_ipywidgets_enabled = False
_plotly_enabled = False
[docs]def notebook_extension(*, _inline_js=True):
"""Enable ipywidgets, holoviews, and asyncio notebook integration."""
if not in_ipynb():
raise RuntimeError('"adaptive.notebook_extension()" may only be run '
'from a Jupyter notebook.')
global _async_enabled, _holoviews_enabled, _ipywidgets_enabled
# Load holoviews
try:
_holoviews_enabled = False # After closing a notebook the js is gone
if not _holoviews_enabled:
import holoviews
holoviews.notebook_extension('bokeh', logo=False, inline=_inline_js)
_holoviews_enabled = True
except ModuleNotFoundError:
warnings.warn("holoviews is not installed; plotting "
"is disabled.", RuntimeWarning)
# Load ipywidgets
try:
if not _ipywidgets_enabled:
import ipywidgets
_ipywidgets_enabled = True
except ModuleNotFoundError:
warnings.warn("ipywidgets is not installed; live_info "
"is disabled.", RuntimeWarning)
# Enable asyncio integration
if not _async_enabled:
get_ipython().magic('gui asyncio')
_async_enabled = True
[docs]def ensure_holoviews():
try:
return importlib.import_module('holoviews')
except ModuleNotFoundError:
raise RuntimeError('holoviews is not installed; plotting is disabled.')
[docs]def ensure_plotly():
global _plotly_enabled
try:
import plotly
if not _plotly_enabled:
import plotly.graph_objs
import plotly.figure_factory
import plotly.offline
# This injects javascript and should happen only once
plotly.offline.init_notebook_mode()
_plotly_enabled = True
return plotly
except ModuleNotFoundError:
raise RuntimeError('plotly is not installed; plotting is disabled.')
[docs]def in_ipynb():
try:
# If we are running in IPython, then `get_ipython()` is always a global
return get_ipython().__class__.__name__ == 'ZMQInteractiveShell'
except NameError:
return False
# Fancy displays in the Jupyter notebook
active_plotting_tasks = dict()
[docs]def live_plot(runner, *, plotter=None, update_interval=2,
name=None, normalize=True):
"""Live plotting of the learner's data.
Parameters
----------
runner : `Runner`
plotter : function
A function that takes the learner as a argument and returns a
holoviews object. By default ``learner.plot()`` will be called.
update_interval : int
Number of second between the updates of the plot.
name : hasable
Name for the `live_plot` task in `adaptive.active_plotting_tasks`.
By default the name is None and if another task with the same name
already exists that other `live_plot` is canceled.
normalize : bool
Normalize (scale to fit) the frame upon each update.
Returns
-------
dm : `holoviews.core.DynamicMap`
The plot that automatically updates every `update_interval`.
"""
if not _holoviews_enabled:
raise RuntimeError(
"Live plotting is not enabled; did you run "
"'adaptive.notebook_extension()'?"
)
import holoviews as hv
import ipywidgets
from IPython.display import display
if name in active_plotting_tasks:
active_plotting_tasks[name].cancel()
def plot_generator():
while True:
if not plotter:
yield runner.learner.plot()
else:
yield plotter(runner.learner)
streams = [hv.streams.Stream.define("Next")()]
dm = hv.DynamicMap(plot_generator(), streams=streams)
if normalize:
# XXX: change when https://github.com/pyviz/holoviews/issues/3637
# is fixed.
dm = dm.map(lambda obj: obj.opts(framewise=True), hv.Element)
cancel_button = ipywidgets.Button(description='cancel live-plot',
layout=ipywidgets.Layout(width='150px'))
# Could have used dm.periodic in the following, but this would either spin
# off a thread (and learner is not threadsafe) or block the kernel.
async def updater():
event = lambda: hv.streams.Stream.trigger(dm.streams) # XXX: used to be dm.event()
# see https://github.com/pyviz/holoviews/issues/3564
try:
while not runner.task.done():
event()
await asyncio.sleep(update_interval)
event() # fire off one last update before we die
finally:
if active_plotting_tasks[name] is asyncio.Task.current_task():
active_plotting_tasks.pop(name, None)
cancel_button.layout.display = 'none' # remove cancel button
def cancel(_):
with suppress(KeyError):
active_plotting_tasks[name].cancel()
active_plotting_tasks[name] = runner.ioloop.create_task(updater())
cancel_button.on_click(cancel)
display(cancel_button)
return dm
[docs]def should_update(status):
try:
# Get the length of the write buffer size
buffer_size = len(status.comm.kernel.iopub_thread._events)
# Make sure to only keep all the messages when the notebook
# is viewed, this means 'buffer_size == 1'. However, when not
# viewing the notebook the buffer fills up. When this happens
# we decide to only add messages to it when a certain probability.
# i.e. we're offline for 12h, with an update_interval of 0.5s,
# and without the reduced probability, we have buffer_size=86400.
# With the correction this is np.log(86400) / np.log(1.1) = 119.2
return 1.1**buffer_size * random.random() < 1
except Exception:
# We catch any Exception because we are using a private API.
return True
[docs]def live_info(runner, *, update_interval=0.5):
"""Display live information about the runner.
Returns an interactive ipywidget that can be
visualized in a Jupyter notebook.
"""
if not _holoviews_enabled:
raise RuntimeError("Live plotting is not enabled; did you run "
"'adaptive.notebook_extension()'?")
import ipywidgets
from IPython.display import display
status = ipywidgets.HTML(value=_info_html(runner))
cancel = ipywidgets.Button(description='cancel runner',
layout=ipywidgets.Layout(width='100px'))
cancel.on_click(lambda _: runner.cancel())
async def update():
while not runner.task.done():
await asyncio.sleep(update_interval)
if should_update(status):
status.value = _info_html(runner)
else:
await asyncio.sleep(0.05)
status.value = _info_html(runner)
cancel.layout.display = 'none'
runner.ioloop.create_task(update())
display(ipywidgets.HBox(
(status, cancel),
layout=ipywidgets.Layout(border='solid 1px',
width='200px',
align_items='center'),
))
def _info_html(runner):
status = runner.status()
color = {'cancelled': 'orange',
'failed': 'red',
'running': 'blue',
'finished': 'green'}[status]
info = [
('status', f'<font color="{color}">{status}</font>'),
('elapsed time', datetime.timedelta(seconds=runner.elapsed_time())),
('overhead', f'{runner.overhead():.2f}%'),
]
with suppress(Exception):
info.append(('# of points', runner.learner.npoints))
with suppress(Exception):
info.append(('latest loss', f'{runner.learner._cache["loss"]:.3f}'))
template = '<dt class="ignore-css">{}</dt><dd>{}</dd>'
table = '\n'.join(template.format(k, v) for k, v in info)
return f'''
<dl>
{table}
</dl>
'''