Skip to content

[BUG] bm.pre2post_event_sum requires numba on CPU, but not included in brainpy[cpu] #821

@VeriTas-arch

Description

@VeriTas-arch

Bug Description

pip install "brainpy[cpu]" succeeds without numba, but calling bm.pre2post_event_sum on CPU fails at runtime due to missing numba.

To Reproduce

Minimal code example named as test.py to reproduce the issue:

import importlib.util

import brainpy as bp
import brainpy.math as bm


class Test(bp.dyn.SynConn):
    def __init__(self, pre, post, conn):
        super().__init__(pre=pre, post=post, conn=conn)
        self.pre2post = self.conn.require("pre2post")
        self.delay = bm.LengthDelay(self.pre.spike, 0)

    def update(self):
        delayed_pre_spike = self.delay(0)
        self.delay.update(self.pre.spike)
        # The line below is the minimal trigger path.
        post_sp = bm.pre2post_event_sum(
            delayed_pre_spike, self.pre2post, self.post.num, 1.0
        )
        self.post.input += post_sp


def main():
    print("brainpy:", bp.__version__)
    print("numba installed:", importlib.util.find_spec("numba") is not None)

    pre = bp.neurons.SpikeTimeGroup(1, times=[0.0], indices=[0])
    post = bp.neurons.LIF(1)
    syn = Test(pre, post, conn=bp.connect.All2All())
    net = bp.DynSysGroup(pre=pre, syn=syn, post=post)

    # On CPU with environment lacking numba, this should fail here.
    runner = bp.DSRunner(net)
    runner.predict(1.0)


if __name__ == "__main__":
    main()

Steps to reproduce:

  1. conda create -n test python=3.12 -c conda-forge
  2. conda activate test
  3. pip install "brainpy[cpu]"
  4. python test.py

Expected Behavior

Either:

  1. This path runs with only brainpy[cpu], or
  2. The runtime dependency is declared/documented clearly for this API path.

Actual Behavior

Runtime error on CPU indicating numba is required when executing bm.pre2post_event_sum.

brainpy: 2.7.7
numba installed: False
Traceback (most recent call last):
  File "D:\miniconda3\envs\test\Lib\site-packages\jax\_src\interpreters\mlir.py", line 2519, in _lower_jaxpr_to_fun_cached
    func_op, _, _ = ctx.cached_primitive_lowerings[key]
                    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^
KeyError: ('closed_call', let _where = { lambda ; a:bool[1] b:f32[] c:f32[1]. let
    d:f32[] = convert_element_type[new_dtype=float32 weak_type=False] b
    e:f32[1] = broadcast_in_dim d
    f:f32[1] = select_n a c e
  in (f,) } in
let _where1 = { lambda ; g:bool[] h:i32[] i:i32[]. let
    j:i32[] = select_n g i h
  in (j,) } in
{ lambda ; k:i32[1] l:i32[2] m:f32[1] n:i32[1] o:i32[1] p:bool[1,1] q:bool[1] r:f32[1]
    s:i32[] t:f32[1] u:f32[1] v:bool[1] w:i32[] x:i32[]. let
    y:f32[] = convert_element_type[new_dtype=float32 weak_type=True] x
    z:f32[] = mul y 0.1:f32[]
    ba:f32[] = add 0.0:f32[] z
    bb:i32[] = squeeze[dimensions=(0,)] o
    bc:i32[] = add bb 0:i32[]
    bd:i32[] = jit[
      name=remainder
      jaxpr={ lambda ; bc:i32[] be:i32[]. let
          bf:i32[] = convert_element_type[new_dtype=int32 weak_type=False] be
          bg:bool[] = eq bf 0:i32[]
          bh:i32[] = jit[name=_where jaxpr=_where1] bg 1:i32[] bf
          bi:i32[] = rem bc bh
          bj:bool[] = ne bi 0:i32[]
          bk:bool[] = lt bi 0:i32[]
          bl:bool[] = lt bh 0:i32[]
          bm:bool[] = ne bk bl
          bn:bool[] = and bm bj
          bo:i32[] = add bi bh
          bd:i32[] = select_n bn bi bo
        in (bd,) }
    ] bc 1:i32[]
    bp:i32[] = stop_gradient bd
    bq:bool[] = lt bp 0:i32[]
    br:i32[] = add bp 1:i32[]
    bs:i32[] = select_n bq bp br
    bt:bool[] = lt 0:i32[] 0:i32[]
    bu:i32[] = add 0:i32[] 1:i32[]
    bv:i32[] = select_n bt 0:i32[] bu
    bw:bool[1,1] = dynamic_slice[slice_sizes=(1, 1)] p bs bv
    bx:bool[1] = squeeze[dimensions=(0,)] bw
    by:i32[1] = sub o 1:i32[]
    bz:i32[1] = jit[
      name=remainder
      jaxpr={ lambda ; by:i32[1] ca:i32[]. let
          cb:i32[] = convert_element_type[new_dtype=int32 weak_type=False] ca
          cc:bool[] = eq cb 0:i32[]
          cd:i32[] = jit[name=_where jaxpr=_where1] cc 1:i32[] cb
          ce:i32[1] = rem by cd
          cf:bool[1] = ne ce 0:i32[]
          cg:bool[1] = lt ce 0:i32[]
          ch:bool[] = lt cd 0:i32[]
          ci:bool[1] = ne cg ch
          cj:bool[1] = and ci cf
          ck:i32[1] = add ce cd
          bz:i32[1] = select_n cj ce ck
        in (bz,) }
    ] by 1:i32[]
    cl:i32[1] = stop_gradient bz
    cm:i32[] = squeeze[dimensions=(0,)] cl
    cn:bool[] = lt cm 0:i32[]
    co:i32[] = add cm 1:i32[]
    cp:i32[] = select_n cn cm co
    cq:i32[1] = broadcast_in_dim cp
    cr:bool[1,1] = scatter[
      dimension_numbers=ScatterDimensionNumbers(update_window_dims=(0,), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,), operand_batching_dims=(), scatter_indices_batching_dims=())
      indices_are_sorted=True
      mode=GatherScatterMode.FILL_OR_DROP
      unique_indices=True
      update_consts=()
      update_jaxpr=None
    ] p cq q
    cs:f32[1] = jit[
      name=brainevent.binary_csrmv
      jaxpr={ lambda ; ct:f32[] k:i32[1] l:i32[2] bx:bool[1]. let
          cu:f32[] = convert_element_type[new_dtype=float32 weak_type=False] ct
          cv:f32[1] = broadcast_in_dim cu
          cs:f32[1] = binary_csrmv[
            backend=None
            indices_info=ShapeDtypeStruct(shape=(1,), dtype=int32)
            indptr_info=ShapeDtypeStruct(shape=(2,), dtype=int32)
            outs=(ShapedArray(float32[1]),)
            shape=(1, 1)
            transpose=True
            vector_info=ShapeDtypeStruct(shape=(1,), dtype=bool)
            weight_info=ShapeDtypeStruct(shape=(1,), dtype=float32)
          ] cv k l bx
        in (cs,) }
    ] 1.0:f32[] k l bx
    cw:f32[1] = add r cs
    cx:bool[1] = broadcast_in_dim False:bool[]
    cy:i32[] cz:bool[1] = while[
      body_jaxpr={ lambda ; da:i32[1] db:i32[] dc:bool[1]. let
          dd:bool[] = lt db 0:i32[]
          de:i32[] = convert_element_type[new_dtype=int32 weak_type=False] db
          df:i32[] = add de 1:i32[]
          dg:i32[] = select_n dd db df
          dh:i32[1] = dynamic_slice[slice_sizes=(1,)] da dg
          di:i32[] = squeeze[dimensions=(0,)] dh
          dj:bool[] = lt di 0:i32[]
          dk:i32[] = add di 1:i32[]
          dl:i32[] = select_n dj di dk
          dm:i32[1] = broadcast_in_dim dl
          dn:bool[1] = scatter[
            dimension_numbers=ScatterDimensionNumbers(update_window_dims=(), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,), operand_batching_dims=(), scatter_indices_batching_dims=())
            indices_are_sorted=True
            mode=GatherScatterMode.FILL_OR_DROP
            unique_indices=True
            update_consts=()
            update_jaxpr=None
          ] dc dm True:bool[]
          do:i32[] = add db 1:i32[]
        in (do, dn) }
      body_nconsts=1
      cond_jaxpr={ lambda ; dp:f32[1] dq:f32[] dr:i32[] ds:bool[1]. let
          dt:bool[] = lt dr 1:i32[]
          du:bool[] = lt dr 0:i32[]
          dv:i32[] = convert_element_type[new_dtype=int32 weak_type=False] dr
          dw:i32[] = add dv 1:i32[]
          dx:i32[] = select_n du dr dw
          dy:f32[1] = dynamic_slice[slice_sizes=(1,)] dp dx
          dz:f32[] = squeeze[dimensions=(0,)] dy
          ea:f32[] = convert_element_type[new_dtype=float32 weak_type=False] dq
          eb:bool[] = ge ea dz
          ec:bool[] = convert_element_type[new_dtype=bool weak_type=False] dt
          ed:bool[] = and ec eb
        in (ed,) }
      cond_nconsts=2
    ] m ba n s cx
    ee:f32[1] = neg t
    ef:f32[1] = add ee 0.0:f32[]
    eg:f32[1] = mul 1.0:f32[] cw
    eh:f32[1] = add ef eg
    ei:f32[1] = div eh 10.0:f32[]
    ej:f32[1] = broadcast_in_dim 1.0:f32[]
    ek:f32[1] = div ej 10.0:f32[]
    el:f32[1] = neg ek
    em:f32[1] = mul 0.10000000149011612:f32[] el
    en:f32[1] = abs em
    eo:bool[1] = le en 9.999999747378752e-06:f32[]
    ep:f32[1] = div em 2.0:f32[]
    eq:f32[1] = add 1.0:f32[] ep
    er:f32[1] = mul em em
    es:f32[1] = div er 6.0:f32[]
    et:f32[1] = add eq es
    eu:f32[1] = exp em
    ev:f32[1] = sub eu 1.0:f32[]
    ew:f32[1] = div ev em
    ex:f32[1] = select_n eo ew et
    ey:f32[1] = mul 0.10000000149011612:f32[] ex
    ez:f32[1] = mul ey ei
    fa:f32[1] = add t ez
    fb:f32[1] = add fa 0.0:f32[]
    fc:f32[] = convert_element_type[new_dtype=float32 weak_type=False] ba
    fd:f32[1] = sub fc u
    fe:bool[1] = le fd 0.0:f32[]
    ff:f32[1] = jit[
      name=_where
      jaxpr={ lambda ; fe:bool[1] t:f32[1] fb:f32[1]. let
          ff:f32[1] = select_n fe fb t
        in (ff,) }
    ] fe t fb
    fg:bool[1] = ge ff 20.0:f32[]
    fh:f32[1] = jit[name=_where jaxpr=_where] fg -5.0:f32[] ff
    fi:f32[1] = jit[name=_where jaxpr=_where] fg ba u
    fj:f32[1] = broadcast_in_dim 0.0:f32[]
    fk:i32[] = no_vmap w
    fl:bool[] = eq fk 0:i32[]
    fm:i32[] = convert_element_type[new_dtype=int32 weak_type=False] fl
    cond[
      branches=(
        { lambda ; . let  in () }
        { lambda ; . let
            debug_callback[
              callback=<function debug_callback.<locals>._flat_callback at 0x00000287F4CFBB00>
              effect=OrderedDebug
              partitioned=False
            ]
          in () }
      )
    ] fm
    fn:i32[] = jit[
      name=remainder
      jaxpr={ lambda ; fk:i32[] fo:i32[]. let
          fp:bool[] = eq fo 0:i32[]
          fq:i32[] = jit[
            name=_where
            jaxpr={ lambda ; fp:bool[] fr:i32[] fo:i32[]. let
                fq:i32[] = select_n fp fo fr
              in (fq,) }
          ] fp 1:i32[] fo
          fs:i32[] = rem fk fq
          ft:bool[] = ne fs 0:i32[]
          fu:bool[] = lt fs 0:i32[]
          fv:bool[] = lt fq 0:i32[]
          fw:bool[] = ne fu fv
          fx:bool[] = and fw ft
          fy:i32[] = add fs fq
          fn:i32[] = select_n fx fs fy
        in (fn,) }
    ] fk 1:i32[]
    fz:bool[] = eq fn 0:i32[]
    ga:i32[] = convert_element_type[new_dtype=int32 weak_type=False] fz
    cond[
      branches=(
        { lambda ; . let  in () }
        { lambda ; . let
            debug_callback[
              callback=<function debug_callback.<locals>._flat_callback at 0x00000287F4CFBCE0>
              effect=OrderedDebug
              partitioned=False
            ]
          in () }
      )
    ] ga
    gb:bool[] = eq fk 9:i32[]
    gc:i32[] = convert_element_type[new_dtype=int32 weak_type=False] gb
    cond[
      branches=(
        { lambda ; . let  in () }
        { lambda ; . let
            debug_callback[
              callback=<function debug_callback.<locals>._flat_callback at 0x00000287F5D740E0>
              effect=OrderedDebug
              partitioned=False
            ]
          in () }
      )
    ] gc
    gd:i32[] = add w 1:i32[]
  in (cl, cr, cz, fj, cy, fh, fi, fg, gd) }, (<jax._src.debugging.OrderedDebugEffect object at 0x00000287CCB08D10>,))

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "D:\miniconda3\envs\test\Lib\site-packages\brainevent\_csr\binary.py", line 158, in binary_csrmv
    res = binary_csrmv_p_call(
  File "D:\miniconda3\envs\test\Lib\site-packages\brainevent\_csr\binary.py", line 986, in binary_csrmv_p_call
    return binary_csrmv_p(
  File "D:\miniconda3\envs\test\Lib\site-packages\brainevent\_op\main.py", line 279, in __call__
    r = self.primitive.bind(*ins, **kwargs, outs=tuple(outs))
jax._src.source_info_util.JaxStackTraceBeforeTransformation: ModuleNotFoundError: No module named 'numba'

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "D:\CodeWork\CodeTest\course_learning\exp_psy\test\test.py", line 38, in <module>
    main()
  File "D:\CodeWork\CodeTest\course_learning\exp_psy\test\test.py", line 34, in main
    runner.predict(1.0)
  File "D:\miniconda3\envs\test\Lib\site-packages\brainpy\runners.py", line 491, in predict
    outputs, hists = self._predict(indices, *inputs, shared_args=shared_args)
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\miniconda3\envs\test\Lib\site-packages\brainpy\runners.py", line 541, in _predict
    outs_and_mons = self._fun_predict(indices, *xs, shared_args=shared_args)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\miniconda3\envs\test\Lib\site-packages\brainpy\runners.py", line 666, in _fun_predict
    return bm.for_loop(functools.partial(self._step_func_predict, shared_args=shared_args),
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\miniconda3\envs\test\Lib\site-packages\brainpy\math\object_transform\controls.py", line 399, in for_loop
    return brainstate.transform.for_loop(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\miniconda3\envs\test\Lib\site-packages\brainstate\transform\_loop_collect_return.py", line 542, in for_loop
    _, ys = scan(
            ^^^^^
  File "D:\miniconda3\envs\test\Lib\site-packages\brainstate\transform\_loop_collect_return.py", line 263, in scan
    ) = jax.lax.scan(
        ^^^^^^^^^^^^^
  File "D:\miniconda3\envs\test\Lib\site-packages\brainevent\_op\main.py", line 418, in fallback_kernel_fn
    kernel = entry.kernel_generator(**kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\miniconda3\envs\test\Lib\site-packages\brainevent\_csr\binary.py", line 297, in _csrmv_numba_kernel
    import numba
ModuleNotFoundError: No module named 'numba'
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

Environment

Please provide the following information:

  • BrainPy Version: 2.7.7
  • Python Version: 3.12.13
  • JAX Version: 0.9.1
  • Operating System: Windows 11
  • Hardware: CPU only
  • Installation method: pip

Checklist

  • I have checked for duplicate issues
  • I have provided a minimal code example to reproduce the bug
  • I have included the full error message/traceback
  • I have provided environment information

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions