aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-12-15 17:32:50 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-12-15 17:39:26 -0800
commit9648f8040a559f6cf9bbe0501ba96f2b2c2864b1 (patch)
tree57dc6e959e0a534622eaf392ee43b7691378b10e /tensorflow/compiler
parent5b5445b9a7aa2664a90c4fc946ecf268c971425b (diff)
Automated g4 rollback of changelist 179258973
PiperOrigin-RevId: 179260538
Diffstat (limited to 'tensorflow/compiler')
-rw-r--r--tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc87
-rw-r--r--tensorflow/compiler/xla/service/BUILD2
-rw-r--r--tensorflow/compiler/xla/service/copy_insertion_test.cc2
-rw-r--r--tensorflow/compiler/xla/service/layout_assignment.cc2
-rw-r--r--tensorflow/compiler/xla/shape_tree.h2
5 files changed, 26 insertions, 69 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
index 03c22354a9..351fda2517 100644
--- a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
@@ -311,32 +311,6 @@ class TensorArrayGatherOp : public XlaOpKernel {
xla::ComputationDataHandle ta = resource->value;
- // Look for the case where the gather takes a simple slice from the
- // tensor array (0, 1, 2, 3, 4, ..., N)
- std::vector<int64> const_indices;
- Status status = ctx->ConstantInputAsIntVector(1, &const_indices);
- if (status.ok()) {
- bool gather_is_dense_slice = true;
- for (auto i = 0; i < const_indices.size(); i++) {
- if (const_indices[i] != i) {
- gather_is_dense_slice = false;
- break;
- }
- }
-
- if (gather_is_dense_slice) {
- std::vector<int64> begin(ta_shape.dims(), 0);
- std::vector<int64> strides(ta_shape.dims(), 1);
- std::vector<int64> end(ta_shape.dims(), 1);
- end[0] = const_indices.size();
- for (auto i = 1; i < ta_shape.dims(); i++) {
- end[i] = ta_shape.dim_size(i);
- }
- ctx->SetOutput(0, b->Slice(ta, begin, end, strides));
- return;
- }
- }
-
xla::ComputationDataHandle gather = XlaComputeGatherDynamicSlice(
ctx, ta, ta_shape, indices, indices_shape, 0, dtype_, index_type, b);
ctx->SetOutput(0, gather);
@@ -378,47 +352,28 @@ class TensorArrayScatterOp : public XlaOpKernel {
const xla::ComputationDataHandle value = ctx->Input(2);
const xla::ComputationDataHandle flow = ctx->Input(3);
- // Look for the case where the scatter is for each sub-tensor in order. The
- // tensor array implementation allows for this to be a straight addition.
- bool scatter_all_elements_in_order = false;
- std::vector<int64> const_indices;
- Status status = ctx->ConstantInputAsIntVector(1, &const_indices);
- if (status.ok() && num_indices == value_shape.dim_size(0)) {
- scatter_all_elements_in_order = true;
- for (auto i = 0; i < num_indices; i++) {
- if (const_indices[i] != i) {
- scatter_all_elements_in_order = false;
- break;
- }
- }
- }
+ auto slice_dims = value_shape.dim_sizes();
+ slice_dims[0] = 1LL;
- if (scatter_all_elements_in_order) {
- ta = b->Add(ta, value);
- } else {
- auto slice_dims = value_shape.dim_sizes();
- slice_dims[0] = 1LL;
-
- std::vector<int64> value_starts(value_shape.dims(), 0);
- auto value_ends = value_shape.dim_sizes();
-
- std::vector<int64> value_strides(value_shape.dims(), 1);
-
- // For every (index, value) pair, update the corresponding TensorArray
- // storage.
- for (int i = 0; i < num_indices; ++i) {
- // Slice out part of the value.
- value_starts[0] = i;
- value_ends[0] = i + 1;
- auto slice = b->Slice(value, value_starts, value_ends, value_strides);
-
- // start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0].
- auto index = b->Slice(indices, {i}, {i + 1}, {1});
- auto start_indices =
- b->Pad(b->Reshape(index, {1}), b->ConstantR0<int32>(0),
- xla::MakeEdgePaddingConfig({{0, elem_shape.dims()}}));
- ta = DynamicAddSlice(b, ta, slice, slice_dims, start_indices);
- }
+ std::vector<int64> value_starts(value_shape.dims(), 0);
+ auto value_ends = value_shape.dim_sizes();
+
+ std::vector<int64> value_strides(value_shape.dims(), 1);
+
+ // For every (index, value) pair, update the corresponding TensorArray
+ // storage.
+ for (int i = 0; i < num_indices; ++i) {
+ // Slice out part of the value.
+ value_starts[0] = i;
+ value_ends[0] = i + 1;
+ auto slice = b->Slice(value, value_starts, value_ends, value_strides);
+
+ // start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0].
+ auto index = b->Slice(indices, {i}, {i + 1}, {1});
+ auto start_indices =
+ b->Pad(b->Reshape(index, {1}), b->ConstantR0<int32>(0),
+ xla::MakeEdgePaddingConfig({{0, elem_shape.dims()}}));
+ ta = DynamicAddSlice(b, ta, slice, slice_dims, start_indices);
}
resource->value = ta;
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index 3655a08cf3..07ef98076e 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -90,6 +90,8 @@ cc_library(
":shape_inference",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:status",
+ "//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
diff --git a/tensorflow/compiler/xla/service/copy_insertion_test.cc b/tensorflow/compiler/xla/service/copy_insertion_test.cc
index 8388574716..3278fd5f06 100644
--- a/tensorflow/compiler/xla/service/copy_insertion_test.cc
+++ b/tensorflow/compiler/xla/service/copy_insertion_test.cc
@@ -339,7 +339,7 @@ TEST_F(CopyInsertionTest, ElementOfNestedTupleParameter) {
ShapeUtil::MakeShape(F32, {42})}),
"param0"));
- // The return value of the computation is the zero-th element of the nested
+ // The return value of the computation is the zero-th elemnt of the nested
// tuple. This element is itself a tuple.
auto gte = builder.AddInstruction(HloInstruction::CreateGetTupleElement(
ShapeUtil::GetSubshape(param->shape(), {0}), param, 0));
diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc
index 328afe42ba..af726271ae 100644
--- a/tensorflow/compiler/xla/service/layout_assignment.cc
+++ b/tensorflow/compiler/xla/service/layout_assignment.cc
@@ -1303,7 +1303,7 @@ Status LayoutAssignment::AssignLayouts(const LayoutConstraints& constraints,
TF_RET_CHECK(LayoutUtil::HasLayout(instruction->shape()));
}
- // Copy the root instruction's result if the it does not match the result
+ // Copy the root instrucion's result if the it does not match the result
// layout constraint
if (constraints.ResultLayout() != nullptr &&
!constraints.ResultLayout()->MatchesLayoutInShape(
diff --git a/tensorflow/compiler/xla/shape_tree.h b/tensorflow/compiler/xla/shape_tree.h
index d752619bd6..bf8d190150 100644
--- a/tensorflow/compiler/xla/shape_tree.h
+++ b/tensorflow/compiler/xla/shape_tree.h
@@ -238,7 +238,7 @@ class ShapeTree {
// (or compatible).
// index : the index of the element in the shape. See ShapeUtil::GetSubshape
// for definition of index.
- // data : The data value at this element.
+ // data : The data value at this elemnt.
template <typename Fn>
void ForEachElement(const Fn& func) const;