-
Notifications
You must be signed in to change notification settings - Fork 105
Open
Labels
bugSomething isn't workingSomething isn't working
Description
Bug Description
pip install "brainpy[cpu]" succeeds without numba, but calling bm.pre2post_event_sum on CPU fails at runtime due to missing numba.
To Reproduce
Minimal code example named as test.py to reproduce the issue:
import importlib.util
import brainpy as bp
import brainpy.math as bm
class Test(bp.dyn.SynConn):
def __init__(self, pre, post, conn):
super().__init__(pre=pre, post=post, conn=conn)
self.pre2post = self.conn.require("pre2post")
self.delay = bm.LengthDelay(self.pre.spike, 0)
def update(self):
delayed_pre_spike = self.delay(0)
self.delay.update(self.pre.spike)
# The line below is the minimal trigger path.
post_sp = bm.pre2post_event_sum(
delayed_pre_spike, self.pre2post, self.post.num, 1.0
)
self.post.input += post_sp
def main():
print("brainpy:", bp.__version__)
print("numba installed:", importlib.util.find_spec("numba") is not None)
pre = bp.neurons.SpikeTimeGroup(1, times=[0.0], indices=[0])
post = bp.neurons.LIF(1)
syn = Test(pre, post, conn=bp.connect.All2All())
net = bp.DynSysGroup(pre=pre, syn=syn, post=post)
# On CPU with environment lacking numba, this should fail here.
runner = bp.DSRunner(net)
runner.predict(1.0)
if __name__ == "__main__":
main()Steps to reproduce:
- conda create -n test python=3.12 -c conda-forge
- conda activate test
- pip install "brainpy[cpu]"
- python test.py
Expected Behavior
Either:
- This path runs with only
brainpy[cpu], or - The runtime dependency is declared/documented clearly for this API path.
Actual Behavior
Runtime error on CPU indicating numba is required when executing bm.pre2post_event_sum.
brainpy: 2.7.7
numba installed: False
Traceback (most recent call last):
File "D:\miniconda3\envs\test\Lib\site-packages\jax\_src\interpreters\mlir.py", line 2519, in _lower_jaxpr_to_fun_cached
func_op, _, _ = ctx.cached_primitive_lowerings[key]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^
KeyError: ('closed_call', let _where = { lambda ; a:bool[1] b:f32[] c:f32[1]. let
d:f32[] = convert_element_type[new_dtype=float32 weak_type=False] b
e:f32[1] = broadcast_in_dim d
f:f32[1] = select_n a c e
in (f,) } in
let _where1 = { lambda ; g:bool[] h:i32[] i:i32[]. let
j:i32[] = select_n g i h
in (j,) } in
{ lambda ; k:i32[1] l:i32[2] m:f32[1] n:i32[1] o:i32[1] p:bool[1,1] q:bool[1] r:f32[1]
s:i32[] t:f32[1] u:f32[1] v:bool[1] w:i32[] x:i32[]. let
y:f32[] = convert_element_type[new_dtype=float32 weak_type=True] x
z:f32[] = mul y 0.1:f32[]
ba:f32[] = add 0.0:f32[] z
bb:i32[] = squeeze[dimensions=(0,)] o
bc:i32[] = add bb 0:i32[]
bd:i32[] = jit[
name=remainder
jaxpr={ lambda ; bc:i32[] be:i32[]. let
bf:i32[] = convert_element_type[new_dtype=int32 weak_type=False] be
bg:bool[] = eq bf 0:i32[]
bh:i32[] = jit[name=_where jaxpr=_where1] bg 1:i32[] bf
bi:i32[] = rem bc bh
bj:bool[] = ne bi 0:i32[]
bk:bool[] = lt bi 0:i32[]
bl:bool[] = lt bh 0:i32[]
bm:bool[] = ne bk bl
bn:bool[] = and bm bj
bo:i32[] = add bi bh
bd:i32[] = select_n bn bi bo
in (bd,) }
] bc 1:i32[]
bp:i32[] = stop_gradient bd
bq:bool[] = lt bp 0:i32[]
br:i32[] = add bp 1:i32[]
bs:i32[] = select_n bq bp br
bt:bool[] = lt 0:i32[] 0:i32[]
bu:i32[] = add 0:i32[] 1:i32[]
bv:i32[] = select_n bt 0:i32[] bu
bw:bool[1,1] = dynamic_slice[slice_sizes=(1, 1)] p bs bv
bx:bool[1] = squeeze[dimensions=(0,)] bw
by:i32[1] = sub o 1:i32[]
bz:i32[1] = jit[
name=remainder
jaxpr={ lambda ; by:i32[1] ca:i32[]. let
cb:i32[] = convert_element_type[new_dtype=int32 weak_type=False] ca
cc:bool[] = eq cb 0:i32[]
cd:i32[] = jit[name=_where jaxpr=_where1] cc 1:i32[] cb
ce:i32[1] = rem by cd
cf:bool[1] = ne ce 0:i32[]
cg:bool[1] = lt ce 0:i32[]
ch:bool[] = lt cd 0:i32[]
ci:bool[1] = ne cg ch
cj:bool[1] = and ci cf
ck:i32[1] = add ce cd
bz:i32[1] = select_n cj ce ck
in (bz,) }
] by 1:i32[]
cl:i32[1] = stop_gradient bz
cm:i32[] = squeeze[dimensions=(0,)] cl
cn:bool[] = lt cm 0:i32[]
co:i32[] = add cm 1:i32[]
cp:i32[] = select_n cn cm co
cq:i32[1] = broadcast_in_dim cp
cr:bool[1,1] = scatter[
dimension_numbers=ScatterDimensionNumbers(update_window_dims=(0,), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,), operand_batching_dims=(), scatter_indices_batching_dims=())
indices_are_sorted=True
mode=GatherScatterMode.FILL_OR_DROP
unique_indices=True
update_consts=()
update_jaxpr=None
] p cq q
cs:f32[1] = jit[
name=brainevent.binary_csrmv
jaxpr={ lambda ; ct:f32[] k:i32[1] l:i32[2] bx:bool[1]. let
cu:f32[] = convert_element_type[new_dtype=float32 weak_type=False] ct
cv:f32[1] = broadcast_in_dim cu
cs:f32[1] = binary_csrmv[
backend=None
indices_info=ShapeDtypeStruct(shape=(1,), dtype=int32)
indptr_info=ShapeDtypeStruct(shape=(2,), dtype=int32)
outs=(ShapedArray(float32[1]),)
shape=(1, 1)
transpose=True
vector_info=ShapeDtypeStruct(shape=(1,), dtype=bool)
weight_info=ShapeDtypeStruct(shape=(1,), dtype=float32)
] cv k l bx
in (cs,) }
] 1.0:f32[] k l bx
cw:f32[1] = add r cs
cx:bool[1] = broadcast_in_dim False:bool[]
cy:i32[] cz:bool[1] = while[
body_jaxpr={ lambda ; da:i32[1] db:i32[] dc:bool[1]. let
dd:bool[] = lt db 0:i32[]
de:i32[] = convert_element_type[new_dtype=int32 weak_type=False] db
df:i32[] = add de 1:i32[]
dg:i32[] = select_n dd db df
dh:i32[1] = dynamic_slice[slice_sizes=(1,)] da dg
di:i32[] = squeeze[dimensions=(0,)] dh
dj:bool[] = lt di 0:i32[]
dk:i32[] = add di 1:i32[]
dl:i32[] = select_n dj di dk
dm:i32[1] = broadcast_in_dim dl
dn:bool[1] = scatter[
dimension_numbers=ScatterDimensionNumbers(update_window_dims=(), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,), operand_batching_dims=(), scatter_indices_batching_dims=())
indices_are_sorted=True
mode=GatherScatterMode.FILL_OR_DROP
unique_indices=True
update_consts=()
update_jaxpr=None
] dc dm True:bool[]
do:i32[] = add db 1:i32[]
in (do, dn) }
body_nconsts=1
cond_jaxpr={ lambda ; dp:f32[1] dq:f32[] dr:i32[] ds:bool[1]. let
dt:bool[] = lt dr 1:i32[]
du:bool[] = lt dr 0:i32[]
dv:i32[] = convert_element_type[new_dtype=int32 weak_type=False] dr
dw:i32[] = add dv 1:i32[]
dx:i32[] = select_n du dr dw
dy:f32[1] = dynamic_slice[slice_sizes=(1,)] dp dx
dz:f32[] = squeeze[dimensions=(0,)] dy
ea:f32[] = convert_element_type[new_dtype=float32 weak_type=False] dq
eb:bool[] = ge ea dz
ec:bool[] = convert_element_type[new_dtype=bool weak_type=False] dt
ed:bool[] = and ec eb
in (ed,) }
cond_nconsts=2
] m ba n s cx
ee:f32[1] = neg t
ef:f32[1] = add ee 0.0:f32[]
eg:f32[1] = mul 1.0:f32[] cw
eh:f32[1] = add ef eg
ei:f32[1] = div eh 10.0:f32[]
ej:f32[1] = broadcast_in_dim 1.0:f32[]
ek:f32[1] = div ej 10.0:f32[]
el:f32[1] = neg ek
em:f32[1] = mul 0.10000000149011612:f32[] el
en:f32[1] = abs em
eo:bool[1] = le en 9.999999747378752e-06:f32[]
ep:f32[1] = div em 2.0:f32[]
eq:f32[1] = add 1.0:f32[] ep
er:f32[1] = mul em em
es:f32[1] = div er 6.0:f32[]
et:f32[1] = add eq es
eu:f32[1] = exp em
ev:f32[1] = sub eu 1.0:f32[]
ew:f32[1] = div ev em
ex:f32[1] = select_n eo ew et
ey:f32[1] = mul 0.10000000149011612:f32[] ex
ez:f32[1] = mul ey ei
fa:f32[1] = add t ez
fb:f32[1] = add fa 0.0:f32[]
fc:f32[] = convert_element_type[new_dtype=float32 weak_type=False] ba
fd:f32[1] = sub fc u
fe:bool[1] = le fd 0.0:f32[]
ff:f32[1] = jit[
name=_where
jaxpr={ lambda ; fe:bool[1] t:f32[1] fb:f32[1]. let
ff:f32[1] = select_n fe fb t
in (ff,) }
] fe t fb
fg:bool[1] = ge ff 20.0:f32[]
fh:f32[1] = jit[name=_where jaxpr=_where] fg -5.0:f32[] ff
fi:f32[1] = jit[name=_where jaxpr=_where] fg ba u
fj:f32[1] = broadcast_in_dim 0.0:f32[]
fk:i32[] = no_vmap w
fl:bool[] = eq fk 0:i32[]
fm:i32[] = convert_element_type[new_dtype=int32 weak_type=False] fl
cond[
branches=(
{ lambda ; . let in () }
{ lambda ; . let
debug_callback[
callback=<function debug_callback.<locals>._flat_callback at 0x00000287F4CFBB00>
effect=OrderedDebug
partitioned=False
]
in () }
)
] fm
fn:i32[] = jit[
name=remainder
jaxpr={ lambda ; fk:i32[] fo:i32[]. let
fp:bool[] = eq fo 0:i32[]
fq:i32[] = jit[
name=_where
jaxpr={ lambda ; fp:bool[] fr:i32[] fo:i32[]. let
fq:i32[] = select_n fp fo fr
in (fq,) }
] fp 1:i32[] fo
fs:i32[] = rem fk fq
ft:bool[] = ne fs 0:i32[]
fu:bool[] = lt fs 0:i32[]
fv:bool[] = lt fq 0:i32[]
fw:bool[] = ne fu fv
fx:bool[] = and fw ft
fy:i32[] = add fs fq
fn:i32[] = select_n fx fs fy
in (fn,) }
] fk 1:i32[]
fz:bool[] = eq fn 0:i32[]
ga:i32[] = convert_element_type[new_dtype=int32 weak_type=False] fz
cond[
branches=(
{ lambda ; . let in () }
{ lambda ; . let
debug_callback[
callback=<function debug_callback.<locals>._flat_callback at 0x00000287F4CFBCE0>
effect=OrderedDebug
partitioned=False
]
in () }
)
] ga
gb:bool[] = eq fk 9:i32[]
gc:i32[] = convert_element_type[new_dtype=int32 weak_type=False] gb
cond[
branches=(
{ lambda ; . let in () }
{ lambda ; . let
debug_callback[
callback=<function debug_callback.<locals>._flat_callback at 0x00000287F5D740E0>
effect=OrderedDebug
partitioned=False
]
in () }
)
] gc
gd:i32[] = add w 1:i32[]
in (cl, cr, cz, fj, cy, fh, fi, fg, gd) }, (<jax._src.debugging.OrderedDebugEffect object at 0x00000287CCB08D10>,))
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "D:\miniconda3\envs\test\Lib\site-packages\brainevent\_csr\binary.py", line 158, in binary_csrmv
res = binary_csrmv_p_call(
File "D:\miniconda3\envs\test\Lib\site-packages\brainevent\_csr\binary.py", line 986, in binary_csrmv_p_call
return binary_csrmv_p(
File "D:\miniconda3\envs\test\Lib\site-packages\brainevent\_op\main.py", line 279, in __call__
r = self.primitive.bind(*ins, **kwargs, outs=tuple(outs))
jax._src.source_info_util.JaxStackTraceBeforeTransformation: ModuleNotFoundError: No module named 'numba'
The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "D:\CodeWork\CodeTest\course_learning\exp_psy\test\test.py", line 38, in <module>
main()
File "D:\CodeWork\CodeTest\course_learning\exp_psy\test\test.py", line 34, in main
runner.predict(1.0)
File "D:\miniconda3\envs\test\Lib\site-packages\brainpy\runners.py", line 491, in predict
outputs, hists = self._predict(indices, *inputs, shared_args=shared_args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "D:\miniconda3\envs\test\Lib\site-packages\brainpy\runners.py", line 541, in _predict
outs_and_mons = self._fun_predict(indices, *xs, shared_args=shared_args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "D:\miniconda3\envs\test\Lib\site-packages\brainpy\runners.py", line 666, in _fun_predict
return bm.for_loop(functools.partial(self._step_func_predict, shared_args=shared_args),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "D:\miniconda3\envs\test\Lib\site-packages\brainpy\math\object_transform\controls.py", line 399, in for_loop
return brainstate.transform.for_loop(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "D:\miniconda3\envs\test\Lib\site-packages\brainstate\transform\_loop_collect_return.py", line 542, in for_loop
_, ys = scan(
^^^^^
File "D:\miniconda3\envs\test\Lib\site-packages\brainstate\transform\_loop_collect_return.py", line 263, in scan
) = jax.lax.scan(
^^^^^^^^^^^^^
File "D:\miniconda3\envs\test\Lib\site-packages\brainevent\_op\main.py", line 418, in fallback_kernel_fn
kernel = entry.kernel_generator(**kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "D:\miniconda3\envs\test\Lib\site-packages\brainevent\_csr\binary.py", line 297, in _csrmv_numba_kernel
import numba
ModuleNotFoundError: No module named 'numba'
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
Environment
Please provide the following information:
- BrainPy Version: 2.7.7
- Python Version: 3.12.13
- JAX Version: 0.9.1
- Operating System: Windows 11
- Hardware: CPU only
- Installation method: pip
Checklist
- I have checked for duplicate issues
- I have provided a minimal code example to reproduce the bug
- I have included the full error message/traceback
- I have provided environment information
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working