aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-04-03 14:42:31 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-03 14:44:57 -0700
commit86b3f351dd98db9cdfc2fc68a2a4328e90b36035 (patch)
treecdf8776c0e033925895f446d8e1ce81de883c687
parentb9b90965de4e475ccff8a571de016026447ee1df (diff)
[XLA] Redesign: implement and test dynamic slice.
PiperOrigin-RevId: 191502312
-rw-r--r--tensorflow/compiler/xla/client/xla_client/xla_builder.cc18
-rw-r--r--tensorflow/compiler/xla/tests/BUILD3
-rw-r--r--tensorflow/compiler/xla/tests/dynamic_ops_test.cc41
3 files changed, 37 insertions, 25 deletions
diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc
index c2e661cb3d..fe8ae77683 100644
--- a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc
+++ b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc
@@ -489,7 +489,23 @@ XlaOp XlaBuilder::SliceInDim(const XlaOp& operand, int64 start_index,
XlaOp XlaBuilder::DynamicSlice(const XlaOp& operand, const XlaOp& start_indices,
tensorflow::gtl::ArraySlice<int64> slice_sizes) {
- return UnimplementedOp();
+ return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ HloInstructionProto instr;
+
+ TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
+ TF_ASSIGN_OR_RETURN(const Shape& start_indices_shape,
+ GetShape(start_indices));
+ TF_ASSIGN_OR_RETURN(*instr.mutable_shape(),
+ ShapeInference::InferDynamicSliceShape(
+ operand_shape, start_indices_shape, slice_sizes));
+
+ for (int64 size : slice_sizes) {
+ instr.add_dynamic_slice_sizes(size);
+ }
+
+ return AddInstruction(std::move(instr), HloOpcode::kDynamicSlice,
+ {operand, start_indices});
+ });
}
XlaOp XlaBuilder::DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update,
diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD
index 5dcd02a1a4..6f58c20f34 100644
--- a/tensorflow/compiler/xla/tests/BUILD
+++ b/tensorflow/compiler/xla/tests/BUILD
@@ -977,9 +977,8 @@ xla_test(
"//tensorflow/compiler/xla:reference_util",
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla/client:client_library",
- "//tensorflow/compiler/xla/client:computation",
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/service:computation_placer",
"//tensorflow/compiler/xla/service:device_memory_allocator",
"//tensorflow/compiler/xla/service:local_service",
diff --git a/tensorflow/compiler/xla/tests/dynamic_ops_test.cc b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc
index 4f354e6aef..c0a16ad288 100644
--- a/tensorflow/compiler/xla/tests/dynamic_ops_test.cc
+++ b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc
@@ -18,9 +18,8 @@ limitations under the License.
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/client/client_library.h"
-#include "tensorflow/compiler/xla/client/computation.h"
-#include "tensorflow/compiler/xla/client/computation_builder.h"
#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/reference_util.h"
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
#include "tensorflow/compiler/xla/service/local_service.h"
@@ -112,10 +111,8 @@ class DynamicSliceTest : public ClientLibraryTestBase {
void TestR3Wrap() {
// Slice at dimension boundaries, but with sizes that cause indices to wrap.
RunR3<IndexT, DataT>(
- {{{1, 2}, {3, 4}, {5, 6}},
- {{7, 8}, {9, 10}, {11, 12}}},
- {0, 2, 1}, {2, 1, 2},
- {{{6, 5}}, {{12, 11}}});
+ {{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}}, {0, 2, 1},
+ {2, 1, 2}, {{{6, 5}}, {{12, 11}}});
}
template <typename IndexT, typename DataT>
@@ -137,9 +134,9 @@ class DynamicSliceTest : public ClientLibraryTestBase {
->Convert(primitive_util::NativeToPrimitiveType<DataT>())
.ValueOrDie());
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
// Initialize and transfer dynamic slice start indices parameter.
- ComputationDataHandle starts;
+ XlaOp starts;
std::unique_ptr<GlobalData> start_data = CreateR1Parameter<IndexT>(
slice_starts, 0, "slice_starts", &builder, &starts);
// Build dynamic slice computation.
@@ -163,9 +160,9 @@ class DynamicSliceTest : public ClientLibraryTestBase {
->Convert(primitive_util::NativeToPrimitiveType<DataT>())
.ValueOrDie());
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
// Initialize and transfer dynamic slice start indices parameter.
- ComputationDataHandle starts;
+ XlaOp starts;
std::unique_ptr<GlobalData> start_data = CreateR1Parameter<IndexT>(
slice_starts, 0, "slice_starts", &builder, &starts);
// Build dynamic slice computation.
@@ -189,9 +186,9 @@ class DynamicSliceTest : public ClientLibraryTestBase {
->Convert(primitive_util::NativeToPrimitiveType<DataT>())
.ValueOrDie());
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
// Initialize and transfer dynamic slice start indices parameter.
- ComputationDataHandle starts;
+ XlaOp starts;
std::unique_ptr<GlobalData> start_data = CreateR1Parameter<IndexT>(
slice_starts, 0, "slice_starts", &builder, &starts);
// Build dynamic slice computation.
@@ -359,9 +356,9 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase {
->Convert(primitive_util::NativeToPrimitiveType<DataT>())
.ValueOrDie());
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
// Initialize and transfer dynamic slice start indices parameter.
- ComputationDataHandle starts;
+ XlaOp starts;
std::unique_ptr<GlobalData> start_data = CreateR1Parameter<IndexT>(
slice_starts, 0, "slice_starts", &builder, &starts);
// Build dynamic slice computation.
@@ -390,9 +387,9 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase {
->Convert(primitive_util::NativeToPrimitiveType<DataT>())
.ValueOrDie());
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
// Initialize and transfer dynamic slice start indices parameter.
- ComputationDataHandle starts;
+ XlaOp starts;
std::unique_ptr<GlobalData> start_data = CreateR1Parameter<IndexT>(
slice_starts, 0, "slice_starts", &builder, &starts);
// Build dynamic slice computation.
@@ -421,9 +418,9 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase {
->Convert(primitive_util::NativeToPrimitiveType<DataT>())
.ValueOrDie());
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
// Initialize and transfer dynamic slice start indices parameter.
- ComputationDataHandle starts;
+ XlaOp starts;
std::unique_ptr<GlobalData> start_data = CreateR1Parameter<IndexT>(
slice_starts, 0, "slice_starts", &builder, &starts);
// Build dynamic slice computation.
@@ -474,13 +471,13 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase {
}
// Build dynamic slice computation.
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
// Initialize and transfer input parameter.
- ComputationDataHandle input;
+ XlaOp input;
std::unique_ptr<GlobalData> input_data =
CreateR3Parameter<T>(input_values, 0, "input_values", &builder, &input);
// Initialize and transfer update parameter.
- ComputationDataHandle update;
+ XlaOp update;
std::unique_ptr<GlobalData> update_data = CreateR3Parameter<T>(
update_values, 1, "update_values", &builder, &update);
auto starts = builder.ConstantR1<int32>({index, 0, 0});
@@ -672,7 +669,7 @@ void BM_DynamicSlice(int num_iters) {
TransferManager::GetForPlatform(platform).ValueOrDie();
int device_ordinal = client->default_device_ordinal();
- ComputationBuilder builder(client, "DynamicSlice");
+ XlaBuilder builder("DynamicSlice");
// Create input as a constant: shape [1, 2, 3, 4]
auto input_literal = Literal::CreateR4(