Skip to content

Commit

Permalink
[xla:ffi] Improved fail messages for custom call tests
Browse files Browse the repository at this point in the history
In case a custom call is not found, the tests now fail gracefully instead of crashing.

PiperOrigin-RevId: 617898636
  • Loading branch information
Adam-Banas authored and tensorflower-gardener committed Mar 21, 2024
1 parent 5e463fc commit 6480a6f
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 7 deletions.
1 change: 1 addition & 0 deletions third_party/xla/xla/tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1751,6 1751,7 @@ xla_test(
"//xla/service:custom_call_target_registry",
"@com_google_absl//absl/base:dynamic_annotations",
"@com_google_absl//absl/strings",
"@local_tsl//tsl/platform:statusor",
"@local_tsl//tsl/platform:test",
],
)
Expand Down
16 changes: 9 additions & 7 deletions third_party/xla/xla/tests/custom_call_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 35,7 @@ limitations under the License.
#include "xla/tests/test_macros.h"
#include "xla/tests/test_utils.h"
#include "xla/xla_data.pb.h"
#include "tsl/platform/statusor.h"
#include "tsl/platform/test.h"

namespace {
Expand Down Expand Up @@ -117,7 118,7 @@ XLA_TEST_F(CustomCallTest, CustomCallR0F32Add2) {

module->AddEntryComputation(builder.Build());

Literal result = ExecuteAndTransfer(std::move(module), {});
TF_ASSERT_OK_AND_ASSIGN(auto result, Execute(std::move(module), {}));
LiteralTestUtil::ExpectR0Near<float>(44.0f, result, error_spec_);
}

Expand All @@ -138,7 139,7 @@ XLA_TEST_F(CustomCallTest, CustomCallR2F32Reduce) {

module->AddEntryComputation(builder.Build());

Literal result = ExecuteAndTransfer(std::move(module), {});
TF_ASSERT_OK_AND_ASSIGN(auto result, Execute(std::move(module), {}));
LiteralTestUtil::ExpectR0Near<float>(10.0f, result, error_spec_);
}

Expand All @@ -161,7 162,7 @@ XLA_TEST_F(CustomCallTest, UsedInOtherComputations) {

module->AddEntryComputation(b.Build());

Literal result = ExecuteAndTransfer(std::move(module), {});
TF_ASSERT_OK_AND_ASSIGN(auto result, Execute(std::move(module), {}));
LiteralTestUtil::ExpectR3EqualArray3D<float>(
Array3D<float>{{{2, 3}, {4, 5}}, {{3, 4}, {5, 6}}}, result);
}
Expand Down Expand Up @@ -190,7 191,7 @@ XLA_TEST_F(CustomCallTest, InputAndOutputLayoutDiffer) {
// Note, the expected result is transposed! This is because the input and
// output layouts of the custom call differ and the called function just
// blindly adds one to each element.
Literal result = ExecuteAndTransfer(std::move(module), {&argument});
TF_ASSERT_OK_AND_ASSIGN(auto result, Execute(std::move(module), {&argument}));
LiteralTestUtil::ExpectR2Equal<float>({{2.f, 4.f}, {3.f, 5.f}}, result);
}

Expand All @@ -217,7 218,7 @@ XLA_TEST_F(CustomCallTest, LayoutConstrained) {

Literal argument = LiteralUtil::CreateR2<float>({{1.f, 2.f}, {3.f, 4.f}});

Literal result = ExecuteAndTransfer(std::move(module), {&argument});
TF_ASSERT_OK_AND_ASSIGN(auto result, Execute(std::move(module), {&argument}));
LiteralTestUtil::ExpectR2Equal<float>({{3.f, 4.f}, {5.f, 6.f}}, result);
}

Expand All @@ -237,7 238,8 @@ XLA_TEST_F(CustomCallTest, TupleOutput) {
Literal arg1 = LiteralUtil::CreateR0<float>(42.f);

Literal expected = LiteralUtil::MakeTuple({&arg1, &arg0});
Literal result = ExecuteAndTransfer(std::move(module), {&arg0, &arg1});
TF_ASSERT_OK_AND_ASSIGN(auto result,
Execute(std::move(module), {&arg0, &arg1}));
EXPECT_EQ(result, expected);
}

Expand All @@ -253,7 255,7 @@ XLA_TEST_F(CustomCallTest, ReportsSuccess) {

module->AddEntryComputation(builder.Build());

Literal result = ExecuteAndTransfer(std::move(module), {});
TF_ASSERT_OK_AND_ASSIGN(auto result, Execute(std::move(module), {}));
LiteralTestUtil::ExpectR0Near<float>(44.0f, result, error_spec_);
}

Expand Down

0 comments on commit 6480a6f

Please sign in to comment.