Live Editing Source Code
This example demonstrates how to create an interactive code editor with live evaluation and plotting. You'll see state management using Plot.State, widget.state.update, and Plot.onChange, as well as live code evaluation using Python's exec.
import jax
import jax.numpy as jnp
import colight.plot as Plot
from colight.plot import js
key = jax.random.key(314159)
thetas = jnp.arange(0.0, 1.0, 0.0005)
sigma = 0.05
def noisy_jax_model(key, theta, sigma):
# Sample a bernoulli random variable to determine which noise model to use
b = jax.random.bernoulli(key, theta)
# If b=True: noise proportional to theta, if b=False: constant noise plus linear term
return jax.lax.cond(
b,
lambda theta: jax.random.normal(key) * sigma * theta,
lambda theta: jax.random.normal(key) * sigma + theta * 2,
theta,
)
def make_samples(key, thetas, sigma, model_func):
# Vectorize model over array of thetas using unique random keys for each
return jax.vmap(model_func, in_axes=(0, 0, None))(
jax.random.split(key, len(thetas)), thetas, sigma
)
initial_source = """sigma = 0.05
def noisy_jax_model(key, theta, sigma):
# Sample a bernoulli random variable to determine which noise model to use
b = jax.random.bernoulli(key, theta)
# If b=True: noise proportional to theta, if b=False: constant noise plus linear term
return jax.lax.cond(
b,
lambda theta: jax.random.normal(key) * sigma * theta,
lambda theta: jax.random.normal(key) * sigma + theta * 2,
theta,
)"""
initial_state = Plot.State(
{
"samples": make_samples(key, thetas, sigma, noisy_jax_model),
"thetas": thetas,
"toEval": "",
"source": initial_source,
}
)
Callback function
def evaluate(widget, _e):
# Update random key and evaluate new code from text editor
global key, sigma, noisy_jax_model
key, subkey = jax.random.split(key, 2)
source = f"global sigma, noisy_jax_model\n{widget.state.toEval}"
exec(source)
widget.state.update(
{"samples": make_samples(subkey, thetas, sigma, noisy_jax_model)}
)
Plot.dot will render our samples as a scatter plot. We pass $state.thetas and $state.samples in columnar format.
samples_plot = Plot.dot(
{"x": js("$state.thetas"), "y": js("$state.samples")}, fill="rgba(0, 128, 128, 0.3)"
) + {"height": 400}
(
initial_state
| Plot.onChange({"toEval": evaluate})
| Plot.html(
[
"form.!flex.flex-col.gap-3",
{
"onSubmit": js(
"e => { e.preventDefault(); $state.toEval = $state.source}"
)
},
samples_plot,
[
"textarea.whitespace-pre-wrap.text-[13px].lh-normal.p-3.rounded-md.bg-gray-100.flex-1.h-[300px].font-mono",
{
"rows": js("$state.source.split('\\n').length+1"),
"onChange": js("(e) => $state.source = e.target.value"),
"value": js("$state.source"),
"onKeyDown": js(
"(e) => { if (e.ctrlKey && e.key === 'Enter') { e.stopPropagation(); $state.toEval = $state.source } }"
),
},
],
[
"div.flex.items-stretch",
[
"button.flex-auto.!bg-blue-500.!hover:bg-blue-600.text-white.text-center.px-4.py-2.rounded-md.cursor-pointer",
{"type": "submit"},
"Evaluate and Plot",
],
[
"div.flex.items-center.p-2",
{
"onClick": lambda widget, _: widget.state.update(
{"source": initial_source}
)
},
"Reset Source",
],
],
]
)
)