Skip to content

Commit

Permalink
select : Support Python boolean argument types
Browse files Browse the repository at this point in the history
  • Loading branch information
rtabbara committed Nov 29, 2024
1 parent 4803a9d commit d0c8811
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/python/base.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -883,7 883,7 @@ nb::object select(nb::handle h0, nb::handle h1, nb::handle h2) {
nb::handle tp = h1.type();

if (is_drjit_type(tp) || !tp.is(h2.type()) ||
tp.is(&PyLong_Type) || tp.is(&PyFloat_Type)) {
tp.is(&PyLong_Type) || tp.is(&PyFloat_Type) || tp.is(&PyBool_Type)) {
PyObject *o = apply<Select>(ArrayOp::Select, "select",
std::make_index_sequence<3>(), h0.ptr(), h1.ptr(), h2.ptr());

Expand Down
4 changes: 4 additions & 0 deletions tests/test_arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 157,8 @@ def test06_select():
assert isinstance(result, s.ArrayXi) and dr.all(result == s.ArrayXi(1, 2))
result = dr.select(s.ArrayXb(True, False), 1, 2.0)
assert isinstance(result, s.ArrayXf) and dr.all(result == s.ArrayXf(1, 2))
result = dr.select(s.ArrayXb(True, False, True), True, False)
assert isinstance(result, s.ArrayXb) and dr.all(result == s.ArrayXb(True, False, True))

result = dr.select(s.Array2b(True, False), s.Array2i(3, 4), s.Array2i(5, 6))
assert isinstance(result, s.Array2i) and dr.all(result == s.Array2i(3, 6))
Expand Down Expand Up @@ -185,6 187,8 @@ def test06_select():
assert isinstance(result, l.ArrayXi) and dr.all(result == l.ArrayXi(1, 2))
result = dr.select(l.ArrayXb(True, False), 1, 2.0)
assert isinstance(result, l.ArrayXf) and dr.all(result == l.ArrayXf(1, 2))
result = dr.select(l.ArrayXb(True, False, True), True, False)
assert isinstance(result, l.ArrayXb) and dr.all(result == l.ArrayXb(True, False, True))

result = dr.select(l.Array2b(True, False), l.Array2i(3, 4), l.Array2i(5, 6))
assert isinstance(result, l.Array2i) and dr.all(result == l.Array2i(3, 6))
Expand Down

0 comments on commit d0c8811

Please sign in to comment.