diff --git a/diffrax/_root_finder/_verychord.py b/diffrax/_root_finder/_verychord.py index 921a6dd4..385a2cb9 100644 --- a/diffrax/_root_finder/_verychord.py +++ b/diffrax/_root_finder/_verychord.py @@ -162,7 +162,7 @@ def terminate( converged = _converged(factor, self.kappa) terminate = at_least_two & (small | diverged | converged) terminate_result = optx.RESULTS.where( - jnp.invert(small) & (diverged | jnp.invert(converged)), + at_least_two & jnp.invert(small) & (diverged | jnp.invert(converged)), optx.RESULTS.nonlinear_divergence, optx.RESULTS.successful, ) diff --git a/pyproject.toml b/pyproject.toml index 880c50c9..01b2e8f8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,7 +19,7 @@ classifiers = [ "Topic :: Scientific/Engineering :: Information Analysis", "Topic :: Scientific/Engineering :: Mathematics" ] -dependencies = ["jax>=0.4.38", "jaxtyping>=0.2.24", "typing_extensions>=4.5.0", "equinox>=0.11.10", "lineax>=0.0.5", "optimistix>=0.0.10", "wadler_lindig>=0.1.1"] +dependencies = ["jax>=0.4.38", "jaxtyping>=0.2.24", "typing_extensions>=4.5.0", "equinox>=0.11.10", "lineax>=0.0.5", "optimistix>=0.1.0", "wadler_lindig>=0.1.1"] description = "GPU+autodiff-capable ODE/SDE/CDE solvers written in JAX." keywords = ["jax", "dynamical-systems", "differential-equations", "deep-learning", "equinox", "neural-differential-equations", "diffrax"] license = {file = "LICENSE"} diff --git a/test/test_very_chord.py b/test/test_very_chord.py index c0d0287a..17424630 100644 --- a/test/test_very_chord.py +++ b/test/test_very_chord.py @@ -24,6 +24,9 @@ def _fn2(x, args): @jax.jit def _fn3(x, args): mlp = eqx.nn.MLP(4, 4, 256, 2, key=jr.PRNGKey(678)) + dynamic, static = eqx.partition(mlp, eqx.is_array) + dynamic = jtu.tree_map(lambda x: x * 0.1, dynamic) + mlp = eqx.combine(dynamic, static) return mlp(x) - x