diff options
Diffstat (limited to 'tensorflow/compiler/xla/tests/dynamic_ops_test.cc')
-rw-r--r-- | tensorflow/compiler/xla/tests/dynamic_ops_test.cc | 506 |
1 files changed, 506 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/tests/dynamic_ops_test.cc b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc new file mode 100644 index 0000000000..cecc4872df --- /dev/null +++ b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc @@ -0,0 +1,506 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include <numeric> +#include <vector> + +#include "tensorflow/compiler/xla/array2d.h" +#include "tensorflow/compiler/xla/client/client_library.h" +#include "tensorflow/compiler/xla/client/computation.h" +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" +#include "tensorflow/compiler/xla/reference_util.h" +#include "tensorflow/compiler/xla/service/device_memory_allocator.h" +#include "tensorflow/compiler/xla/service/local_service.h" +#include "tensorflow/compiler/xla/service/platform_util.h" +#include "tensorflow/compiler/xla/service/shaped_buffer.h" +#include "tensorflow/compiler/xla/service/transfer_manager.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" +#include "tensorflow/core/platform/types.h" + +namespace se = ::perftools::gputools; + +namespace xla { +namespace { + +class DynamicSliceTest : public ClientLibraryTestBase { + protected: + template <typename IndexT> + void TestR1() { + // Slice at dimension start. + RunR1<IndexT>({0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0}, {0}, {5}, + {0.0, 1.0, 2.0, 3.0, 4.0}); + // Slice in the middle. + RunR1<IndexT>({0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0}, {2}, {3}, + {2.0, 3.0, 4.0}); + // Slice at dimension boundaries. + RunR1<IndexT>({0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0}, {5}, {3}, + {5.0, 6.0, 7.0}); + // Slice at dimension boundaries, but with sizes that cause indices to wrap. + RunR1<IndexT>({0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0}, {6}, {4}, + {6.0, 7.0, 0.0, 1.0}); + } + + template <typename IndexT> + void TestR2() { + // Slice at dimension start. + RunR2<IndexT>({{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}, {7.0f, 8.0f, 9.0f}}, + {0, 0}, {2, 2}, {{1.0f, 2.0f}, {4.0f, 5.0f}}); + // Slice in the middle. + RunR2<IndexT>({{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}, {7.0f, 8.0f, 9.0f}}, + {1, 1}, {2, 1}, {{5.0f}, {8.0f}}); + // Slice at dimension boundaries. + RunR2<IndexT>({{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}, {7.0f, 8.0f, 9.0f}}, + {1, 1}, {2, 1}, {{5.0f}, {8.0f}}); + // Slice at dimension boundaries, but with sizes that cause indices to wrap. + RunR2<IndexT>({{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}, {7.0f, 8.0f, 9.0f}}, + {1, 1}, {3, 3}, + {{5.0f, 6.0f, 4.0f}, {8.0f, 9.0f, 7.0f}, {2.0f, 3.0f, 1.0f}}); + } + + template <typename IndexT> + void TestR3() { + // R3 Shape: [2, 3, 2] + // clang-format off + + // Slice at dimension start. + RunR3<IndexT>( + {{{1.0f, 2.0f}, {3.0f, 4.0f}, {5.0f, 6.0f}}, + {{7.0f, 8.0f}, {9.0f, 10.0f}, {11.0f, 12.0f}}}, + {0, 0, 0}, {2, 1, 2}, + {{{1.0f, 2.0f}}, {{7.0f, 8.0f}}}); + + // Slice in the middle. + RunR3<IndexT>( + {{{1.0f, 2.0f}, {3.0f, 4.0f}, {5.0f, 6.0f}}, + {{7.0f, 8.0f}, {9.0f, 10.0f}, {11.0f, 12.0f}}}, + {0, 1, 1}, {2, 2, 1}, + {{{4.0f}, {6.0f}}, {{10.0f}, {12.0f}}}); + + // Slice at dimension boundaries, but with sizes that cause indices to wrap. + RunR3<IndexT>( + {{{1.0f, 2.0f}, {3.0f, 4.0f}, {5.0f, 6.0f}}, + {{7.0f, 8.0f}, {9.0f, 10.0f}, {11.0f, 12.0f}}}, + {0, 2, 1}, {2, 2, 1}, + {{{6.0f}, {2.0f}}, {{12.0f}, {8.0f}}}); + + // clang-format on + } + + template <typename IndexT> + void RunR1(const std::vector<float>& input_values, + const std::vector<IndexT> slice_starts, + const std::vector<int64> slice_sizes, + const std::vector<float>& expected_values) { + ComputationBuilder builder(client_, TestName()); + // Initialize and transfer dynamic slice start indices parameter. + ComputationDataHandle starts; + std::unique_ptr<GlobalData> start_data = CreateR1Parameter<IndexT>( + slice_starts, 0, "slice_starts", &builder, &starts); + // Build dynamic slice computation. + auto input = builder.ConstantR1<float>(input_values); + builder.DynamicSlice(input, starts, slice_sizes); + // Run computation and compare against expected values. + ComputeAndCompareR1<float>(&builder, expected_values, {start_data.get()}, + ErrorSpec(0.000001)); + } + + template <typename IndexT> + void RunR2(const Array2D<float>& input_values, + const std::vector<IndexT> slice_starts, + const std::vector<int64> slice_sizes, + const Array2D<float>& expected_values) { + ComputationBuilder builder(client_, TestName()); + // Initialize and transfer dynamic slice start indices parameter. + ComputationDataHandle starts; + std::unique_ptr<GlobalData> start_data = CreateR1Parameter<IndexT>( + slice_starts, 0, "slice_starts", &builder, &starts); + // Build dynamic slice computation. + auto input = builder.ConstantR2FromArray2D<float>(input_values); + builder.DynamicSlice(input, starts, slice_sizes); + // Run computation and compare against expected values. + ComputeAndCompareR2<float>(&builder, expected_values, {start_data.get()}, + ErrorSpec(0.000001)); + } + + template <typename IndexT> + void RunR3(const Array3D<float>& input_values, + const std::vector<IndexT> slice_starts, + const std::vector<int64> slice_sizes, + const Array3D<float>& expected_values) { + ComputationBuilder builder(client_, TestName()); + // Initialize and transfer dynamic slice start indices parameter. + ComputationDataHandle starts; + std::unique_ptr<GlobalData> start_data = CreateR1Parameter<IndexT>( + slice_starts, 0, "slice_starts", &builder, &starts); + // Build dynamic slice computation. + auto input = builder.ConstantR3FromArray3D<float>(input_values); + builder.DynamicSlice(input, starts, slice_sizes); + // Run computation and compare against expected values. + ComputeAndCompareR3<float>(&builder, expected_values, {start_data.get()}, + ErrorSpec(0.000001)); + } +}; + +XLA_TEST_F(DynamicSliceTest, Int32R1) { TestR1<int32>(); } + +XLA_TEST_F(DynamicSliceTest, Int64R1) { TestR1<int64>(); } + +XLA_TEST_F(DynamicSliceTest, UInt64R1) { TestR1<uint64>(); } + +XLA_TEST_F(DynamicSliceTest, Int32R2) { TestR2<int32>(); } + +XLA_TEST_F(DynamicSliceTest, Int64R2) { TestR2<int64>(); } + +XLA_TEST_F(DynamicSliceTest, UInt64R2) { TestR2<uint64>(); } + +XLA_TEST_F(DynamicSliceTest, Int32R3) { TestR3<int32>(); } + +XLA_TEST_F(DynamicSliceTest, Int64R3) { TestR3<int64>(); } + +XLA_TEST_F(DynamicSliceTest, UInt64R3) { TestR3<uint64>(); } + +class DynamicUpdateSliceTest : public ClientLibraryTestBase { + protected: + template <typename IndexT> + void TestR1() { + // clang-format off + // Slice at dimension start. + RunR1<IndexT>({0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0}, + {8.0, 9.0, 10.0}, {0}, + {8.0, 9.0, 10.0, 3.0, 4.0, 5.0, 6.0, 7.0}); + // Slice in the middle. + RunR1<IndexT>({0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0}, + {8.0, 9.0, 10.0}, {2}, + {0.0, 1.0, 8.0, 9.0, 10.0, 5.0, 6.0, 7.0}); + // Slice at dimension boundaries. + RunR1<IndexT>({0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0}, + {8.0, 9.0, 10.0}, {5}, + {0.0, 1.0, 2.0, 3.0, 4.0, 8.0, 9.0, 10.0}); + // Slice at dimension boundaries, but with sizes that cause indices to wrap. + RunR1<IndexT>({0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0}, + {8.0, 9.0, 10.0}, {6}, + {0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 8.0, 9.0}); + // clang-format on + } + + template <typename IndexT> + void TestR2() { + // clang-format off + // Slice at dimension start. + RunR2<IndexT>( + {{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}, {7.0f, 8.0f, 9.0f}}, + {{10.0f, 11.0f}}, {0, 0}, + {{10.0f, 11.0f, 3.0f}, {4.0f, 5.0f, 6.0f}, {7.0f, 8.0f, 9.0f}}); + // Slice in the middle. + RunR2<IndexT>( + {{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}, {7.0f, 8.0f, 9.0f}}, + {{10.0f, 11.0f}}, {1, 1}, + {{1.0f, 2.0f, 3.0f}, {4.0f, 10.0f, 11.0f}, {7.0f, 8.0f, 9.0f}}); + // Slice at dimension boundaries. + RunR2<IndexT>( + {{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}, {7.0f, 8.0f, 9.0f}}, + {{10.0f, 11.0f}}, {2, 1}, + {{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}, {7.0f, 10.0f, 11.0f}}); + // Slice at dimension boundaries, but with sizes that cause indices to wrap. + RunR2<IndexT>( + {{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}, {7.0f, 8.0f, 9.0f}}, + {{10.0f, 11.0f}}, {2, 2}, + {{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}, {7.0f, 8.0f, 10.0f}}); + // clang-format on + } + + template <typename IndexT> + void TestR3() { + // R3 Shape: [2, 3, 2] + // clang-format off + // Slice at dimension start. + RunR3<IndexT>( + {{{1.0f, 2.0f}, {3.0f, 4.0f}, {5.0f, 6.0f}}, + {{7.0f, 8.0f}, {9.0f, 10.0f}, {11.0f, 12.0f}}}, + {{{13.0f, 14.0f}, {15.0f, 16.0f}}, + {{17.0f, 18.0f}, {19.0f, 20.0f}}}, + {0, 0, 0}, + {{{13.0f, 14.0f}, {15.0f, 16.0f}, {5.0f, 6.0f}}, + {{17.0f, 18.0f}, {19.0f, 20.0f}, {11.0f, 12.0f}}}); + // Slice in the middle. + RunR3<IndexT>( + {{{1.0f, 2.0f}, {3.0f, 4.0f}, {5.0f, 6.0f}}, + {{7.0f, 8.0f}, {9.0f, 10.0f}, {11.0f, 12.0f}}}, + {{{13.0f}, {15.0f}}}, + {1, 1, 1}, + {{{1.0f, 2.0f}, {3.0f, 4.0f}, {5.0f, 6.0f}}, + {{7.0f, 8.0f}, {9.0f, 13.0f}, {11.0f, 15.0f}}}); + // Slice at dimension boundaries, but with sizes that cause indices to wrap. + RunR3<IndexT>( + {{{1.0f, 2.0f}, {3.0f, 4.0f}, {5.0f, 6.0f}}, + {{7.0f, 8.0f}, {9.0f, 10.0f}, {11.0f, 12.0f}}}, + {{{13.0f}, {15.0f}}}, + {1, 2, 1}, + {{{1.0f, 2.0f}, {3.0f, 4.0f}, {5.0f, 6.0f}}, + {{7.0f, 8.0f}, {9.0f, 10.0f}, {11.0f, 13.0f}}}); + // clang-format on + } + + template <typename IndexT> + void RunR1(const std::vector<float>& input_values, + const std::vector<float>& update_values, + const std::vector<IndexT> slice_starts, + const std::vector<float>& expected_values) { + ComputationBuilder builder(client_, TestName()); + // Initialize and transfer dynamic slice start indices parameter. + ComputationDataHandle starts; + std::unique_ptr<GlobalData> start_data = CreateR1Parameter<IndexT>( + slice_starts, 0, "slice_starts", &builder, &starts); + // Build dynamic slice computation. + auto input = builder.ConstantR1<float>(input_values); + auto update = builder.ConstantR1<float>(update_values); + builder.DynamicUpdateSlice(input, update, starts); + // Run computation and compare against expected values. + ComputeAndCompareR1<float>(&builder, expected_values, {start_data.get()}, + ErrorSpec(0.000001)); + } + + template <typename IndexT> + void RunR2(const Array2D<float>& input_values, + const Array2D<float>& update_values, + const std::vector<IndexT> slice_starts, + const Array2D<float>& expected_values) { + ComputationBuilder builder(client_, TestName()); + // Initialize and transfer dynamic slice start indices parameter. + ComputationDataHandle starts; + std::unique_ptr<GlobalData> start_data = CreateR1Parameter<IndexT>( + slice_starts, 0, "slice_starts", &builder, &starts); + // Build dynamic slice computation. + auto input = builder.ConstantR2FromArray2D<float>(input_values); + auto update = builder.ConstantR2FromArray2D<float>(update_values); + builder.DynamicUpdateSlice(input, update, starts); + // Run computation and compare against expected values. + ComputeAndCompareR2<float>(&builder, expected_values, {start_data.get()}, + ErrorSpec(0.000001)); + } + + template <typename IndexT> + void RunR3(const Array3D<float>& input_values, + const Array3D<float>& update_values, + const std::vector<IndexT> slice_starts, + const Array3D<float>& expected_values) { + ComputationBuilder builder(client_, TestName()); + // Initialize and transfer dynamic slice start indices parameter. + ComputationDataHandle starts; + std::unique_ptr<GlobalData> start_data = CreateR1Parameter<IndexT>( + slice_starts, 0, "slice_starts", &builder, &starts); + // Build dynamic slice computation. + auto input = builder.ConstantR3FromArray3D<float>(input_values); + auto update = builder.ConstantR3FromArray3D<float>(update_values); + builder.DynamicUpdateSlice(input, update, starts); + // Run computation and compare against expected values. + ComputeAndCompareR3<float>(&builder, expected_values, {start_data.get()}, + ErrorSpec(0.000001)); + } + + void RunR3Contiguous(std::vector<int32> operand_shape, int32 index, + int32 size) { + const int32 kSeq = operand_shape[0]; + const int32 kBatch = operand_shape[1]; + const int32 kDim = operand_shape[2]; + Array3D<float> input_values(kSeq, kBatch, kDim); + Array3D<float> update_values(size, kBatch, kDim); + Array3D<float> expected_values(kSeq, kBatch, kDim); + + input_values.FillIota(0); + float val = 1000; + update_values.FillIota(val); + + // TODO(b/34128753) Expected values may vary depending on backend when + // the update wraps. According to documentation, the results are technically + // implementation specific where the update is out of bounds, and hence + // we don't really know what to pass into ComputeAndCompareR3. + expected_values.FillIota(0); + for (int i = 0; i < size; i++) { + for (int j = 0; j < kBatch; j++) { + for (int k = 0; k < kDim; k++) { + expected_values((index + i) % kSeq, j, k) = val++; + } + } + } + if (VLOG_IS_ON(1)) { + DumpArray<float>("input", input_values); + DumpArray<float>("update", update_values); + DumpArray<float>("expected", expected_values); + } + + // Build dynamic slice computation. + ComputationBuilder builder(client_, TestName()); + auto starts = builder.ConstantR1<int32>({index, 0, 0}); + auto input = builder.ConstantR3FromArray3D<float>(input_values); + auto update = builder.ConstantR3FromArray3D<float>(update_values); + builder.DynamicUpdateSlice(input, update, starts); + + // Run computation and compare against expected values. + ComputeAndCompareR3<float>(&builder, expected_values, {}, + ErrorSpec(0.000001)); + } + + template <typename NativeT> + void DumpArray(const string& name, const Array3D<NativeT> values) { + std::unique_ptr<Literal> literal = + LiteralUtil::CreateR3FromArray3D<NativeT>(values); + LOG(INFO) << name << ":" << LiteralUtil::ToString(*literal); + } +}; + +XLA_TEST_F(DynamicUpdateSliceTest, Int32R1) { TestR1<int32>(); } + +XLA_TEST_F(DynamicUpdateSliceTest, Int64R1) { TestR1<int64>(); } + +XLA_TEST_F(DynamicUpdateSliceTest, UInt64R1) { TestR1<uint64>(); } + +XLA_TEST_F(DynamicUpdateSliceTest, Int32R2) { TestR2<int32>(); } + +XLA_TEST_F(DynamicUpdateSliceTest, Int64R2) { TestR2<int64>(); } + +XLA_TEST_F(DynamicUpdateSliceTest, UInt64R2) { TestR2<uint64>(); } + +XLA_TEST_F(DynamicUpdateSliceTest, Int32R3) { TestR3<int32>(); } + +XLA_TEST_F(DynamicUpdateSliceTest, Int64R3) { TestR3<int64>(); } + +XLA_TEST_F(DynamicUpdateSliceTest, UInt64R3) { TestR3<uint64>(); } + +// Tests for simple R3 case where the update is contiguous (i.e. the minor +// two dimensions are not sliced). +XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousSingleElement) { + // Single element, no wrap. + std::vector<int32> operand_shape({4, 5, 2}); + RunR3Contiguous(operand_shape, /*index=*/1, /*size=*/1); +} + +XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousMultipleElements) { + // Multiple element, no wrap. + std::vector<int32> operand_shape({4, 5, 2}); + RunR3Contiguous(operand_shape, /*index=*/1, /*size=*/2); +} + +// TODO(b/34128753) CPU and GPU failed on 2016-01-06. Appears not to handle +// wrapping as expected. +XLA_TEST_F(DynamicUpdateSliceTest, + DISABLED_ON_CPU(DISABLED_ON_GPU(R3ContiguousMultipleWrapping))) { + // Multiple element, wrapping. + std::vector<int32> operand_shape({4, 5, 2}); + RunR3Contiguous(operand_shape, /*index=*/3, /*size=*/2); +} + +// TODO(b/34128753) CPU and GPU failed on 2016-01-06. Appears not to handle +// wrapping as expected. +XLA_TEST_F(DynamicUpdateSliceTest, + DISABLED_ON_CPU(DISABLED_ON_GPU(R3ContiguousTooLarge))) { + // Multiple element, update size larger than operand. + std::vector<int32> operand_shape({4, 5, 2}); + RunR3Contiguous(operand_shape, /*index=*/5, /*size=*/2); +} + +XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousUnaligned) { + std::vector<int32> operand_shape({3, 123, 247}); + RunR3Contiguous(operand_shape, /*index=*/1, /*size=*/1); +} + +// TODO(b/34134076) Disabled on GPU 2016-01-06 due to out-of-memory error. +XLA_TEST_F(DynamicUpdateSliceTest, DISABLED_ON_GPU(R3ContiguousLarger)) { + std::vector<int32> operand_shape({32, 128, 1024}); + RunR3Contiguous(operand_shape, /*index=*/7, /*size=*/1); +} + +void BM_DynamicSlice(int num_iters) { + tensorflow::testing::StopTiming(); + + se::Platform* platform = PlatformUtil::GetDefaultPlatform().ValueOrDie(); + auto executors = PlatformUtil::GetStreamExecutors(platform).ValueOrDie(); + StreamExecutorMemoryAllocator allocator(platform, executors); + LocalClient* client = + ClientLibrary::GetOrCreateLocalClient(platform).ValueOrDie(); + auto* transfer_manager = + TransferManager::GetForPlatform(platform).ValueOrDie(); + int device_ordinal = client->default_device_ordinal(); + + ComputationBuilder builder(client, "DynamicSlice"); + + // Create input as a constant: shape [1, 2, 3, 4] + auto input_literal = LiteralUtil::CreateR4( + {{{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}, + {{13, 14, 15, 16}, {17, 18, 19, 20}, {21, 22, 23, 24}}}}); + auto input = builder.ConstantLiteral(*input_literal); + + // Create dynamic slice start indices as a parameter: shape [4] + auto start_indices_shape = ShapeUtil::MakeShape(S32, {4}); + auto start_indices = + builder.Parameter(0, start_indices_shape, "start_indices"); + // Add DynamicSlice op to the computatation. + builder.DynamicSlice(input, start_indices, {1, 1, 1, 1}); + auto computation = builder.Build().ConsumeValueOrDie(); + + // Initialize and transfer parameter buffer. + auto buffer = ScopedShapedBuffer::MakeScopedShapedBuffer(start_indices_shape, + &allocator, 0) + .ConsumeValueOrDie(); + + auto start_indices_literal = LiteralUtil::CreateR1<int32>({0, 1, 2, 3}); + ASSERT_IS_OK(transfer_manager->TransferLiteralToDevice( + executors[device_ordinal], *start_indices_literal, + buffer->mutable_buffer({}))); + + // Run some warm-up executions. + LocalExecuteOptions options; + options.set_allocator(&allocator); + const int kWarmups = 2; + for (int i = 0; i < kWarmups; ++i) { + auto result = client->ExecuteLocally(computation, {buffer.get()}, options); + ASSERT_TRUE(result.ok()); + } + + // Run benchmark. + tensorflow::testing::StartTiming(); + for (int i = 0; i < num_iters; ++i) { + auto result = client->ExecuteLocally(computation, {buffer.get()}, options); + ASSERT_TRUE(result.ok()); + } +} +BENCHMARK(BM_DynamicSlice); + +} // namespace +} // namespace xla + +int main(int argc, char** argv) { + std::vector<tensorflow::Flag> flag_list; + xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); + xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); + const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); + if (!parse_result) { + LOG(ERROR) << "\n" << usage; + return 2; + } + testing::InitGoogleTest(&argc, argv); + if (argc > 1) { + LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage; + return 2; + } + return RUN_ALL_TESTS(); +} |