diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-07-24 09:08:26 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-07-24 09:13:11 -0700 |
commit | 0742b4b1eacdcca1473f856fbd05a6f501eff4ed (patch) | |
tree | 91cd4bbf6f7bb7d6616956cfd87a4206b41183b3 /tensorflow | |
parent | 133472fb49c51a7a47d420c282fcc55a0cc1cdb7 (diff) |
[XLA] StatusOr testing macro TF_ASSIGN_OR_ASSERT_OK => TF_ASSERT_OK_AND_ASSIGN
PiperOrigin-RevId: 162944460
Diffstat (limited to 'tensorflow')
18 files changed, 140 insertions, 146 deletions
diff --git a/tensorflow/compiler/xla/literal_util_test.cc b/tensorflow/compiler/xla/literal_util_test.cc index 6c3648e1e0..b50e741b8a 100644 --- a/tensorflow/compiler/xla/literal_util_test.cc +++ b/tensorflow/compiler/xla/literal_util_test.cc @@ -835,8 +835,8 @@ TEST_F(LiteralUtilTest, ConvertR4) { {{26, 27, 28, 29}, {30, 31, 32, 33}}, }}, layout_r4_dim0major_); // clang-format on - TF_ASSIGN_OR_ASSERT_OK(std::unique_ptr<Literal> converted, - original->Convert(U32)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> converted, + original->Convert(U32)); EXPECT_TRUE(expected->Equal(*converted)); } diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc index e27575db43..18acd4f3ae 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment_test.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc @@ -1573,7 +1573,7 @@ TEST_F(BufferAssignmentTest, TwoCalls) { { FlattenCallGraph flatten; - TF_ASSIGN_OR_ASSERT_OK(bool result, flatten.Run(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool result, flatten.Run(module.get())); EXPECT_TRUE(result); std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get()); } @@ -1652,7 +1652,7 @@ TEST_F(WhileBufferAssignmentTest, WhileLoopsInterferingResultRange) { { FlattenCallGraph flatten; - TF_ASSIGN_OR_ASSERT_OK(bool result, flatten.Run(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool result, flatten.Run(module.get())); EXPECT_TRUE(result); } @@ -1723,7 +1723,7 @@ TEST_F(WhileBufferAssignmentTest, DISABLED_TwoWhiles) { { FlattenCallGraph flatten; - TF_ASSIGN_OR_ASSERT_OK(bool result, flatten.Run(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool result, flatten.Run(module.get())); EXPECT_TRUE(result); } diff --git a/tensorflow/compiler/xla/service/flatten_call_graph_test.cc b/tensorflow/compiler/xla/service/flatten_call_graph_test.cc index a08506d84d..12a6794ac1 100644 --- a/tensorflow/compiler/xla/service/flatten_call_graph_test.cc +++ b/tensorflow/compiler/xla/service/flatten_call_graph_test.cc @@ -139,7 +139,7 @@ TEST_F(FlattenCallGraphTest, ComplexGraph) { } { - TF_ASSIGN_OR_ASSERT_OK(bool result, RunFlattenCallGraph(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module.get())); EXPECT_TRUE(result); std::unique_ptr<CallGraph> flat_call_graph = CallGraph::Build(module.get()); const CallGraphNode& c_node = flat_call_graph->GetNode(c_computation); @@ -182,7 +182,7 @@ TEST_F(FlattenCallGraphTest, SharedWhileConditionAndBody) { } { - TF_ASSIGN_OR_ASSERT_OK(bool result, RunFlattenCallGraph(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module.get())); EXPECT_TRUE(result); std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get()); const CallGraphNode& cond_node = call_graph->GetNode(cond_computation); @@ -211,7 +211,7 @@ TEST_F(FlattenCallGraphTest, FlattenCalls) { module->AddEntryComputation( MakeCallingComputation(b_computation, /*callsites=*/2, ".Entry")); - TF_ASSIGN_OR_ASSERT_OK(bool result, RunFlattenCallGraph(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module.get())); EXPECT_TRUE(result); std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get()); EXPECT_EQ(7, module->computations().size()); diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc index 1c60b06ddd..3ae499d5e0 100644 --- a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc +++ b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc @@ -51,7 +51,7 @@ TEST_F(HloConstantFoldingTest, ConvertF32ToS64) { EXPECT_THAT(computation->root_instruction(), op::Convert(input)); HloConstantFolding const_folder; - TF_ASSIGN_OR_ASSERT_OK(bool result, const_folder.Run(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get())); EXPECT_TRUE(result); EXPECT_THAT(computation->root_instruction(), op::Constant()); @@ -72,7 +72,7 @@ TEST_F(HloConstantFoldingTest, ConvertS64ToF32) { EXPECT_THAT(computation->root_instruction(), op::Convert(input)); HloConstantFolding const_folder; - TF_ASSIGN_OR_ASSERT_OK(bool result, const_folder.Run(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get())); EXPECT_TRUE(result); EXPECT_THAT(computation->root_instruction(), op::Constant()); @@ -93,7 +93,7 @@ TEST_F(HloConstantFoldingTest, ConvertF32ArrayToS64Array) { EXPECT_THAT(computation->root_instruction(), op::Convert(input)); HloConstantFolding const_folder; - TF_ASSIGN_OR_ASSERT_OK(bool result, const_folder.Run(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get())); EXPECT_TRUE(result); EXPECT_THAT(computation->root_instruction(), op::Constant()); @@ -133,7 +133,7 @@ TEST_F(HloConstantFoldingTest, Concatenate) { auto computation = module->AddEntryComputation(builder.Build()); HloConstantFolding const_folder; - TF_ASSIGN_OR_ASSERT_OK(bool result, const_folder.Run(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get())); EXPECT_TRUE(result); HloInstruction* root = computation->root_instruction(); @@ -148,9 +148,9 @@ TEST_F(HloConstantFoldingTest, Slice) { const int64 slice_start[] = {4, 2, 3, 1, 5}; const int64 slice_limits[] = {10, 8, 6, 5, 9}; const int64 slice_strides[] = {1, 1, 1, 1, 1}; - TF_ASSIGN_OR_ASSERT_OK(auto literal, - LiteralTestUtil::CreateRandomLiteral<F32>( - ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0)); + TF_ASSERT_OK_AND_ASSIGN(auto literal, + LiteralTestUtil::CreateRandomLiteral<F32>( + ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0)); HloInstruction* literal_instruction = builder.AddInstruction( HloInstruction::CreateConstant(std::move(literal))); Shape shape = ShapeUtil::MakeShape(F32, {6, 6, 3, 4, 4}); @@ -160,7 +160,7 @@ TEST_F(HloConstantFoldingTest, Slice) { auto computation = module->AddEntryComputation(builder.Build()); HloConstantFolding const_folder; - TF_ASSIGN_OR_ASSERT_OK(bool result, const_folder.Run(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get())); EXPECT_TRUE(result); HloInstruction* root = computation->root_instruction(); @@ -171,9 +171,9 @@ TEST_F(HloConstantFoldingTest, Slice) { TEST_F(HloConstantFoldingTest, TransposeConstantFold) { HloComputation::Builder builder(TestName()); const int64 dimensions[] = {11, 8, 7, 5, 9}; - TF_ASSIGN_OR_ASSERT_OK(auto literal, - LiteralTestUtil::CreateRandomLiteral<F32>( - ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0)); + TF_ASSERT_OK_AND_ASSIGN(auto literal, + LiteralTestUtil::CreateRandomLiteral<F32>( + ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0)); auto literal_clone = literal->Literal::CloneToUnique(); HloInstruction* literal_instruction = builder.AddInstruction( HloInstruction::CreateConstant(std::move(literal))); @@ -185,7 +185,7 @@ TEST_F(HloConstantFoldingTest, TransposeConstantFold) { auto computation = module->AddEntryComputation(builder.Build()); HloConstantFolding const_folder; - TF_ASSIGN_OR_ASSERT_OK(bool result, const_folder.Run(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get())); EXPECT_TRUE(result); HloInstruction* root = computation->root_instruction(); diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc index 626bd3b02b..7269fbeffc 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc @@ -218,9 +218,9 @@ TEST_F(HloEvaluatorTest, DoesTraverseInstructions) { TEST_F(HloEvaluatorTest, DoesReshape) { HloComputation::Builder builder(TestName()); const int64 dimensions[] = {11, 8, 7, 5, 9}; - TF_ASSIGN_OR_ASSERT_OK(auto literal, - LiteralTestUtil::CreateRandomLiteral<F32>( - ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0)); + TF_ASSERT_OK_AND_ASSIGN(auto literal, + LiteralTestUtil::CreateRandomLiteral<F32>( + ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0)); auto literal_clone = literal->CloneToUnique(); HloInstruction* literal_instruction = builder.AddInstruction( HloInstruction::CreateConstant(std::move(literal))); diff --git a/tensorflow/compiler/xla/service/hlo_ordering_test.cc b/tensorflow/compiler/xla/service/hlo_ordering_test.cc index a1e38803c4..ad6070a9c1 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering_test.cc +++ b/tensorflow/compiler/xla/service/hlo_ordering_test.cc @@ -62,7 +62,7 @@ TEST_F(HloOrderingTest, LastUseScheduledFirst) { auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - TF_ASSIGN_OR_ASSERT_OK( + TF_ASSERT_OK_AND_ASSIGN( SequentialHloOrdering::HloModuleSequence sequence, CreateMemoryMinimizingSequence(*module, [](const LogicalBuffer& buffer) { return ShapeUtil::ByteSizeOf(buffer.shape()); diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc index 3a935dcf96..2358969f38 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc @@ -158,7 +158,7 @@ TEST_F(HloRematerializationTest, SingleComputation) { SequentialHloOrdering::HloModuleSequence sequence; // Computation requires 16KB without rematerialization, but uses only 12KB // with rematerialization so pick a memory limit between these values (14KB). - TF_ASSIGN_OR_ASSERT_OK( + TF_ASSERT_OK_AND_ASSIGN( bool changed, HloRematerialization::RematerializeAndSchedule( ByteSizeOf, @@ -191,7 +191,7 @@ TEST_F(HloRematerializationTest, SingleComputationNoRematerialization) { EXPECT_EQ(computation->instruction_count(), 7); SequentialHloOrdering::HloModuleSequence sequence; - TF_ASSIGN_OR_ASSERT_OK( + TF_ASSERT_OK_AND_ASSIGN( bool changed, HloRematerialization::RematerializeAndSchedule( ByteSizeOf, @@ -232,7 +232,7 @@ TEST_F(HloRematerializationTest, RematerializeAroundWhile) { // while so the peak memory use of the module is 18KB. Set the memory limit a // bit lower (17KB) to force rematerialization of the entry computation. SequentialHloOrdering::HloModuleSequence sequence; - TF_ASSIGN_OR_ASSERT_OK( + TF_ASSERT_OK_AND_ASSIGN( bool changed, HloRematerialization::RematerializeAndSchedule( ByteSizeOf, @@ -268,7 +268,7 @@ TEST_F(HloRematerializationTest, RematerializeEntryAndWhileBody) { EXPECT_EQ(body_computation->instruction_count(), 7); SequentialHloOrdering::HloModuleSequence sequence; - TF_ASSIGN_OR_ASSERT_OK( + TF_ASSERT_OK_AND_ASSIGN( bool changed, HloRematerialization::RematerializeAndSchedule( ByteSizeOf, @@ -310,7 +310,7 @@ TEST_F(HloRematerializationTest, RematerializeNestedComputations) { // If all computations are maximally rematerialized then peak memory usage is // ~12K so pick something slightly larger. SequentialHloOrdering::HloModuleSequence sequence; - TF_ASSIGN_OR_ASSERT_OK( + TF_ASSERT_OK_AND_ASSIGN( bool changed, HloRematerialization::RematerializeAndSchedule( ByteSizeOf, @@ -406,7 +406,7 @@ TEST_F(HloRematerializationTest, InstructionRematerializedMultipleTimes) { // Pick a memory limit some where between 24KB (initial peak memory including // parameter and output) and 20KB (peak memory possible with // rematerialization). - TF_ASSIGN_OR_ASSERT_OK( + TF_ASSERT_OK_AND_ASSIGN( bool changed, HloRematerialization::RematerializeAndSchedule( ByteSizeOf, @@ -503,7 +503,7 @@ TEST_P(IndirectUseTest, IndirectUseNotRematerialized) { // Pick a memory limit some where between 24KB (initial peak memory including // parameter and output) and 20KB (peak memory possible with // rematerialization). - TF_ASSIGN_OR_ASSERT_OK( + TF_ASSERT_OK_AND_ASSIGN( bool changed, HloRematerialization::RematerializeAndSchedule( ByteSizeOf, diff --git a/tensorflow/compiler/xla/service/user_computation_test.cc b/tensorflow/compiler/xla/service/user_computation_test.cc index 0d50810dc4..07739f241a 100644 --- a/tensorflow/compiler/xla/service/user_computation_test.cc +++ b/tensorflow/compiler/xla/service/user_computation_test.cc @@ -51,15 +51,15 @@ TEST_F(UserComputationTest, SimpleComputation) { ConstantRequest constant_request; *constant_request.mutable_literal() = Literal::CreateR1<float>({123.0f, 42.0f})->ToProto(); - TF_ASSIGN_OR_ASSERT_OK(ComputationDataHandle constant_handle, - computation.AddConstantInstruction(constant_request)); + TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle constant_handle, + computation.AddConstantInstruction(constant_request)); ParameterRequest param_request; *param_request.mutable_shape() = kScalarShape; param_request.set_parameter(0); param_request.set_name("param0"); - TF_ASSIGN_OR_ASSERT_OK(ComputationDataHandle param_handle, - computation.AddParameterInstruction(param_request)); + TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle param_handle, + computation.AddParameterInstruction(param_request)); OpMetadata metadata; metadata.set_op_name("meta"); TF_ASSERT_OK(computation.SetOpMetadata(param_handle, metadata)); @@ -81,7 +81,7 @@ TEST_F(UserComputationTest, SimpleComputation) { // Program shape should have a single scalar parameter and scalar // result. The outfeed instruction should not affect the program shape. - TF_ASSIGN_OR_ASSERT_OK( + TF_ASSERT_OK_AND_ASSIGN( std::shared_ptr<const ProgramShape> program_shape, computation.ComputeProgramShape(latest_version.version)); ASSERT_EQ(1, program_shape->parameters_size()); @@ -90,7 +90,7 @@ TEST_F(UserComputationTest, SimpleComputation) { EXPECT_TRUE(ShapeUtil::Compatible(kScalarShape, program_shape->result())); // Build the HLO computation. - TF_ASSIGN_OR_ASSERT_OK( + TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr<HloComputation> hlo_computation, computation.BuildHloComputation(latest_version.version, hlo_resolver, DebugOptions())); @@ -108,7 +108,7 @@ TEST_F(UserComputationTest, SimpleComputation) { computation.GetVersionedHandleAtOperation(param_handle); // Program shape should have a single scalar parameter, and scalar result. - TF_ASSIGN_OR_ASSERT_OK( + TF_ASSERT_OK_AND_ASSIGN( std::shared_ptr<const ProgramShape> program_shape, computation.ComputeProgramShape(version_at_param.version)); ASSERT_EQ(1, program_shape->parameters_size()); @@ -118,7 +118,7 @@ TEST_F(UserComputationTest, SimpleComputation) { // There should be two instructions, one for the constant and one for the // parameter. The outfeed instruction should not be included. - TF_ASSIGN_OR_ASSERT_OK( + TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr<HloComputation> hlo_computation, computation.BuildHloComputation(version_at_param.version, hlo_resolver, DebugOptions())); @@ -132,7 +132,7 @@ TEST_F(UserComputationTest, SimpleComputation) { computation.GetVersionedHandle(); // Build the HLO computation. - TF_ASSIGN_OR_ASSERT_OK( + TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr<HloComputation> hlo_computation, computation.BuildHloComputation( latest_version.version, hlo_resolver, DebugOptions(), @@ -165,13 +165,13 @@ TEST_F(UserComputationTest, EliminateScalarBroadcast) { ConstantRequest a_request; *a_request.mutable_literal() = Literal::CreateR1<float>({123.0f, 42.0f})->ToProto(); - TF_ASSIGN_OR_ASSERT_OK(ComputationDataHandle a_handle, - computation.AddConstantInstruction(a_request)); + TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle a_handle, + computation.AddConstantInstruction(a_request)); ConstantRequest b_request; *b_request.mutable_literal() = Literal::CreateR0<float>(1.0f)->ToProto(); - TF_ASSIGN_OR_ASSERT_OK(ComputationDataHandle b_handle, - computation.AddConstantInstruction(b_request)); + TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle b_handle, + computation.AddConstantInstruction(b_request)); BinaryOpRequest add; add.set_binop(BINOP_ADD); @@ -185,7 +185,7 @@ TEST_F(UserComputationTest, EliminateScalarBroadcast) { VersionedComputationHandle latest_version = computation.GetVersionedHandle(); // Build the HLO computation. - TF_ASSIGN_OR_ASSERT_OK( + TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr<HloComputation> hlo_computation, computation.BuildHloComputation(latest_version.version, hlo_resolver, DebugOptions())); @@ -218,15 +218,15 @@ TEST_F(UserComputationTest, EliminateDegenerateBroadcastAfterIndimBroadcast) { *a_request.mutable_shape() = ShapeUtil::MakeShape(F32, {2, 3}); a_request.set_name("a"); a_request.set_parameter(0); - TF_ASSIGN_OR_ASSERT_OK(ComputationDataHandle a_handle, - computation.AddParameterInstruction(a_request)); + TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle a_handle, + computation.AddParameterInstruction(a_request)); ParameterRequest b_request; *b_request.mutable_shape() = ShapeUtil::MakeShape(F32, {2, 1, 4}); b_request.set_name("b"); b_request.set_parameter(1); - TF_ASSIGN_OR_ASSERT_OK(ComputationDataHandle b_handle, - computation.AddParameterInstruction(b_request)); + TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle b_handle, + computation.AddParameterInstruction(b_request)); BinaryOpRequest add; add.set_binop(BINOP_ADD); @@ -242,7 +242,7 @@ TEST_F(UserComputationTest, EliminateDegenerateBroadcastAfterIndimBroadcast) { VersionedComputationHandle latest_version = computation.GetVersionedHandle(); // Build the HLO computation. - TF_ASSIGN_OR_ASSERT_OK( + TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr<HloComputation> hlo_computation, computation.BuildHloComputation(latest_version.version, hlo_resolver, DebugOptions())); diff --git a/tensorflow/compiler/xla/status_macros.h b/tensorflow/compiler/xla/status_macros.h index aa12cda666..d9ca8f320b 100644 --- a/tensorflow/compiler/xla/status_macros.h +++ b/tensorflow/compiler/xla/status_macros.h @@ -183,12 +183,12 @@ class StatusAdaptorForMacros { .with_log_stack_trace() \ .add_ret_check_failure(#condition) -#define TF_ASSIGN_OR_ASSERT_OK(lhs, rexpr) \ - TF_ASSIGN_OR_ASSERT_OK_IMPL( \ +#define TF_ASSERT_OK_AND_ASSIGN(lhs, rexpr) \ + TF_ASSERT_OK_AND_ASSIGN_IMPL( \ TF_STATUS_MACROS_CONCAT_NAME(_status_or_value, __COUNTER__), lhs, \ rexpr); -#define TF_ASSIGN_OR_ASSERT_OK_IMPL(statusor, lhs, rexpr) \ +#define TF_ASSERT_OK_AND_ASSIGN_IMPL(statusor, lhs, rexpr) \ auto statusor = (rexpr); \ ASSERT_TRUE(statusor.status().ok()) << statusor.status(); \ lhs = statusor.ConsumeValueOrDie() diff --git a/tensorflow/compiler/xla/status_macros_test.cc b/tensorflow/compiler/xla/status_macros_test.cc index dead17cdfa..4b0740dad7 100644 --- a/tensorflow/compiler/xla/status_macros_test.cc +++ b/tensorflow/compiler/xla/status_macros_test.cc @@ -63,7 +63,7 @@ StatusOr<int> CreateIntUnsuccessfully() { } TEST(StatusMacros, AssignOrAssertOnOK) { - TF_ASSIGN_OR_ASSERT_OK(int result, CreateIntSuccessfully()); + TF_ASSERT_OK_AND_ASSIGN(int result, CreateIntSuccessfully()); EXPECT_EQ(42, result); } diff --git a/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc b/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc index 3d6a995a24..c65f8c0f08 100644 --- a/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc +++ b/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc @@ -67,9 +67,9 @@ TEST_F(DeconstructTupleTest, DeconstructTuple) { // Try copying the elements back and comparing it auto handles = result_status.ConsumeValueOrDie(); std::unique_ptr<Literal> literal; - TF_ASSIGN_OR_ASSERT_OK(literal, client_->Transfer(*handles[0])); + TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[0])); LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, *literal); - TF_ASSIGN_OR_ASSERT_OK(literal, client_->Transfer(*handles[1])); + TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[1])); LiteralTestUtil::ExpectR1Equal<float>({2.0, 4.0, 6.0, 8.0}, *literal); } @@ -89,17 +89,17 @@ TEST_F(DeconstructTupleTest, DeconstructTupleTwice) { auto handles2 = result_status2.ConsumeValueOrDie(); std::unique_ptr<Literal> literal; - TF_ASSIGN_OR_ASSERT_OK(literal, client_->Transfer(*handles1[0])); + TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles1[0])); LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, *literal); - TF_ASSIGN_OR_ASSERT_OK(literal, client_->Transfer(*handles1[1])); + TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles1[1])); LiteralTestUtil::ExpectR1Equal<float>({2.0, 4.0, 6.0, 8.0}, *literal); handles1[0].reset(); handles1[1].reset(); - TF_ASSIGN_OR_ASSERT_OK(literal, client_->Transfer(*handles2[0])); + TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles2[0])); LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, *literal); - TF_ASSIGN_OR_ASSERT_OK(literal, client_->Transfer(*handles2[1])); + TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles2[1])); LiteralTestUtil::ExpectR1Equal<float>({2.0, 4.0, 6.0, 8.0}, *literal); } @@ -119,13 +119,13 @@ XLA_TEST_F(DeconstructTupleTest, DeconstructTupleRepeatedElement) { auto handles = result_status.ConsumeValueOrDie(); std::unique_ptr<Literal> literal; - TF_ASSIGN_OR_ASSERT_OK(literal, client_->Transfer(*handles[0])); + TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[0])); LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, *literal); - TF_ASSIGN_OR_ASSERT_OK(literal, client_->Transfer(*handles[1])); + TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[1])); LiteralTestUtil::ExpectR1Equal<float>({2.0, 4.0, 6.0, 8.0}, *literal); - TF_ASSIGN_OR_ASSERT_OK(literal, client_->Transfer(*handles[2])); + TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[2])); LiteralTestUtil::ExpectR1Equal<float>({2.0, 4.0, 6.0, 8.0}, *literal); - TF_ASSIGN_OR_ASSERT_OK(literal, client_->Transfer(*handles[3])); + TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[3])); LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, *literal); } @@ -145,17 +145,17 @@ TEST_F(DeconstructTupleTest, DeconstructTupleThenDeallocate) { global_data.reset(); std::unique_ptr<Literal> literal; - TF_ASSIGN_OR_ASSERT_OK(literal, client_->Transfer(*handles[0])); + TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[0])); LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, *literal); - TF_ASSIGN_OR_ASSERT_OK(literal, client_->Transfer(*handles[1])); + TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[1])); LiteralTestUtil::ExpectR1Equal<float>({2.0, 4.0, 6.0, 8.0}, *literal); - TF_ASSIGN_OR_ASSERT_OK(literal, client_->Transfer(*handles[2])); + TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[2])); LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, *literal); /// Try deallocating one of the repeated elements, then copy handles[0].reset(); - TF_ASSIGN_OR_ASSERT_OK(literal, client_->Transfer(*handles[2])); + TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[2])); LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, *literal); } diff --git a/tensorflow/compiler/xla/tests/hlo_metadata_test.cc b/tensorflow/compiler/xla/tests/hlo_metadata_test.cc index f54fa2256e..eded2077fc 100644 --- a/tensorflow/compiler/xla/tests/hlo_metadata_test.cc +++ b/tensorflow/compiler/xla/tests/hlo_metadata_test.cc @@ -46,7 +46,7 @@ TEST_F(HloMetadataTest, MetadataPropagation) { builder.ClearOpMetadata(); Shape argument_layout = ShapeUtil::MakeShape(F32, {}); - TF_ASSIGN_OR_ASSERT_OK( + TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr<LocalExecutable> executable, local_client_->Compile(builder.Build().ValueOrDie(), {&argument_layout, &argument_layout}, diff --git a/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc b/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc index 717e9cd494..9ad9b33176 100644 --- a/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc +++ b/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc @@ -178,11 +178,11 @@ TEST_P(MatOpsDotAddTest, Dot_Add_2x2_2x2) { Shape rhs_shape = ShapeUtil::MakeShape(prim_type, {rhs.height(), rhs.width()}); - TF_ASSIGN_OR_ASSERT_OK( + TF_ASSERT_OK_AND_ASSIGN( auto lhs_handle, client_->TransferToServer(*Literal::CreateR2FromArray2DWithLayout<float>( lhs, LayoutUtil::MakeLayout(minor_to_major(row_major))))); - TF_ASSIGN_OR_ASSERT_OK( + TF_ASSERT_OK_AND_ASSIGN( auto rhs_handle, client_->TransferToServer(*Literal::CreateR2FromArray2DWithLayout<float>( rhs, LayoutUtil::MakeLayout(minor_to_major(row_major))))); diff --git a/tensorflow/compiler/xla/tests/prng_test.cc b/tensorflow/compiler/xla/tests/prng_test.cc index ed994fda45..0a2d4c763d 100644 --- a/tensorflow/compiler/xla/tests/prng_test.cc +++ b/tensorflow/compiler/xla/tests/prng_test.cc @@ -68,13 +68,12 @@ void PrngTest::BernoulliTest(float p, tensorflow::gtl::ArraySlice<int64> dims) { auto shape = ShapeUtil::MakeShape(U32, dims); builder.RngBernoulli(builder.ConstantR0<float>(p), shape); - TF_ASSIGN_OR_ASSERT_OK(auto computation, builder.Build()); + TF_ASSERT_OK_AND_ASSIGN(auto computation, builder.Build()); ExecutionOptions execution_options = execution_options_; execution_options.set_seed(42); - TF_ASSIGN_OR_ASSERT_OK( - auto actual, - client_->ExecuteAndTransfer(computation, /*arguments=*/{}, - &execution_options)); + TF_ASSERT_OK_AND_ASSIGN( + auto actual, client_->ExecuteAndTransfer(computation, /*arguments=*/{}, + &execution_options)); EXPECT_THAT(dims, ::testing::ElementsAreArray(actual->shape().dimensions())); int32 sum = 0; actual->EachCell<uint32>( @@ -167,22 +166,21 @@ XLA_TEST_F(PrngTest, MapUsingRng) { ComputationBuilder builder(client_, TestName()); std::unique_ptr<Literal> param0_literal = Literal::CreateR1<float>({2.2f, 5.3f, 4.4f, 5.5f}); - TF_ASSIGN_OR_ASSERT_OK(std::unique_ptr<GlobalData> param0_data, - client_->TransferToServer(*param0_literal)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<GlobalData> param0_data, + client_->TransferToServer(*param0_literal)); auto param0 = builder.Parameter(0, param0_literal->shape(), "param0"); auto fn = build_sum_rng(builder); builder.Map({param0}, fn); - TF_ASSIGN_OR_ASSERT_OK(auto computation, builder.Build()); + TF_ASSERT_OK_AND_ASSIGN(auto computation, builder.Build()); ExecutionOptions execution_options = execution_options_; execution_options.set_seed(125); - TF_ASSIGN_OR_ASSERT_OK( - auto actual, - client_->ExecuteAndTransfer(computation, - /*arguments=*/{param0_data.get()}, - &execution_options)); + TF_ASSERT_OK_AND_ASSIGN( + auto actual, client_->ExecuteAndTransfer( + computation, + /*arguments=*/{param0_data.get()}, &execution_options)); EXPECT_EQ(actual->f32s_size(), param0_literal->f32s_size()); for (int i = 0; i < param0_literal->f32s_size(); ++i) { @@ -213,39 +211,35 @@ XLA_TEST_F(PrngTest, PassInGlobalRngSeed) { std::unique_ptr<Literal> result1; { - TF_ASSIGN_OR_ASSERT_OK(auto computation, build_computation()); - TF_ASSIGN_OR_ASSERT_OK( - result1, - client_->ExecuteAndTransfer(computation, /*arguments=*/{}, - &execution_options1)); + TF_ASSERT_OK_AND_ASSIGN(auto computation, build_computation()); + TF_ASSERT_OK_AND_ASSIGN( + result1, client_->ExecuteAndTransfer(computation, /*arguments=*/{}, + &execution_options1)); } std::unique_ptr<Literal> result2; std::unique_ptr<Literal> result3; { - TF_ASSIGN_OR_ASSERT_OK(auto computation, build_computation()); - TF_ASSIGN_OR_ASSERT_OK( - result2, - client_->ExecuteAndTransfer(computation, /*arguments=*/{}, - &execution_options1)); - TF_ASSIGN_OR_ASSERT_OK( - result3, - client_->ExecuteAndTransfer(computation, /*arguments=*/{}, - &execution_options1)); + TF_ASSERT_OK_AND_ASSIGN(auto computation, build_computation()); + TF_ASSERT_OK_AND_ASSIGN( + result2, client_->ExecuteAndTransfer(computation, /*arguments=*/{}, + &execution_options1)); + TF_ASSERT_OK_AND_ASSIGN( + result3, client_->ExecuteAndTransfer(computation, /*arguments=*/{}, + &execution_options1)); } std::unique_ptr<Literal> result4; std::unique_ptr<Literal> result5; std::unique_ptr<Literal> result6; { - TF_ASSIGN_OR_ASSERT_OK(auto computation, build_computation()); - TF_ASSIGN_OR_ASSERT_OK( - result4, - client_->ExecuteAndTransfer(computation, /*arguments=*/{}, - &execution_options2)); - TF_ASSIGN_OR_ASSERT_OK( + TF_ASSERT_OK_AND_ASSIGN(auto computation, build_computation()); + TF_ASSERT_OK_AND_ASSIGN( + result4, client_->ExecuteAndTransfer(computation, /*arguments=*/{}, + &execution_options2)); + TF_ASSERT_OK_AND_ASSIGN( result5, client_->ExecuteAndTransfer(computation, /*arguments=*/{}, &execution_options_)); - TF_ASSIGN_OR_ASSERT_OK( + TF_ASSERT_OK_AND_ASSIGN( result6, client_->ExecuteAndTransfer(computation, /*arguments=*/{}, &execution_options_)); } diff --git a/tensorflow/compiler/xla/tests/reduce_window_test.cc b/tensorflow/compiler/xla/tests/reduce_window_test.cc index a438236c45..9774e40941 100644 --- a/tensorflow/compiler/xla/tests/reduce_window_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_window_test.cc @@ -314,8 +314,8 @@ TEST_F(ReduceWindowTest, R4UnitWindow) { auto res = ReferenceUtil::ReduceWindow4DAdd(input_array, 0.0f, {1, 1, 7, 1}, {1, 4, 1, 1}, padding); - TF_ASSIGN_OR_ASSERT_OK(std::unique_ptr<GlobalData> input_data, - client_->TransferToServer(*input_literal)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<GlobalData> input_data, + client_->TransferToServer(*input_literal)); ComputeAndCompareR4<float>(&builder_, *res, {input_data.get()}, ErrorSpec(1e-3, 1e-3)); } @@ -337,8 +337,8 @@ XLA_TEST_F(ReduceWindowTest, R4SecondMinorStride) { auto res = ReferenceUtil::ReduceWindow4DAdd( input_array, 0.0f, {1, 1, win_len, 1}, {1, 1, stride, 1}, padding); - TF_ASSIGN_OR_ASSERT_OK(std::unique_ptr<GlobalData> input_data, - client_->TransferToServer(*input_literal)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<GlobalData> input_data, + client_->TransferToServer(*input_literal)); ComputeAndCompareR4<float>(&builder_, *res, {input_data.get()}, ErrorSpec(1e-3, 1e-3)); } @@ -360,8 +360,8 @@ XLA_TEST_F(ReduceWindowTest, R4SecondMinorUnitStride) { auto res = ReferenceUtil::ReduceWindow4DAdd( input_array, 0.0f, {1, 1, win_len, 1}, {1, 1, stride, 1}, padding); - TF_ASSIGN_OR_ASSERT_OK(std::unique_ptr<GlobalData> input_data, - client_->TransferToServer(*input_literal)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<GlobalData> input_data, + client_->TransferToServer(*input_literal)); ComputeAndCompareR4<float>(&builder_, *res, {input_data.get()}, ErrorSpec(1e-3, 1e-3)); } @@ -383,8 +383,8 @@ XLA_TEST_F(ReduceWindowTest, R4SecondMinorWin) { auto res = ReferenceUtil::ReduceWindow4DAdd( input_array, 0.0f, {1, 1, win_len, 1}, {1, 1, stride, 1}, padding); - TF_ASSIGN_OR_ASSERT_OK(std::unique_ptr<GlobalData> input_data, - client_->TransferToServer(*input_literal)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<GlobalData> input_data, + client_->TransferToServer(*input_literal)); ComputeAndCompareR4<float>(&builder_, *res, {input_data.get()}, ErrorSpec(1e-3, 1e-3)); } @@ -507,8 +507,8 @@ class R4ReduceWindowTest input.FillIota(1); std::unique_ptr<Literal> input_literal = Literal::CreateR4FromArray4D(input); - TF_ASSIGN_OR_ASSERT_OK(std::unique_ptr<GlobalData> input_arg, - client_->TransferToServer(*input_literal)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<GlobalData> input_arg, + client_->TransferToServer(*input_literal)); std::vector<std::pair<int64, int64>> padding(4); for (int i = 0; i < 4; ++i) { @@ -774,8 +774,8 @@ TEST_P(R2ReduceWindowTest, Add) { std::unique_ptr<Literal> input_literal = Literal::CreateR2FromArray2DWithLayout( input, LayoutUtil::MakeLayout(param.layout)); - TF_ASSIGN_OR_ASSERT_OK(std::unique_ptr<GlobalData> input_arg, - client_->TransferToServer(*input_literal)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<GlobalData> input_arg, + client_->TransferToServer(*input_literal)); b.ReduceWindow(/*operand=*/ b.Parameter(0, input_literal->shape(), "p0"), /*init_value=*/b.ConstantR0<float>(kInitValue), @@ -878,8 +878,8 @@ TEST_P(R1ReduceWindowTest, DoIt) { std::iota(std::begin(input_vector), std::end(input_vector), 0); std::unique_ptr<Literal> input_literal = Literal::CreateR1(tensorflow::gtl::ArraySlice<float>(input_vector)); - TF_ASSIGN_OR_ASSERT_OK(std::unique_ptr<GlobalData> input_arg, - client_->TransferToServer(*input_literal)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<GlobalData> input_arg, + client_->TransferToServer(*input_literal)); auto computation = param.reducer == kAdd ? CreateScalarAddComputation(F32, &b) diff --git a/tensorflow/compiler/xla/tests/scalar_computations_test.cc b/tensorflow/compiler/xla/tests/scalar_computations_test.cc index 07bd00f015..6ebd11584f 100644 --- a/tensorflow/compiler/xla/tests/scalar_computations_test.cc +++ b/tensorflow/compiler/xla/tests/scalar_computations_test.cc @@ -354,7 +354,7 @@ TEST_F(ScalarComputationsTest, DivU32s) { ComputationDataHandle divisor = builder.Parameter(1, ShapeUtil::MakeShape(U32, {}), "divisor"); builder.Div(dividend, divisor); - TF_ASSIGN_OR_ASSERT_OK(div_computation, builder.Build()); + TF_ASSERT_OK_AND_ASSIGN(div_computation, builder.Build()); } for (uint32 divisor : vals) { @@ -362,10 +362,10 @@ TEST_F(ScalarComputationsTest, DivU32s) { for (uint32 dividend : vals) { auto dividend_literal = Literal::CreateR0<uint32>(dividend); auto divisor_literal = Literal::CreateR0<uint32>(divisor); - TF_ASSIGN_OR_ASSERT_OK(auto dividend_data, - client_->TransferToServer(*dividend_literal)); - TF_ASSIGN_OR_ASSERT_OK(auto divisor_data, - client_->TransferToServer(*divisor_literal)); + TF_ASSERT_OK_AND_ASSIGN(auto dividend_data, + client_->TransferToServer(*dividend_literal)); + TF_ASSERT_OK_AND_ASSIGN(auto divisor_data, + client_->TransferToServer(*divisor_literal)); auto actual_literal = client_ ->ExecuteAndTransfer(div_computation, @@ -395,7 +395,7 @@ TEST_F(ScalarComputationsTest, RemU32s) { ComputationDataHandle divisor = builder.Parameter(1, ShapeUtil::MakeShape(U32, {}), "divisor"); builder.Rem(dividend, divisor); - TF_ASSIGN_OR_ASSERT_OK(rem_computation, builder.Build()); + TF_ASSERT_OK_AND_ASSIGN(rem_computation, builder.Build()); } for (uint32 divisor : vals) { @@ -403,10 +403,10 @@ TEST_F(ScalarComputationsTest, RemU32s) { for (uint32 dividend : vals) { auto dividend_literal = Literal::CreateR0<uint32>(dividend); auto divisor_literal = Literal::CreateR0<uint32>(divisor); - TF_ASSIGN_OR_ASSERT_OK(auto dividend_data, - client_->TransferToServer(*dividend_literal)); - TF_ASSIGN_OR_ASSERT_OK(auto divisor_data, - client_->TransferToServer(*divisor_literal)); + TF_ASSERT_OK_AND_ASSIGN(auto dividend_data, + client_->TransferToServer(*dividend_literal)); + TF_ASSERT_OK_AND_ASSIGN(auto divisor_data, + client_->TransferToServer(*divisor_literal)); auto actual_literal = client_ ->ExecuteAndTransfer(rem_computation, @@ -426,7 +426,7 @@ TEST_F(ScalarComputationsTest, RemainderTwoScalarsNonConstDividendS32) { builder.Rem(x, builder.ConstantR0<int32>(80000)); std::unique_ptr<Literal> literal = Literal::CreateR0<int32>(87919); - TF_ASSIGN_OR_ASSERT_OK(auto input_data, client_->TransferToServer(*literal)); + TF_ASSERT_OK_AND_ASSIGN(auto input_data, client_->TransferToServer(*literal)); ComputeAndCompareR0<int32>(&builder, 7919, {input_data.get()}); } diff --git a/tensorflow/compiler/xla/tests/while_test.cc b/tensorflow/compiler/xla/tests/while_test.cc index afa7d871c0..8a6c40a0f5 100644 --- a/tensorflow/compiler/xla/tests/while_test.cc +++ b/tensorflow/compiler/xla/tests/while_test.cc @@ -387,7 +387,7 @@ TEST_F(WhileTest, TwoWhileWithTupleResult) { auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); builder.Lt(iteration, builder.ConstantR0<int32>(c1)); - TF_ASSIGN_OR_ASSERT_OK(condition, builder.Build()); + TF_ASSERT_OK_AND_ASSIGN(condition, builder.Build()); } Computation condition2; @@ -397,7 +397,7 @@ TEST_F(WhileTest, TwoWhileWithTupleResult) { auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); builder.Lt(iteration, builder.ConstantR0<int32>(c2)); - TF_ASSIGN_OR_ASSERT_OK(condition2, builder.Build()); + TF_ASSERT_OK_AND_ASSIGN(condition2, builder.Build()); } // Create a computation for the body. @@ -413,7 +413,7 @@ TEST_F(WhileTest, TwoWhileWithTupleResult) { auto new_weights = builder.Add(weights, input); auto result = builder.Tuple( {builder.Add(iteration, builder.ConstantR0<int32>(1)), new_weights}); - TF_ASSIGN_OR_ASSERT_OK(body, builder.Build()); + TF_ASSERT_OK_AND_ASSIGN(body, builder.Build()); } Computation body2; @@ -426,7 +426,7 @@ TEST_F(WhileTest, TwoWhileWithTupleResult) { auto new_weights = builder.Add(weights, input); auto result = builder.Tuple( {builder.Add(iteration, builder.ConstantR0<int32>(1)), new_weights}); - TF_ASSIGN_OR_ASSERT_OK(body2, builder.Build()); + TF_ASSERT_OK_AND_ASSIGN(body2, builder.Build()); } // Create a While node with computations for the condition and the body. @@ -466,7 +466,7 @@ TEST_F(WhileTest, TwoWhileLoopsAndSharedBody) { auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); builder.Lt(iteration, builder.ConstantR0<int32>(c1)); - TF_ASSIGN_OR_ASSERT_OK(condition, builder.Build()); + TF_ASSERT_OK_AND_ASSIGN(condition, builder.Build()); } Computation condition2; @@ -476,7 +476,7 @@ TEST_F(WhileTest, TwoWhileLoopsAndSharedBody) { auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); builder.Lt(iteration, builder.ConstantR0<int32>(c2)); - TF_ASSIGN_OR_ASSERT_OK(condition2, builder.Build()); + TF_ASSERT_OK_AND_ASSIGN(condition2, builder.Build()); } // Create a computation for the body. @@ -492,7 +492,7 @@ TEST_F(WhileTest, TwoWhileLoopsAndSharedBody) { auto new_weights = builder.Add(weights, input); auto result = builder.Tuple( {builder.Add(iteration, builder.ConstantR0<int32>(1)), new_weights}); - TF_ASSIGN_OR_ASSERT_OK(body, builder.Build()); + TF_ASSERT_OK_AND_ASSIGN(body, builder.Build()); } // Create a While node with computations for the condition and the body. @@ -533,7 +533,7 @@ TEST_F(WhileTest, DISABLED_ON_GPU(WhileLoopsWithSharedBodyAndInit)) { auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); builder.Lt(iteration, builder.ConstantR0<int32>(c1)); - TF_ASSIGN_OR_ASSERT_OK(condition, builder.Build()); + TF_ASSERT_OK_AND_ASSIGN(condition, builder.Build()); } Computation condition2; @@ -543,7 +543,7 @@ TEST_F(WhileTest, DISABLED_ON_GPU(WhileLoopsWithSharedBodyAndInit)) { auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); builder.Lt(iteration, builder.ConstantR0<int32>(c2)); - TF_ASSIGN_OR_ASSERT_OK(condition2, builder.Build()); + TF_ASSERT_OK_AND_ASSIGN(condition2, builder.Build()); } // Create a computation for the body. @@ -559,7 +559,7 @@ TEST_F(WhileTest, DISABLED_ON_GPU(WhileLoopsWithSharedBodyAndInit)) { auto new_weights = builder.Add(weights, input); auto result = builder.Tuple( {builder.Add(iteration, builder.ConstantR0<int32>(1)), new_weights}); - TF_ASSIGN_OR_ASSERT_OK(body, builder.Build()); + TF_ASSERT_OK_AND_ASSIGN(body, builder.Build()); } // Create a While node with computations for the condition and the body. @@ -697,11 +697,11 @@ TEST_F(WhileTest, WhileWithPrngScalarResult) { }; for (int i = 1; i < 4; ++i) { - TF_ASSIGN_OR_ASSERT_OK(auto computation, while_loop(i)); + TF_ASSERT_OK_AND_ASSIGN(auto computation, while_loop(i)); ExecutionOptions execution_options = execution_options_; execution_options.set_seed(65); - TF_ASSIGN_OR_ASSERT_OK( + TF_ASSERT_OK_AND_ASSIGN( auto result, client_->ExecuteAndTransfer(computation, {}, &execution_options)); } diff --git a/tensorflow/contrib/xla_tf_graph/xla_tf_graph_util_test.cc b/tensorflow/contrib/xla_tf_graph/xla_tf_graph_util_test.cc index 32336a767a..db811bda36 100644 --- a/tensorflow/contrib/xla_tf_graph/xla_tf_graph_util_test.cc +++ b/tensorflow/contrib/xla_tf_graph/xla_tf_graph_util_test.cc @@ -108,7 +108,7 @@ TEST(XlaTfGraphUtil, ConvertTfGraphToSessionModule) { std::vector<XlaCompiler::Argument> args = BuildAddGraphArguments(); std::unique_ptr<Graph> graph = BuildAddGraph(); - TF_ASSIGN_OR_ASSERT_OK( + TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr<xla::SessionModule> session_module, ConvertTfGraphToXlaSessionModule(args, std::move(graph))); @@ -122,11 +122,11 @@ TEST(XlaTfGraphUtil, ConvertTfGraphToSessionModule) { TEST(XlaTfGraphUtil, ConvertXlaSessionModuleToXlaNodes) { std::vector<XlaCompiler::Argument> args = BuildAddGraphArguments(); std::unique_ptr<Graph> graph = BuildAddGraph(); - TF_ASSIGN_OR_ASSERT_OK( + TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr<xla::SessionModule> session_module, ConvertTfGraphToXlaSessionModule(args, std::move(graph))); - TF_ASSIGN_OR_ASSERT_OK(auto xla_nodes, - ConvertXlaSessionModuleToXlaNodes(*session_module)); + TF_ASSERT_OK_AND_ASSIGN(auto xla_nodes, + ConvertXlaSessionModuleToXlaNodes(*session_module)); EXPECT_EQ(session_module->entry().requests_size(), xla_nodes.size()); } |