aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-06-26 13:48:16 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-06-26 13:51:52 -0700
commitb13e96e21c1229a905a623111dd89d2bd0cba53b (patch)
treed60c94e4a87554943931c438f9885c9a48556ce9
parentb91b5baab23cfb8d9da269166070511c28eeaf84 (diff)
Automated g4 rollback of changelist 160183498
PiperOrigin-RevId: 160189134
-rw-r--r--tensorflow/BUILD1
-rw-r--r--tensorflow/compiler/jit/BUILD5
-rw-r--r--tensorflow/compiler/jit/kernels/BUILD1
-rw-r--r--tensorflow/compiler/plugin/executor/BUILD4
-rw-r--r--tensorflow/compiler/tests/BUILD18
-rw-r--r--tensorflow/contrib/lookup/lookup_ops.py4
-rw-r--r--tensorflow/core/kernels/sparse_reduce_sum_op.cc305
-rw-r--r--tensorflow/python/BUILD2
-rw-r--r--tensorflow/python/kernel_tests/sparse_ops_test.py3
9 files changed, 316 insertions, 27 deletions
diff --git a/tensorflow/BUILD b/tensorflow/BUILD
index 6e76db02de..62ae5ae78c 100644
--- a/tensorflow/BUILD
+++ b/tensorflow/BUILD
@@ -208,7 +208,6 @@ filegroup(
"//tensorflow/compiler/jit/kernels:all_files",
"//tensorflow/compiler/jit/legacy_flags:all_files",
"//tensorflow/compiler/jit/ops:all_files",
- "//tensorflow/compiler/plugin/executor:all_files",
"//tensorflow/compiler/tests:all_files",
"//tensorflow/compiler/tf2xla:all_files",
"//tensorflow/compiler/tf2xla/cc:all_files",
diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD
index 7ebd842218..306e704415 100644
--- a/tensorflow/compiler/jit/BUILD
+++ b/tensorflow/compiler/jit/BUILD
@@ -15,10 +15,7 @@ package_group(
)
package(
- default_visibility = [
- ":internal",
- "//tensorflow/compiler/plugin/executor:__pkg__",
- ],
+ default_visibility = [":internal"],
)
load("//tensorflow:tensorflow.bzl", "cc_header_only_library")
diff --git a/tensorflow/compiler/jit/kernels/BUILD b/tensorflow/compiler/jit/kernels/BUILD
index 97f3512a6c..ed204b8182 100644
--- a/tensorflow/compiler/jit/kernels/BUILD
+++ b/tensorflow/compiler/jit/kernels/BUILD
@@ -2,7 +2,6 @@ licenses(["notice"]) # Apache 2.0
package(
default_visibility = [
- "//tensorflow/compiler/plugin/executor:__pkg__",
"//tensorflow/compiler/tf2xla:internal",
],
)
diff --git a/tensorflow/compiler/plugin/executor/BUILD b/tensorflow/compiler/plugin/executor/BUILD
index 2e5875705f..9bc706abdf 100644
--- a/tensorflow/compiler/plugin/executor/BUILD
+++ b/tensorflow/compiler/plugin/executor/BUILD
@@ -11,11 +11,9 @@ cc_library(
"*.h",
]),
deps = [
- "//tensorflow/compiler/jit:xla_device",
"//tensorflow/compiler/jit:xla_jit_headers_lib",
- "//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/xla:xla_headers_lib",
- "//tensorflow/compiler/xla/service",
+ "//tensorflow/compiler/xla/service:hlo_evaluator",
"//third_party/eigen3",
"@local_config_cuda//cuda:cuda_headers",
"@protobuf//:protobuf_headers",
diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD
index 432b24756d..044857d422 100644
--- a/tensorflow/compiler/tests/BUILD
+++ b/tensorflow/compiler/tests/BUILD
@@ -175,11 +175,6 @@ tf_xla_py_test(
name = "slice_ops_test",
size = "small",
srcs = ["slice_ops_test.py"],
- # TODO(b/62962492): Test fails with assertion error.
- tags = [
- "manual",
- "notap",
- ],
deps = [
":xla_test",
"//tensorflow/python:array_ops",
@@ -461,11 +456,6 @@ cuda_py_test(
"//tensorflow/python:math_ops",
"//tensorflow/python:nn_ops",
],
- # TODO(b/62961789): Test fails with SIGABRT
- tags = [
- "manual",
- "notap",
- ],
)
cc_library(
@@ -534,12 +524,8 @@ cuda_py_test(
# --dump_graph_dir, and the config file was written by hand.
#
# Run the following to build a minimal benchmark of the computation on Android:
-# $ bazel build -c opt --cxxopt='-std=c++11' --linkopt='-lm' \
-# --cpu=armeabi-v7a \
-# --host_crosstool_top=@bazel_tools//tools/cpp:toolchain \
-# --crosstool_top=//external:android/crosstool \
-# //tensorflow/compiler/tests:lstm_layer_inference_benchmark
-
+# $ bazel build -c opt --config=android_arm \
+# third_party/tensorflow/compiler/tests:lstm_layer_inference_benchmark
#
# Currently the resulting binary size is ~190KB
tf_library(
diff --git a/tensorflow/contrib/lookup/lookup_ops.py b/tensorflow/contrib/lookup/lookup_ops.py
index f53f38f3cf..ce8518267f 100644
--- a/tensorflow/contrib/lookup/lookup_ops.py
+++ b/tensorflow/contrib/lookup/lookup_ops.py
@@ -260,7 +260,11 @@ def index_to_string(tensor, mapping, default_value="UNK", name=None):
For example:
```python
+<<<<<<< HEAD
mapping_string = tf.constant(["emerson", "lake", "palmer"])
+=======
+ mapping_string = tf.constant(["emerson", "lake", "palmer")
+>>>>>>> 338a7ead4475d6b97b420d6d1c56ff66815e3e7b
indices = tf.constant([1, 5], tf.int64)
values = tf.contrib.lookup.index_to_string(
indices, mapping=mapping_string, default_value="UNKNOWN")
diff --git a/tensorflow/core/kernels/sparse_reduce_sum_op.cc b/tensorflow/core/kernels/sparse_reduce_sum_op.cc
new file mode 100644
index 0000000000..074aab9f9e
--- /dev/null
+++ b/tensorflow/core/kernels/sparse_reduce_sum_op.cc
@@ -0,0 +1,305 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// See docs in ../ops/sparse_ops.cc.
+
+#define EIGEN_USE_THREADS
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_util.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/util/sparse/sparse_tensor.h"
+
+// TODO(b/31496047): Fix non-standard include order.
+#include <numeric> // clang-format off
+
+using tensorflow::sparse::SparseTensor;
+using tensorflow::gtl::ArraySlice;
+
+namespace tensorflow {
+
+struct ReduceDetails {
+ // The dimensions to call Reorder() with.
+ std::vector<int64> reorder_dims;
+
+ // The dimensions to call group() with after Reorder().
+ std::vector<int64> group_by_dims;
+
+ // The shape after reduction.
+ TensorShape reduced_shape;
+};
+
+// Compute common reduce parameters that'll be used for SparseTensor
+// reductions. Usage:
+// ReduceDetails reduction = SparseTensorReduceHelper(sp, axes, keep_dims);
+// sp.Reorder(reduction.reorder_dims);
+// for (const auto& g : sp.group(reduction.group_by_dims)) {
+// ...
+// }
+// // Set output shape to reduction.reduced_shape.
+ReduceDetails SparseTensorReduceHelper(const SparseTensor &sp,
+ gtl::ArraySlice<int32> axes_slice,
+ bool keep_dims) {
+ ReduceDetails reduction;
+
+ std::vector<int32> reduction_axes(axes_slice.begin(), axes_slice.end());
+ int ndims = sp.dims();
+ for (int64 i = 0; i < reduction_axes.size(); ++i) {
+ reduction_axes[i] = (reduction_axes[i] + ndims) % ndims;
+ }
+ std::sort(reduction_axes.begin(), reduction_axes.end());
+
+ // (0) Calculate the grouping dimensions:
+ // group_by_dims == {0, .., NDIMS-1} \ reduction_axes.
+ std::vector<int64> perm(ndims);
+ std::iota(perm.begin(), perm.end(), 0);
+
+ // Requires perm and reduction_axes_ be sorted; group_by_dims will be
+ // sorted as well.
+ std::set_difference(
+ perm.begin(), perm.end(), reduction_axes.begin(), reduction_axes.end(),
+ std::inserter(reduction.group_by_dims, reduction.group_by_dims.begin()));
+
+ // Now append the rest of the axes (the complement of group_by_dims_);
+ // result is used by Reorder().
+ reduction.reorder_dims = reduction.group_by_dims;
+ std::set_difference(perm.begin(), perm.end(), reduction.group_by_dims.begin(),
+ reduction.group_by_dims.end(),
+ std::back_inserter(reduction.reorder_dims));
+
+ // (1) Calculate the shape after reduction.
+ auto sp_shape = sp.shape();
+ std::vector<int64> out_dim_sizes;
+ if (keep_dims) {
+ out_dim_sizes.reserve(ndims);
+ auto beg = reduction.group_by_dims.begin();
+ auto end = reduction.group_by_dims.end();
+ for (int d = 0; d < ndims; ++d) {
+ if (std::find(beg, end, d) == end) {
+ out_dim_sizes.push_back(1); // A reduced axis.
+ } else {
+ out_dim_sizes.push_back(sp_shape[d]);
+ }
+ }
+ } else {
+ out_dim_sizes = sp.PickDims(reduction.group_by_dims);
+ }
+
+ reduction.reduced_shape = TensorShape(out_dim_sizes);
+ return reduction;
+}
+
+Status ValidateInputs(const Tensor *shape_t, const Tensor *reduction_axes_t) {
+ // indices and values are validated in SparseTensor ctor.
+ if (!TensorShapeUtils::IsVector(shape_t->shape())) {
+ return errors::InvalidArgument(
+ "Expected input_shape to be a vector; got shape: ",
+ shape_t->shape().DebugString());
+ }
+ if (!TensorShapeUtils::IsScalar(reduction_axes_t->shape()) &&
+ !TensorShapeUtils::IsVector(reduction_axes_t->shape())) {
+ return errors::InvalidArgument(
+ "Expected reduction_axes to be a scalar or a vector; got shape: ",
+ reduction_axes_t->shape().DebugString());
+ }
+
+ const auto reduction_axes_flat = reduction_axes_t->flat<int32>();
+ for (int64 i = 0; i < reduction_axes_flat.size(); i++) {
+ int32 axis = reduction_axes_flat(i);
+ if (axis < -shape_t->NumElements() || axis >= shape_t->NumElements()) {
+ return errors::InvalidArgument("Invalid reduction dimension ", axis,
+ ", for input with ",
+ shape_t->NumElements(), " dimensions.");
+ }
+ }
+
+ return Status::OK();
+}
+
+template <typename T>
+class SparseReduceSumOp : public OpKernel {
+ public:
+ explicit SparseReduceSumOp(OpKernelConstruction *ctx) : OpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("keep_dims", &keep_dims_));
+ }
+
+ void Compute(OpKernelContext *ctx) override {
+ const Tensor *indices_t, *values_t, *shape_t, *reduction_axes_t;
+ OP_REQUIRES_OK(ctx, ctx->input("input_indices", &indices_t));
+ OP_REQUIRES_OK(ctx, ctx->input("input_values", &values_t));
+ OP_REQUIRES_OK(ctx, ctx->input("input_shape", &shape_t));
+ OP_REQUIRES_OK(ctx, ctx->input("reduction_axes", &reduction_axes_t));
+
+ OP_REQUIRES_OK(ctx, ValidateInputs(shape_t, reduction_axes_t));
+
+ // TODO(zongheng): we will call Reorder() below, which will modify
+ // in-place the underlying indices and values buffers. To avoid
+ // surprises of this kernel being stateful, we work around the above by
+ // making deep copies here. Remove this if/when we change Reorder()'s
+ // semantics.
+ const auto shape_vec = shape_t->vec<int64>();
+ SparseTensor sp(tensor::DeepCopy(*indices_t), tensor::DeepCopy(*values_t),
+ TensorShape(shape_vec));
+ ReduceDetails reduction = SparseTensorReduceHelper(
+ sp, reduction_axes_t->flat<int32>(), keep_dims_);
+
+ Tensor *out_values;
+ OP_REQUIRES_OK(
+ ctx, ctx->allocate_output(0, reduction.reduced_shape, &out_values));
+ auto out_flat = out_values->flat<T>();
+ out_flat.setZero();
+
+ Tensor tmp_group_sum;
+ OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
+ TensorShape({}), &tmp_group_sum));
+ auto group_sum = tmp_group_sum.scalar<T>();
+
+ // Compute strides, and use it to convert coords to flat index. The
+ // coordinates returned by .group() have the same ndims as group_by_dims.
+ gtl::InlinedVector<int64, 8> output_strides(reduction.group_by_dims.size());
+ if (!output_strides.empty()) { // Do this iff we don't reduce all.
+ output_strides.back() = 1;
+ for (int d = output_strides.size() - 2; d >= 0; --d) {
+ output_strides[d] =
+ output_strides[d + 1] * shape_vec(reduction.group_by_dims[d + 1]);
+ }
+ }
+
+ auto CoordinatesToFlatIndex = [](ArraySlice<int64> coords,
+ ArraySlice<int64> strides) {
+ if (strides.empty()) { // Reduce all.
+ return 0LL;
+ }
+ CHECK_EQ(coords.size(), strides.size());
+ int64 idx = 0;
+ for (int i = 0; i < coords.size(); ++i) {
+ idx += coords[i] * strides[i];
+ }
+ return idx;
+ };
+
+ // Each group maps one-on-one onto a value in the reduced tensor.
+ // g.group() provides the coordinates of a particular reduced value.
+ sp.Reorder<T>(reduction.reorder_dims);
+ for (const auto &g : sp.group(reduction.group_by_dims)) {
+ group_sum.device(ctx->eigen_cpu_device()) = g.template values<T>().sum();
+ const int64 idx = CoordinatesToFlatIndex(g.group(), output_strides);
+ out_flat(idx) = group_sum();
+ VLOG(2) << "coords: " << str_util::Join(g.group(), ",")
+ << "; idx: " << idx << "; group sum: " << group_sum();
+ }
+ }
+
+ private:
+ // True if the number of dimensions should be maintained.
+ bool keep_dims_;
+};
+
+#define REGISTER_KERNELS(T) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("SparseReduceSum").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
+ SparseReduceSumOp<T>)
+TF_CALL_NUMBER_TYPES(REGISTER_KERNELS);
+#undef REGISTER_KERNELS
+
+template <typename T>
+class SparseReduceSumSparseOp : public OpKernel {
+ public:
+ explicit SparseReduceSumSparseOp(OpKernelConstruction *ctx) : OpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("keep_dims", &keep_dims_));
+ }
+
+ void Compute(OpKernelContext *ctx) override {
+ const Tensor *indices_t, *values_t, *shape_t, *reduction_axes_t;
+ OP_REQUIRES_OK(ctx, ctx->input("input_indices", &indices_t));
+ OP_REQUIRES_OK(ctx, ctx->input("input_values", &values_t));
+ OP_REQUIRES_OK(ctx, ctx->input("input_shape", &shape_t));
+ OP_REQUIRES_OK(ctx, ctx->input("reduction_axes", &reduction_axes_t));
+
+ OP_REQUIRES_OK(ctx, ValidateInputs(shape_t, reduction_axes_t));
+
+ SparseTensor sp(tensor::DeepCopy(*indices_t), tensor::DeepCopy(*values_t),
+ TensorShape(shape_t->vec<int64>()));
+ ReduceDetails reduction = SparseTensorReduceHelper(
+ sp, reduction_axes_t->flat<int32>(), keep_dims_);
+
+ sp.Reorder<T>(reduction.reorder_dims);
+ // Count nnzs in the output SparseTensor.
+ int64 nnz = 0;
+ auto iter = sp.group(reduction.group_by_dims);
+ for (auto it = iter.begin(); it != iter.end(); ++it) {
+ nnz++;
+ }
+
+ Tensor *out_indices_t;
+ OP_REQUIRES_OK(ctx,
+ ctx->allocate_output(
+ 0, TensorShape({nnz, reduction.reduced_shape.dims()}),
+ &out_indices_t));
+ typename TTypes<int64>::Matrix out_indices_mat =
+ out_indices_t->matrix<int64>();
+ // For keep_dims. We don't explicitly set dim fields for reduced dims below.
+ out_indices_mat.setZero();
+
+ Tensor *out_values_t;
+ OP_REQUIRES_OK(ctx,
+ ctx->allocate_output(1, TensorShape({nnz}), &out_values_t));
+ auto out_flat = out_values_t->flat<T>();
+
+ Tensor tmp_group_sum;
+ OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
+ TensorShape({}), &tmp_group_sum));
+ auto group_sum = tmp_group_sum.scalar<T>();
+ int64 i = 0;
+ for (const auto &g : sp.group(reduction.group_by_dims)) {
+ group_sum.device(ctx->eigen_cpu_device()) = g.template values<T>().sum();
+ std::vector<int64> group = g.group();
+ for (int64 j = 0; j < group.size(); j++) {
+ if (keep_dims_) {
+ out_indices_mat(i, reduction.group_by_dims[j]) = group[j];
+ } else {
+ out_indices_mat(i, j) = group[j];
+ }
+ }
+ out_flat(i) = group_sum();
+ i++;
+ VLOG(2) << "coords: " << str_util::Join(g.group(), ",")
+ << "; group sum: " << group_sum();
+ }
+
+ Tensor *out_shape_t;
+ OP_REQUIRES_OK(ctx, ctx->allocate_output(
+ 2, TensorShape({reduction.reduced_shape.dims()}),
+ &out_shape_t));
+ auto out_shape_flat = out_shape_t->flat<int64>();
+ auto out_dim_sizes = reduction.reduced_shape.dim_sizes();
+ std::copy(out_dim_sizes.begin(), out_dim_sizes.end(), &out_shape_flat(0));
+ }
+
+ private:
+ // True if the number of dimensions should be maintained.
+ bool keep_dims_;
+};
+
+#define REGISTER_KERNELS(T) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("SparseReduceSumSparse").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
+ SparseReduceSumSparseOp<T>)
+TF_CALL_NUMBER_TYPES(REGISTER_KERNELS);
+#undef REGISTER_KERNELS
+
+} // namespace tensorflow
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index 936348e01d..22b18b9cde 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -3056,7 +3056,6 @@ py_test(
srcs = ["client/session_clusterspec_prop_test.py"],
srcs_version = "PY2AND3",
tags = [
- "no_gpu",
"no_pip_gpu",
],
deps = [
@@ -3081,7 +3080,6 @@ py_test(
srcs = ["client/session_list_devices_test.py"],
srcs_version = "PY2AND3",
tags = [
- "no_gpu",
"no_pip_gpu",
],
deps = [
diff --git a/tensorflow/python/kernel_tests/sparse_ops_test.py b/tensorflow/python/kernel_tests/sparse_ops_test.py
index 6664362226..14eb2cba68 100644
--- a/tensorflow/python/kernel_tests/sparse_ops_test.py
+++ b/tensorflow/python/kernel_tests/sparse_ops_test.py
@@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
import numpy as np
+import unittest
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@@ -605,6 +606,7 @@ class SparseReduceTest(test_util.TensorFlowTestCase):
self._compare(sp_t, reduction_axes, ndims, True, False)
self._compare(sp_t, reduction_axes, ndims, True, True)
+ @unittest.skipIf(np.__version__ == "1.13.0", "numpy 1.13 bug")
def testSimpleAndRandomInputs(self):
if np.__version__ == "1.13.0":
self.skipTest("numpy 1.13.0 bug")
@@ -644,6 +646,7 @@ class SparseReduceTest(test_util.TensorFlowTestCase):
with self.assertRaisesOpError("Invalid reduction dimension 2"):
sparse_ops.sparse_reduce_max(sp_t, 2).eval()
+ @unittest.skipIf(np.__version__ == "1.13.0", "numpy 1.13 bug")
def testGradient(self):
if np.__version__ == "1.13.0":
self.skipTest("numpy 1.13.0 bug")