diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-04-03 14:42:31 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-04-03 14:44:57 -0700 |
commit | 86b3f351dd98db9cdfc2fc68a2a4328e90b36035 (patch) | |
tree | cdf8776c0e033925895f446d8e1ce81de883c687 | |
parent | b9b90965de4e475ccff8a571de016026447ee1df (diff) |
[XLA] Redesign: implement and test dynamic slice.
PiperOrigin-RevId: 191502312
-rw-r--r-- | tensorflow/compiler/xla/client/xla_client/xla_builder.cc | 18 | ||||
-rw-r--r-- | tensorflow/compiler/xla/tests/BUILD | 3 | ||||
-rw-r--r-- | tensorflow/compiler/xla/tests/dynamic_ops_test.cc | 41 |
3 files changed, 37 insertions, 25 deletions
diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc index c2e661cb3d..fe8ae77683 100644 --- a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc @@ -489,7 +489,23 @@ XlaOp XlaBuilder::SliceInDim(const XlaOp& operand, int64 start_index, XlaOp XlaBuilder::DynamicSlice(const XlaOp& operand, const XlaOp& start_indices, tensorflow::gtl::ArraySlice<int64> slice_sizes) { - return UnimplementedOp(); + return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> { + HloInstructionProto instr; + + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); + TF_ASSIGN_OR_RETURN(const Shape& start_indices_shape, + GetShape(start_indices)); + TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), + ShapeInference::InferDynamicSliceShape( + operand_shape, start_indices_shape, slice_sizes)); + + for (int64 size : slice_sizes) { + instr.add_dynamic_slice_sizes(size); + } + + return AddInstruction(std::move(instr), HloOpcode::kDynamicSlice, + {operand, start_indices}); + }); } XlaOp XlaBuilder::DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update, diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index 5dcd02a1a4..6f58c20f34 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -977,9 +977,8 @@ xla_test( "//tensorflow/compiler/xla:reference_util", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/service:computation_placer", "//tensorflow/compiler/xla/service:device_memory_allocator", "//tensorflow/compiler/xla/service:local_service", diff --git a/tensorflow/compiler/xla/tests/dynamic_ops_test.cc b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc index 4f354e6aef..c0a16ad288 100644 --- a/tensorflow/compiler/xla/tests/dynamic_ops_test.cc +++ b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc @@ -18,9 +18,8 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/computation.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" #include "tensorflow/compiler/xla/service/local_service.h" @@ -112,10 +111,8 @@ class DynamicSliceTest : public ClientLibraryTestBase { void TestR3Wrap() { // Slice at dimension boundaries, but with sizes that cause indices to wrap. RunR3<IndexT, DataT>( - {{{1, 2}, {3, 4}, {5, 6}}, - {{7, 8}, {9, 10}, {11, 12}}}, - {0, 2, 1}, {2, 1, 2}, - {{{6, 5}}, {{12, 11}}}); + {{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}}, {0, 2, 1}, + {2, 1, 2}, {{{6, 5}}, {{12, 11}}}); } template <typename IndexT, typename DataT> @@ -137,9 +134,9 @@ class DynamicSliceTest : public ClientLibraryTestBase { ->Convert(primitive_util::NativeToPrimitiveType<DataT>()) .ValueOrDie()); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); // Initialize and transfer dynamic slice start indices parameter. - ComputationDataHandle starts; + XlaOp starts; std::unique_ptr<GlobalData> start_data = CreateR1Parameter<IndexT>( slice_starts, 0, "slice_starts", &builder, &starts); // Build dynamic slice computation. @@ -163,9 +160,9 @@ class DynamicSliceTest : public ClientLibraryTestBase { ->Convert(primitive_util::NativeToPrimitiveType<DataT>()) .ValueOrDie()); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); // Initialize and transfer dynamic slice start indices parameter. - ComputationDataHandle starts; + XlaOp starts; std::unique_ptr<GlobalData> start_data = CreateR1Parameter<IndexT>( slice_starts, 0, "slice_starts", &builder, &starts); // Build dynamic slice computation. @@ -189,9 +186,9 @@ class DynamicSliceTest : public ClientLibraryTestBase { ->Convert(primitive_util::NativeToPrimitiveType<DataT>()) .ValueOrDie()); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); // Initialize and transfer dynamic slice start indices parameter. - ComputationDataHandle starts; + XlaOp starts; std::unique_ptr<GlobalData> start_data = CreateR1Parameter<IndexT>( slice_starts, 0, "slice_starts", &builder, &starts); // Build dynamic slice computation. @@ -359,9 +356,9 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { ->Convert(primitive_util::NativeToPrimitiveType<DataT>()) .ValueOrDie()); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); // Initialize and transfer dynamic slice start indices parameter. - ComputationDataHandle starts; + XlaOp starts; std::unique_ptr<GlobalData> start_data = CreateR1Parameter<IndexT>( slice_starts, 0, "slice_starts", &builder, &starts); // Build dynamic slice computation. @@ -390,9 +387,9 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { ->Convert(primitive_util::NativeToPrimitiveType<DataT>()) .ValueOrDie()); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); // Initialize and transfer dynamic slice start indices parameter. - ComputationDataHandle starts; + XlaOp starts; std::unique_ptr<GlobalData> start_data = CreateR1Parameter<IndexT>( slice_starts, 0, "slice_starts", &builder, &starts); // Build dynamic slice computation. @@ -421,9 +418,9 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { ->Convert(primitive_util::NativeToPrimitiveType<DataT>()) .ValueOrDie()); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); // Initialize and transfer dynamic slice start indices parameter. - ComputationDataHandle starts; + XlaOp starts; std::unique_ptr<GlobalData> start_data = CreateR1Parameter<IndexT>( slice_starts, 0, "slice_starts", &builder, &starts); // Build dynamic slice computation. @@ -474,13 +471,13 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { } // Build dynamic slice computation. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); // Initialize and transfer input parameter. - ComputationDataHandle input; + XlaOp input; std::unique_ptr<GlobalData> input_data = CreateR3Parameter<T>(input_values, 0, "input_values", &builder, &input); // Initialize and transfer update parameter. - ComputationDataHandle update; + XlaOp update; std::unique_ptr<GlobalData> update_data = CreateR3Parameter<T>( update_values, 1, "update_values", &builder, &update); auto starts = builder.ConstantR1<int32>({index, 0, 0}); @@ -672,7 +669,7 @@ void BM_DynamicSlice(int num_iters) { TransferManager::GetForPlatform(platform).ValueOrDie(); int device_ordinal = client->default_device_ordinal(); - ComputationBuilder builder(client, "DynamicSlice"); + XlaBuilder builder("DynamicSlice"); // Create input as a constant: shape [1, 2, 3, 4] auto input_literal = Literal::CreateR4( |