aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/util
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/util')
-rw-r--r--tensorflow/core/util/bcast.cc120
-rw-r--r--tensorflow/core/util/bcast.h99
-rw-r--r--tensorflow/core/util/bcast_test.cc226
-rw-r--r--tensorflow/core/util/device_name_utils.cc338
-rw-r--r--tensorflow/core/util/device_name_utils.h141
-rw-r--r--tensorflow/core/util/device_name_utils_test.cc369
-rw-r--r--tensorflow/core/util/event.proto29
-rw-r--r--tensorflow/core/util/events_writer.cc144
-rw-r--r--tensorflow/core/util/events_writer.h77
-rw-r--r--tensorflow/core/util/events_writer_test.cc198
-rw-r--r--tensorflow/core/util/guarded_philox_random.cc39
-rw-r--r--tensorflow/core/util/guarded_philox_random.h56
-rw-r--r--tensorflow/core/util/padding.cc24
-rw-r--r--tensorflow/core/util/padding.h37
-rw-r--r--tensorflow/core/util/port.cc13
-rw-r--r--tensorflow/core/util/port.h11
-rw-r--r--tensorflow/core/util/saved_tensor_slice.proto76
-rw-r--r--tensorflow/core/util/saved_tensor_slice_util.cc76
-rw-r--r--tensorflow/core/util/saved_tensor_slice_util.h110
-rw-r--r--tensorflow/core/util/saved_tensor_slice_util_test.cc32
-rw-r--r--tensorflow/core/util/sparse/README.md222
-rw-r--r--tensorflow/core/util/sparse/dim_comparator.h60
-rw-r--r--tensorflow/core/util/sparse/group_iterator.cc49
-rw-r--r--tensorflow/core/util/sparse/group_iterator.h120
-rw-r--r--tensorflow/core/util/sparse/sparse_tensor.h353
-rw-r--r--tensorflow/core/util/sparse/sparse_tensor_test.cc467
-rw-r--r--tensorflow/core/util/tensor_slice_reader.cc230
-rw-r--r--tensorflow/core/util/tensor_slice_reader.h157
-rw-r--r--tensorflow/core/util/tensor_slice_reader_cache.cc94
-rw-r--r--tensorflow/core/util/tensor_slice_reader_cache.h73
-rw-r--r--tensorflow/core/util/tensor_slice_reader_test.cc395
-rw-r--r--tensorflow/core/util/tensor_slice_set.cc148
-rw-r--r--tensorflow/core/util/tensor_slice_set.h73
-rw-r--r--tensorflow/core/util/tensor_slice_set_test.cc227
-rw-r--r--tensorflow/core/util/tensor_slice_util.h88
-rw-r--r--tensorflow/core/util/tensor_slice_util_test.cc91
-rw-r--r--tensorflow/core/util/tensor_slice_writer.cc110
-rw-r--r--tensorflow/core/util/tensor_slice_writer.h149
-rw-r--r--tensorflow/core/util/tensor_slice_writer_test.cc248
-rw-r--r--tensorflow/core/util/use_cudnn.cc20
-rw-r--r--tensorflow/core/util/use_cudnn.h12
-rw-r--r--tensorflow/core/util/util.cc81
-rw-r--r--tensorflow/core/util/util.h40
-rw-r--r--tensorflow/core/util/work_sharder.cc57
-rw-r--r--tensorflow/core/util/work_sharder.h33
-rw-r--r--tensorflow/core/util/work_sharder_test.cc57
46 files changed, 5869 insertions, 0 deletions
diff --git a/tensorflow/core/util/bcast.cc b/tensorflow/core/util/bcast.cc
new file mode 100644
index 0000000000..4e70b78751
--- /dev/null
+++ b/tensorflow/core/util/bcast.cc
@@ -0,0 +1,120 @@
+#include "tensorflow/core/util/bcast.h"
+
+#include "tensorflow/core/platform/logging.h"
+namespace tensorflow {
+
+/* static */
+void BCast::Reverse(Vec* shape) { std::reverse(shape->begin(), shape->end()); }
+
+BCast::BCast(const Vec& sx, const Vec& sy) {
+ // Reverse the shape of x and y for convenience.
+ // After the reverse, 0-th is the inner-most dimension.
+ Vec x = sx;
+ Reverse(&x);
+ Vec y = sy;
+ Reverse(&y);
+
+ // 1-extend and align x and y so that they are the same size.
+ if (x.size() > y.size()) {
+ y.resize(x.size(), 1);
+ } else {
+ x.resize(y.size(), 1);
+ }
+
+ // Going through each dimension starting from the inner-most
+ // dimension, compares dimension of x and y. They are compatible if
+ // they are equal or either is 1.
+ enum State {
+ UNKNOWN,
+ SAME,
+ X_ONE,
+ Y_ONE,
+ };
+ State prev = UNKNOWN;
+ const int64 n = x.size();
+ for (int i = 0; i < n; ++i) {
+ // Output shape.
+ State curr = UNKNOWN;
+ const int64 x_i = x[i]; // i-th dimension of x.
+ CHECK_GE(x_i, 0);
+ const int64 y_i = y[i]; // i-th dimension of y.
+ CHECK_GE(y_i, 0);
+ int64 o_i; // i-th dimension of the output.
+ int64 bx_i; // i-th broadcast for x.
+ int64 by_i; // i-th broadcast for y.
+ // Invariant:
+ // o_i = x_i * bx_i = y_i * by_i
+ if (x_i == y_i) {
+ // No broadcast.
+ o_i = x_i;
+ bx_i = 1;
+ by_i = 1;
+ curr = SAME;
+ } else if (x_i == 1) {
+ // x broadcast to y on this dimension.
+ o_i = y_i;
+ bx_i = y_i;
+ by_i = 1;
+ grad_x_reduce_idx_.push_back(n - 1 - i);
+ curr = X_ONE;
+ } else if (y_i == 1) {
+ // y broadcast to x on this dimension.
+ o_i = x_i;
+ bx_i = 1;
+ by_i = x_i;
+ grad_y_reduce_idx_.push_back(n - 1 - i);
+ curr = Y_ONE;
+ } else {
+ valid_ = false;
+ return;
+ }
+ output_.push_back(o_i);
+ // Reshape/broadcast.
+ // Invariant:
+ // result[i] == x_reshape[i] * x_bcast[i] == y_reshape_[i] * y_bcast_[i]
+ if (curr == SAME && x_i == 1) {
+ // Both side are 1s.
+ grad_x_reduce_idx_.push_back(n - 1 - i);
+ grad_y_reduce_idx_.push_back(n - 1 - i);
+ continue;
+ } else if (prev == curr) {
+ // It is a run of the same cases (no broadcast, x broadcast to
+ // y, y broadcast to x). We can reshape the input so that fewer
+ // dimensions are involved in the intermediate computation.
+ result_.back() *= o_i;
+ x_reshape_.back() *= x_i;
+ x_bcast_.back() *= bx_i;
+ y_reshape_.back() *= y_i;
+ y_bcast_.back() *= by_i;
+ } else {
+ result_.push_back(o_i);
+ x_reshape_.push_back(x_i);
+ x_bcast_.push_back(bx_i);
+ y_reshape_.push_back(y_i);
+ y_bcast_.push_back(by_i);
+ }
+ prev = curr;
+ }
+
+ if (result_.empty()) {
+ // Can happen when both x and y are effectively scalar.
+ result_.push_back(1);
+ x_reshape_.push_back(1);
+ x_bcast_.push_back(1);
+ y_reshape_.push_back(1);
+ y_bcast_.push_back(1);
+ }
+
+ // Reverse all vectors since x and y were reversed at very
+ // beginning.
+ Reverse(&x_reshape_);
+ Reverse(&x_bcast_);
+ Reverse(&y_reshape_);
+ Reverse(&y_bcast_);
+ Reverse(&result_);
+ Reverse(&output_);
+ Reverse(&grad_x_reduce_idx_);
+ Reverse(&grad_y_reduce_idx_);
+}
+
+} // end namespace tensorflow
diff --git a/tensorflow/core/util/bcast.h b/tensorflow/core/util/bcast.h
new file mode 100644
index 0000000000..9f0233e415
--- /dev/null
+++ b/tensorflow/core/util/bcast.h
@@ -0,0 +1,99 @@
+#ifndef TENSORFLOW_UTIL_BCAST_H_
+#define TENSORFLOW_UTIL_BCAST_H_
+
+#include <algorithm>
+#include <vector>
+
+#include "tensorflow/core/platform/port.h"
+
+#include "tensorflow/core/platform/logging.h"
+namespace tensorflow {
+
+// BCast is a helper for broadcasting binary tensor operation.
+// TensorFlow's broadcasting rule follows that of numpy (See
+// http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html).
+//
+// The rule has the following properties:
+//
+// 1. suffix matching: the rule starts with the right-most
+// dimension, and works towards the left-most dimension. Since
+// TensorFlow is row-major, the right-most dimension (the last
+// element in the shape of a tensor) is the inner-most, a.k.a.
+// the fastest changing, dimension.
+//
+// 2. Two dimensions are compatible for broadcasting if both are the
+// same or either is 1.
+//
+// BCast takes the shape of two tensors and computes a few vectors of
+// int32 that are useful for the caller to reshape the tensors, apply
+// the right broadcasts to them, compute the broadcasted operation,
+// and possibly the gradients. In a nutshell, the caller is expected
+// to compute the broadcasted operation as following:
+//
+// BCast b(x.shape(), y.shape());
+// output = x.reshape(b.x_reshape()).broadcast(b.x_bcast())
+// _op_
+// y.reshape(b.y_reshape()).broadcast(b.y_bcast())
+//
+// For the gradient computation,
+// grad_x = sum(grad * backprop_x(x, y), grad_x_reduce_idx)
+// .reshape(x.shape())
+// grad_y = sum(grad * backprop_y(x, y), grad_y_reduce_idx)
+// .reshape(y.shape())
+// backprop_x and backprop_y are functionals of the binary function "op",
+// e.g.,
+// for +, backprop_x(x, y) = backprop_y(x, y) = 1;
+// for *, backprop_x(x, y) = y, backprop_y(x, y) = x;
+// for /, backprop_x(x, y) = 1/y, backprop_y(x, y) = -x/y^2;
+//
+// The multiplication in the grad * backprop_x itself is also
+// broadcasting following the same rule.
+//
+// TODO(zhifengc): Adds support for n-ary (n >= 2).
+class BCast {
+ public:
+ // A vector of int32 representing the shape of tensor. The 0-th
+ // element is the outer-most dimension and the last element is the
+ // inner-most dimension. Note that we do not use TensorShape since
+ // it's more convenient to manipulate Vec directly for this module.
+ typedef std::vector<int64> Vec;
+
+ BCast(const Vec& x, const Vec& y);
+ ~BCast() {}
+
+ // Returns true iff two operands are compatible according to the
+ // broadcasting rule.
+ bool IsValid() const { return valid_; }
+
+ // If and only if IsValid(), the following fields can be used in
+ // implementing a broadcasted binary tensor operation according to
+ // the broadcasting rule.
+ const Vec& x_reshape() const { return x_reshape_; }
+ const Vec& x_bcast() const { return x_bcast_; }
+ const Vec& y_reshape() const { return y_reshape_; }
+ const Vec& y_bcast() const { return y_bcast_; }
+ const Vec& result_shape() const { return result_; }
+ const Vec& output_shape() const { return output_; }
+ const Vec& grad_x_reduce_idx() const { return grad_x_reduce_idx_; }
+ const Vec& grad_y_reduce_idx() const { return grad_y_reduce_idx_; }
+
+ private:
+ bool valid_ = true;
+ Vec x_reshape_;
+ Vec x_bcast_;
+ Vec y_reshape_;
+ Vec y_bcast_;
+ Vec result_;
+ Vec output_;
+ Vec grad_x_reduce_idx_;
+ Vec grad_y_reduce_idx_;
+
+ static void Reverse(Vec* shape);
+ static bool HasZero(const Vec& shape);
+
+ TF_DISALLOW_COPY_AND_ASSIGN(BCast);
+};
+
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_UTIL_BCAST_H_
diff --git a/tensorflow/core/util/bcast_test.cc b/tensorflow/core/util/bcast_test.cc
new file mode 100644
index 0000000000..02d18586d6
--- /dev/null
+++ b/tensorflow/core/util/bcast_test.cc
@@ -0,0 +1,226 @@
+#include "tensorflow/core/util/bcast.h"
+
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+namespace {
+
+string BCast(const tensorflow::BCast::Vec& x, const tensorflow::BCast::Vec& y) {
+ tensorflow::BCast b(x, y);
+ if (!b.IsValid()) {
+ return "invalid";
+ }
+ string ret;
+ strings::StrAppend(&ret, "[", str_util::Join(b.x_reshape(), ","), "]");
+ strings::StrAppend(&ret, "[", str_util::Join(b.x_bcast(), ","), "]");
+ strings::StrAppend(&ret, "[", str_util::Join(b.y_reshape(), ","), "]");
+ strings::StrAppend(&ret, "[", str_util::Join(b.y_bcast(), ","), "]");
+ strings::StrAppend(&ret, "[", str_util::Join(b.result_shape(), ","), "]");
+ strings::StrAppend(&ret, "[", str_util::Join(b.output_shape(), ","), "]");
+ strings::StrAppend(&ret, "[", str_util::Join(b.grad_x_reduce_idx(), ","),
+ "]");
+ strings::StrAppend(&ret, "[", str_util::Join(b.grad_y_reduce_idx(), ","),
+ "]");
+ return ret;
+}
+
+TEST(BCastTest, Invalid) {
+ EXPECT_EQ("invalid", BCast({5, 3, 2}, {3}));
+ EXPECT_EQ("invalid", BCast({5, 3, 2}, {2, 2}));
+ EXPECT_EQ("invalid", BCast({5, 3, 2}, {10, 1, 1}));
+ EXPECT_EQ("invalid", BCast({1, 2, 1, 2, 1, 2}, {2, 4, 2, 1, 2, 1}));
+}
+
+TEST(BCastTest, Basic_SameShape) {
+ // Effectively no broadcast needed.
+ EXPECT_EQ(BCast({11, 7, 5, 3, 2}, {11, 7, 5, 3, 2}),
+ "[2310][1][2310][1]"
+ "[2310]"
+ "[11,7,5,3,2]"
+ "[][]");
+}
+
+TEST(BCastTest, Basic_Scalar_Scalar) {
+ // Effectively it's a scalar and a scalar.
+ // [1, 1] [1]
+ EXPECT_EQ(BCast({1, 1}, {1}),
+ "[1][1][1][1]"
+ "[1]"
+ "[1,1]"
+ "[0,1][0,1]");
+
+ // [1] [1, 1]
+ EXPECT_EQ(BCast({1}, {1, 1}),
+ "[1][1][1][1]"
+ "[1]"
+ "[1,1]"
+ "[0,1][0,1]");
+}
+
+TEST(BCastTest, Basic_Tensor_Scalar) {
+ // Effectively it's a tensor and a scalar.
+ // [11, 7, 5, 3, 2] [1]
+ EXPECT_EQ(BCast({11, 7, 5, 3, 2}, {1}),
+ "[2310][1][1][2310]"
+ "[2310]"
+ "[11,7,5,3,2]"
+ "[][0,1,2,3,4]");
+
+ // [1] [11, 7, 5, 3, 2]
+ EXPECT_EQ(BCast({1}, {11, 7, 5, 3, 2}),
+ "[1][2310][2310][1]"
+ "[2310]"
+ "[11,7,5,3,2]"
+ "[0,1,2,3,4][]");
+}
+
+TEST(BCastTest, Basic_Tensor_With_DimSize_1_Scalar) {
+ // Effectively it's a tensor and a scalar.
+ // [11, 7, 5, 3, 2, 1] [1]
+ EXPECT_EQ(BCast({11, 7, 5, 3, 2, 1}, {1}),
+ "[2310][1][1][2310]"
+ "[2310]"
+ "[11,7,5,3,2,1]"
+ "[5][0,1,2,3,4,5]");
+
+ // [1] [11, 7, 5, 3, 2, 1]
+ EXPECT_EQ(BCast({1}, {11, 7, 5, 3, 2, 1}),
+ "[1][2310][2310][1]"
+ "[2310]"
+ "[11,7,5,3,2,1]"
+ "[0,1,2,3,4,5][5]");
+
+ // Effectively it's a tensor and a scalar.
+ // [11, 7, 5, 1, 1, 3, 2, 1] [1]
+ EXPECT_EQ(BCast({11, 7, 5, 1, 1, 3, 2, 1, 1}, {1}),
+ "[2310][1][1][2310]"
+ "[2310]"
+ "[11,7,5,1,1,3,2,1,1]"
+ "[3,4,7,8][0,1,2,3,4,5,6,7,8]");
+
+ // [1] [11, 7, 5, 1, 1, 3, 2, 1]
+ EXPECT_EQ(BCast({1}, {11, 7, 5, 1, 1, 3, 2, 1, 1}),
+ "[1][2310][2310][1]"
+ "[2310]"
+ "[11,7,5,1,1,3,2,1,1]"
+ "[0,1,2,3,4,5,6,7,8][3,4,7,8]");
+}
+
+TEST(BCastTest, Basic_Tensor_Vector) {
+ // [11, 7, 5, 3, 2] [2]
+ EXPECT_EQ(BCast({11, 7, 5, 3, 2}, {2}),
+ "[1155,2][1,1][1,2][1155,1]"
+ "[1155,2]"
+ "[11,7,5,3,2]"
+ "[][0,1,2,3]");
+
+ // [2] [11, 7, 5, 3, 2]
+ EXPECT_EQ(BCast({2}, {11, 7, 5, 3, 2}),
+ "[1,2][1155,1][1155,2][1,1]"
+ "[1155,2]"
+ "[11,7,5,3,2]"
+ "[0,1,2,3][]");
+}
+
+TEST(BCastTest, Basic_Tensor_Matrix) {
+ // [11, 7, 5, 3, 2] [3, 2]
+ EXPECT_EQ(BCast({11, 7, 5, 3, 2}, {3, 2}),
+ "[385,6][1,1][1,6][385,1]"
+ "[385,6]"
+ "[11,7,5,3,2]"
+ "[][0,1,2]");
+ // [3, 2] [11, 7, 5, 3, 2]
+ EXPECT_EQ(BCast({3, 2}, {11, 7, 5, 3, 2}),
+ "[1,6][385,1][385,6][1,1]"
+ "[385,6]"
+ "[11,7,5,3,2]"
+ "[0,1,2][]");
+}
+
+TEST(BCastTest, Basic_Tensor_Matrix_Column) {
+ // [11, 7, 5, 3, 2] [3, 1]
+ EXPECT_EQ(BCast({11, 7, 5, 3, 2}, {3, 1}),
+ "[385,3,2][1,1,1][1,3,1][385,1,2]"
+ "[385,3,2]"
+ "[11,7,5,3,2]"
+ "[][0,1,2,4]");
+
+ // [3, 1] [11, 7, 5, 3, 2]
+ EXPECT_EQ(BCast({3, 1}, {11, 7, 5, 3, 2}),
+ "[1,3,1][385,1,2][385,3,2][1,1,1]"
+ "[385,3,2]"
+ "[11,7,5,3,2]"
+ "[0,1,2,4][]");
+}
+
+TEST(BCastTest, Basic_Tensor_Matrix_As_Tensor) {
+ // [11, 7, 5, 3, 2] [7, 5, 1, 1]
+ EXPECT_EQ(BCast({11, 7, 5, 3, 2}, {7, 5, 1, 1}),
+ "[11,35,6][1,1,1][1,35,1][11,1,6]"
+ "[11,35,6]"
+ "[11,7,5,3,2]"
+ "[][0,3,4]");
+
+ // [7, 5, 1, 1] [11, 7, 5, 3, 2]
+ EXPECT_EQ(BCast({7, 5, 1, 1}, {11, 7, 5, 3, 2}),
+ "[1,35,1][11,1,6][11,35,6][1,1,1]"
+ "[11,35,6]"
+ "[11,7,5,3,2]"
+ "[0,3,4][]");
+}
+
+TEST(BCastTest, Complex_BCast_To_Each_Other) {
+ // Rare cases. x and y broadcast to each other. x and y are of
+ // different ranks.
+ // Can be verified in numpy as:
+ // import numpy as np
+ // x = np.arange(0,110).reshape([11,1,5,1,2])
+ // y = np.arange(0,21).reshape([7,1,3,1])
+ // np.shape(x + y)
+ // Out[.]: (11, 7, 5, 3, 2)
+ EXPECT_EQ(BCast({11, 1, 5, 1, 2}, {7, 1, 3, 1}),
+ "[11,1,5,1,2][1,7,1,3,1][1,7,1,3,1][11,1,5,1,2]"
+ "[11,7,5,3,2]"
+ "[11,7,5,3,2]"
+ "[1,3][0,2,4]");
+}
+
+TEST(BCastTest, TestZeroDimensionShape) {
+ EXPECT_EQ(BCast({2, 0, 5}, {5}),
+ "[0,5][1,1][1,5][0,1]"
+ "[0,5]"
+ "[2,0,5]"
+ "[][0,1]");
+ EXPECT_EQ(BCast({5}, {2, 0, 5}),
+ "[1,5][0,1][0,5][1,1]"
+ "[0,5]"
+ "[2,0,5]"
+ "[0,1][]");
+
+ EXPECT_EQ(BCast({2, 0, 3, 0, 5}, {5}),
+ "[0,5][1,1][1,5][0,1]"
+ "[0,5]"
+ "[2,0,3,0,5]"
+ "[][0,1,2,3]");
+ EXPECT_EQ(BCast({5}, {2, 0, 3, 0, 5}),
+ "[1,5][0,1][0,5][1,1]"
+ "[0,5]"
+ "[2,0,3,0,5]"
+ "[0,1,2,3][]");
+
+ EXPECT_EQ(BCast({2, 0, 3, 0, 5}, {3, 1, 5}),
+ "[0,3,0,5][1,1,1,1][1,3,1,5][0,1,0,1]"
+ "[0,3,0,5]"
+ "[2,0,3,0,5]"
+ "[][0,1,3]");
+ EXPECT_EQ(BCast({3, 1, 5}, {2, 0, 3, 0, 5}),
+ "[1,3,1,5][0,1,0,1][0,3,0,5][1,1,1,1]"
+ "[0,3,0,5]"
+ "[2,0,3,0,5]"
+ "[0,1,3][]");
+}
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/util/device_name_utils.cc b/tensorflow/core/util/device_name_utils.cc
new file mode 100644
index 0000000000..b8c6a77dd0
--- /dev/null
+++ b/tensorflow/core/util/device_name_utils.cc
@@ -0,0 +1,338 @@
+#include "tensorflow/core/util/device_name_utils.h"
+
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace tensorflow {
+
+static bool IsAlpha(char c) {
+ return (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z');
+}
+
+static bool IsAlphaNum(char c) { return IsAlpha(c) || (c >= '0' && c <= '9'); }
+
+// Returns true iff "in" is a valid job name.
+static bool IsJobName(StringPiece in) {
+ if (in.empty()) return false;
+ if (!IsAlpha(in[0])) return false;
+ for (size_t i = 1; i < in.size(); ++i) {
+ if (!(IsAlphaNum(in[i]) || in[i] == '_')) return false;
+ }
+ return true;
+}
+
+// Returns true and fills in "*job" iff "*in" starts with a job name.
+static bool ConsumeJobName(StringPiece* in, string* job) {
+ if (in->empty()) return false;
+ if (!IsAlpha((*in)[0])) return false;
+ size_t i = 1;
+ for (; i < in->size(); ++i) {
+ const char c = (*in)[i];
+ if (c == '/') break;
+ if (!(IsAlphaNum(c) || c == '_')) {
+ return false;
+ }
+ }
+ job->assign(in->data(), i);
+ in->remove_prefix(i);
+ return true;
+}
+
+// Returns true and fills in "*device_type" iff "*in" starts with a device type
+// name.
+static bool ConsumeDeviceType(StringPiece* in, string* device_type) {
+ if (in->empty()) return false;
+ if (!IsAlpha((*in)[0])) return false;
+ size_t i = 1;
+ for (; i < in->size(); ++i) {
+ const char c = (*in)[i];
+ if (c == '/' || c == ':') break;
+ if (!(IsAlphaNum(c) || c == '_')) {
+ return false;
+ }
+ }
+ device_type->assign(in->data(), i);
+ in->remove_prefix(i);
+ return true;
+}
+
+// Returns true and fills in "*val" iff "*in" starts with a decimal
+// number.
+static bool ConsumeNumber(StringPiece* in, int* val) {
+ uint64 tmp;
+ if (str_util::ConsumeLeadingDigits(in, &tmp)) {
+ *val = tmp;
+ return true;
+ } else {
+ return false;
+ }
+}
+
+/* static */
+string DeviceNameUtils::FullName(const string& job, int replica, int task,
+ const string& type, int id) {
+ CHECK(IsJobName(job)) << job;
+ CHECK_LE(0, replica);
+ CHECK_LE(0, task);
+ CHECK(!type.empty());
+ CHECK_LE(0, id);
+ return strings::StrCat("/job:", job, "/replica:", replica, "/task:", task,
+ "/device:", type, ":", id);
+}
+
+bool DeviceNameUtils::ParseFullName(StringPiece fullname, ParsedName* p) {
+ p->Clear();
+ if (fullname == "/") {
+ return true;
+ }
+ StringPiece tmp;
+ while (!fullname.empty()) {
+ if (str_util::ConsumePrefix(&fullname, "/job:")) {
+ p->has_job = !str_util::ConsumePrefix(&fullname, "*");
+ if (p->has_job && !ConsumeJobName(&fullname, &p->job)) {
+ return false;
+ }
+ } else if (str_util::ConsumePrefix(&fullname, "/replica:")) {
+ p->has_replica = !str_util::ConsumePrefix(&fullname, "*");
+ if (p->has_replica && !ConsumeNumber(&fullname, &p->replica)) {
+ return false;
+ }
+ } else if (str_util::ConsumePrefix(&fullname, "/task:")) {
+ p->has_task = !str_util::ConsumePrefix(&fullname, "*");
+ if (p->has_task && !ConsumeNumber(&fullname, &p->task)) {
+ return false;
+ }
+ } else if (str_util::ConsumePrefix(&fullname, "/device:")) {
+ p->has_type = !str_util::ConsumePrefix(&fullname, "*");
+ if (p->has_type && !ConsumeDeviceType(&fullname, &p->type)) {
+ return false;
+ }
+ if (!str_util::ConsumePrefix(&fullname, ":")) {
+ p->has_id = false;
+ } else {
+ p->has_id = !str_util::ConsumePrefix(&fullname, "*");
+ if (p->has_id && !ConsumeNumber(&fullname, &p->id)) {
+ return false;
+ }
+ }
+
+ } else if (str_util::ConsumePrefix(&fullname, "/cpu:") ||
+ str_util::ConsumePrefix(&fullname, "/CPU:")) {
+ p->has_type = true;
+ p->type = "CPU"; // Treat '/cpu:..' as uppercase '/device:CPU:...'
+ p->has_id = !str_util::ConsumePrefix(&fullname, "*");
+ if (p->has_id && !ConsumeNumber(&fullname, &p->id)) {
+ return false;
+ }
+ } else if (str_util::ConsumePrefix(&fullname, "/gpu:") ||
+ str_util::ConsumePrefix(&fullname, "/GPU:")) {
+ p->has_type = true;
+ p->type = "GPU"; // Treat '/gpu:..' as uppercase '/device:GPU:...'
+ p->has_id = !str_util::ConsumePrefix(&fullname, "*");
+ if (p->has_id && !ConsumeNumber(&fullname, &p->id)) {
+ return false;
+ }
+ } else {
+ return false;
+ }
+ }
+ return true;
+}
+
+/* static */
+string DeviceNameUtils::ParsedNameToString(const ParsedName& pn) {
+ string buf;
+ if (pn.has_job) strings::StrAppend(&buf, "/job:", pn.job);
+ if (pn.has_replica) strings::StrAppend(&buf, "/replica:", pn.replica);
+ if (pn.has_task) strings::StrAppend(&buf, "/task:", pn.task);
+ if (pn.has_type) {
+ strings::StrAppend(&buf, "/", pn.type, ":");
+ if (pn.has_id) {
+ strings::StrAppend(&buf, pn.id);
+ } else {
+ strings::StrAppend(&buf, "*");
+ }
+ }
+ return buf;
+}
+
+/* static */
+bool DeviceNameUtils::IsSpecification(const ParsedName& less_specific,
+ const ParsedName& more_specific) {
+ if (less_specific.has_job &&
+ (!more_specific.has_job || (less_specific.job != more_specific.job))) {
+ return false;
+ }
+ if (less_specific.has_replica &&
+ (!more_specific.has_replica ||
+ (less_specific.replica != more_specific.replica))) {
+ return false;
+ }
+ if (less_specific.has_task &&
+ (!more_specific.has_task || (less_specific.task != more_specific.task))) {
+ return false;
+ }
+ if (less_specific.has_type &&
+ (!more_specific.has_type || (less_specific.type != more_specific.type))) {
+ return false;
+ }
+ if (less_specific.has_id &&
+ (!more_specific.has_id || (less_specific.id != more_specific.id))) {
+ return false;
+ }
+ return true;
+}
+
+/* static */
+bool DeviceNameUtils::IsCompleteSpecification(const ParsedName& pattern,
+ const ParsedName& name) {
+ CHECK(name.has_job && name.has_replica && name.has_task && name.has_type &&
+ name.has_id);
+
+ if (pattern.has_job && (pattern.job != name.job)) return false;
+ if (pattern.has_replica && (pattern.replica != name.replica)) return false;
+ if (pattern.has_task && (pattern.task != name.task)) return false;
+ if (pattern.has_type && (pattern.type != name.type)) return false;
+ if (pattern.has_id && (pattern.id != name.id)) return false;
+ return true;
+}
+
+/* static */
+Status DeviceNameUtils::MergeDevNames(ParsedName* target,
+ const ParsedName& other,
+ bool allow_soft_placement) {
+ if (other.has_job) {
+ if (target->has_job && target->job != other.job) {
+ return errors::InvalidArgument(
+ "Cannot merge devices with incompatible jobs: '",
+ ParsedNameToString(*target), "' and '", ParsedNameToString(other),
+ "'");
+ } else {
+ target->has_job = other.has_job;
+ target->job = other.job;
+ }
+ }
+
+ if (other.has_replica) {
+ if (target->has_replica && target->replica != other.replica) {
+ return errors::InvalidArgument(
+ "Cannot merge devices with incompatible replicas: '",
+ ParsedNameToString(*target), "' and '", ParsedNameToString(other),
+ "'");
+ } else {
+ target->has_replica = other.has_replica;
+ target->replica = other.replica;
+ }
+ }
+
+ if (other.has_task) {
+ if (target->has_task && target->task != other.task) {
+ return errors::InvalidArgument(
+ "Cannot merge devices with incompatible tasks: '",
+ ParsedNameToString(*target), "' and '", ParsedNameToString(other),
+ "'");
+ } else {
+ target->has_task = other.has_task;
+ target->task = other.task;
+ }
+ }
+
+ if (other.has_type) {
+ if (target->has_type && target->type != other.type) {
+ if (!allow_soft_placement) {
+ return errors::InvalidArgument(
+ "Cannot merge devices with incompatible types: '",
+ ParsedNameToString(*target), "' and '", ParsedNameToString(other),
+ "'");
+ } else {
+ target->has_id = false;
+ target->has_type = false;
+ return Status::OK();
+ }
+ } else {
+ target->has_type = other.has_type;
+ target->type = other.type;
+ }
+ }
+
+ if (other.has_id) {
+ if (target->has_id && target->id != other.id) {
+ if (!allow_soft_placement) {
+ return errors::InvalidArgument(
+ "Cannot merge devices with incompatible ids: '",
+ ParsedNameToString(*target), "' and '", ParsedNameToString(other),
+ "'");
+ } else {
+ target->has_id = false;
+ return Status::OK();
+ }
+ } else {
+ target->has_id = other.has_id;
+ target->id = other.id;
+ }
+ }
+
+ return Status::OK();
+}
+
+/* static */
+bool DeviceNameUtils::IsSameAddressSpace(const ParsedName& a,
+ const ParsedName& b) {
+ return (a.has_job && b.has_job && (a.job == b.job)) &&
+ (a.has_replica && b.has_replica && (a.replica == b.replica)) &&
+ (a.has_task && b.has_task && (a.task == b.task));
+}
+
+/* static */
+bool DeviceNameUtils::IsSameAddressSpace(StringPiece src, StringPiece dst) {
+ ParsedName x;
+ ParsedName y;
+ return ParseFullName(src, &x) && ParseFullName(dst, &y) &&
+ IsSameAddressSpace(x, y);
+}
+
+/* static */
+string DeviceNameUtils::LocalName(StringPiece type, int id) {
+ return strings::StrCat(type, ":", id);
+}
+
+/* static */
+string DeviceNameUtils::LocalName(StringPiece fullname) {
+ ParsedName x;
+ CHECK(ParseFullName(fullname, &x)) << fullname;
+ return LocalName(x.type, x.id);
+}
+
+/* static */
+bool DeviceNameUtils::ParseLocalName(StringPiece name, ParsedName* p) {
+ ParsedName x;
+ if (!ConsumeDeviceType(&name, &p->type)) {
+ return false;
+ }
+ if (!str_util::ConsumePrefix(&name, ":")) {
+ return false;
+ }
+ if (!ConsumeNumber(&name, &p->id)) {
+ return false;
+ }
+ return name.empty();
+}
+
+/* static */
+bool DeviceNameUtils::SplitDeviceName(StringPiece name, string* task,
+ string* device) {
+ ParsedName pn;
+ if (ParseFullName(name, &pn) && pn.has_type && pn.has_id) {
+ *task = strings::StrCat(
+ (pn.has_job ? strings::StrCat("/job:", pn.job) : ""),
+ (pn.has_replica ? strings::StrCat("/replica:", pn.replica) : ""),
+ (pn.has_task ? strings::StrCat("/task:", pn.task) : ""));
+ *device = strings::StrCat(pn.type, ":", pn.id);
+ return true;
+ }
+ return false;
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/util/device_name_utils.h b/tensorflow/core/util/device_name_utils.h
new file mode 100644
index 0000000000..8b0a24ed0d
--- /dev/null
+++ b/tensorflow/core/util/device_name_utils.h
@@ -0,0 +1,141 @@
+#ifndef TENSORFLOW_UTIL_DEVICE_NAME_UTILS_H_
+#define TENSORFLOW_UTIL_DEVICE_NAME_UTILS_H_
+
+#include <string>
+
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/public/status.h"
+
+namespace tensorflow {
+
+// In TensorFlow a device name is a string of the following form:
+// /job:<name>/replica:<replica>/task:<task>/device:<type>:<device_num>
+//
+// <name> is a short identifier conforming to the regexp
+// [a-zA-Z][_a-zA-Z]*
+// <type> is a supported device type (e.g. 'cpu' or 'gpu')
+// <replica>, <task>, <device_num> are small non-negative integers and are
+// densely allocated (except in tests).
+//
+// For some purposes, we also allow device patterns, which can specify
+// some or none of the specific fields above, with missing components,
+// or "<component>:*" indicating "any value allowed for that component.
+//
+// For example:
+// "/job:param_server" - Consider any devices in the "param_server" job
+// "/device:cpu:*" - Consider any cpu devices in any job/task/replica
+// "/job:*/replica:*/task:*/device:cpu:*" - Consider any cpu devices in any
+// job/task/replica
+// "/job:w/replica:0/task:0/device:gpu:*" - Consider any gpu devices in
+// replica 0, task 0, of job "w"
+class DeviceNameUtils {
+ public:
+ // Returns a fully qualified device name given the parameters.
+ static string FullName(const string& job, int replica, int task,
+ const string& type, int id);
+
+ struct ParsedName {
+ void Clear() {
+ has_job = false;
+ has_replica = false;
+ has_task = false;
+ has_type = false;
+ has_id = false;
+ job.clear();
+ replica = 0;
+ task = 0;
+ type.clear();
+ id = 0;
+ }
+
+ bool operator==(const ParsedName& other) const {
+ return (has_job ? (other.has_job && job == other.job) : !other.has_job) &&
+ (has_replica ? (other.has_replica && replica == other.replica)
+ : !other.has_replica) &&
+ (has_task ? (other.has_task && task == other.task)
+ : !other.has_task) &&
+ (has_type ? (other.has_type && type == other.type)
+ : !other.has_type) &&
+ (has_id ? (other.has_id && id == other.id) : !other.has_id);
+ }
+
+ bool has_job = false;
+ string job;
+ bool has_replica = false;
+ int replica = 0;
+ bool has_task = false;
+ int task = 0;
+ bool has_type = false;
+ string type;
+ bool has_id = false;
+ int id = 0;
+ };
+ // Parses "fullname" into "*parsed". Returns true iff succeeds.
+ static bool ParseFullName(StringPiece fullname, ParsedName* parsed);
+
+ // Returns true if "name" specifies any non-trivial constraint on the device.
+ static bool HasSomeDetails(const ParsedName& name) {
+ return name.has_job || name.has_replica || name.has_task || name.has_type ||
+ name.has_id;
+ }
+
+ // Returns true if more_specific is a specification of
+ // less_specific, i.e. everywhere that less-specific has a
+ // non-wildcard component value, more_specific has the same value
+ // for that component.
+ static bool IsSpecification(const ParsedName& less_specific,
+ const ParsedName& more_specific);
+
+ // Like IsSpecification, but the second argument "name" must have a
+ // non-wildcard value for all of its components.
+ static bool IsCompleteSpecification(const ParsedName& pattern,
+ const ParsedName& name);
+
+ // True iff there exists any possible complete device name that is
+ // a specification of both "a" and "b".
+ static inline bool AreCompatibleDevNames(const ParsedName& a,
+ const ParsedName& b) {
+ return IsSpecification(a, b) || IsSpecification(b, a);
+ }
+
+ // Merges the device specifications in "*target" and "other", and
+ // stores the result in "*target". Returns OK if "*target" and
+ // "other" are compatible, otherwise returns an error.
+ static Status MergeDevNames(ParsedName* target, const ParsedName& other) {
+ return MergeDevNames(target, other, false);
+ }
+ static Status MergeDevNames(ParsedName* target, const ParsedName& other,
+ bool allow_soft_placement);
+
+ // Returns true iff devices identified by 'src' and 'dst' are in the
+ // same address space.
+ static bool IsSameAddressSpace(StringPiece src, StringPiece dst);
+ static bool IsSameAddressSpace(const ParsedName& src, const ParsedName& dst);
+
+ // Returns the local device given its "type" and "id".
+ static string LocalName(StringPiece type, int id);
+
+ // Returns a short local device name (cpu:0, gpu:1, etc) based on
+ // the given fullname.
+ static string LocalName(StringPiece fullname);
+
+ // If "name" is a valid local device name (cpu:0, gpu:1, etc.),
+ // fills in parsed.type and parsed.id accordingly. Returns true iff
+ // succeeds.
+ static bool ParseLocalName(StringPiece name, ParsedName* parsed);
+
+ // Splits a fully-qualified device name into a task identifier and a
+ // relative device identifier. It first parses "name" using
+ // ParseFullName(), then assigns *task with everything except for
+ // the local device component, and assigns the relative device
+ // component into *device. This function will still return true if
+ // the task component is empty, but it requires the relative device
+ // component to be fully specified.
+ static bool SplitDeviceName(StringPiece name, string* task, string* device);
+
+ static string ParsedNameToString(const ParsedName& pn);
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_UTIL_DEVICE_NAME_UTILS_H_
diff --git a/tensorflow/core/util/device_name_utils_test.cc b/tensorflow/core/util/device_name_utils_test.cc
new file mode 100644
index 0000000000..14f30d6de5
--- /dev/null
+++ b/tensorflow/core/util/device_name_utils_test.cc
@@ -0,0 +1,369 @@
+#include "tensorflow/core/util/device_name_utils.h"
+
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+
+TEST(DeviceNameUtilsTest, Basic) {
+ EXPECT_EQ(DeviceNameUtils::FullName("hello", 1, 2, "CPU", 3),
+ "/job:hello/replica:1/task:2/device:CPU:3");
+
+ {
+ DeviceNameUtils::ParsedName p;
+ EXPECT_FALSE(DeviceNameUtils::ParseFullName("foobar", &p));
+ EXPECT_FALSE(
+ DeviceNameUtils::ParseFullName("/job:123/replica:1/task:2/gpu:3", &p));
+ EXPECT_FALSE(
+ DeviceNameUtils::ParseFullName("/job:123/replica:1/task:2/gpu:", &p));
+ EXPECT_FALSE(DeviceNameUtils::ParseFullName(
+ "/job:123/replica:1/task:2/device:gpu:", &p));
+ EXPECT_FALSE(
+ DeviceNameUtils::ParseFullName("/job:foo/replica:-1/task:2/gpu:3", &p));
+ EXPECT_FALSE(
+ DeviceNameUtils::ParseFullName("/job:foo/replica:1/task:-2/gpu:3", &p));
+ EXPECT_FALSE(
+ DeviceNameUtils::ParseFullName("/job:foo/replica:1/task:2/bar:3", &p));
+ EXPECT_FALSE(DeviceNameUtils::ParseFullName(
+ "/job:foo/replica:1/task:2/gpu:3/extra", &p));
+ EXPECT_TRUE(
+ DeviceNameUtils::ParseFullName("/job:foo/replica:1/task:2/gpu:3", &p));
+ EXPECT_TRUE(p.has_job);
+ EXPECT_TRUE(p.has_replica);
+ EXPECT_TRUE(p.has_task);
+ EXPECT_TRUE(p.has_type);
+ EXPECT_TRUE(p.has_id);
+ EXPECT_EQ(p.job, "foo");
+ EXPECT_EQ(p.replica, 1);
+ EXPECT_EQ(p.task, 2);
+ EXPECT_EQ(p.type, "GPU");
+ EXPECT_EQ(p.id, 3);
+ }
+ {
+ // Allow _ in job names.
+ DeviceNameUtils::ParsedName p;
+ EXPECT_TRUE(DeviceNameUtils::ParseFullName(
+ "/job:foo_bar/replica:1/task:2/gpu:3", &p));
+ EXPECT_TRUE(p.has_job);
+ EXPECT_TRUE(p.has_replica);
+ EXPECT_TRUE(p.has_task);
+ EXPECT_TRUE(p.has_type);
+ EXPECT_TRUE(p.has_id);
+ EXPECT_EQ(p.job, "foo_bar");
+ EXPECT_EQ(p.replica, 1);
+ EXPECT_EQ(p.task, 2);
+ EXPECT_EQ(p.type, "GPU");
+ EXPECT_EQ(p.id, 3);
+ }
+ {
+ // Allow _ in job names.
+ DeviceNameUtils::ParsedName p;
+ EXPECT_TRUE(DeviceNameUtils::ParseFullName(
+ "/job:foo_bar/replica:1/task:2/device:GPU:3", &p));
+ EXPECT_TRUE(p.has_job);
+ EXPECT_TRUE(p.has_replica);
+ EXPECT_TRUE(p.has_task);
+ EXPECT_TRUE(p.has_type);
+ EXPECT_TRUE(p.has_id);
+ EXPECT_EQ(p.job, "foo_bar");
+ EXPECT_EQ(p.replica, 1);
+ EXPECT_EQ(p.task, 2);
+ EXPECT_EQ(p.type, "GPU");
+ EXPECT_EQ(p.id, 3);
+ }
+ {
+ DeviceNameUtils::ParsedName p;
+ EXPECT_TRUE(DeviceNameUtils::ParseFullName("/job:*/replica:4/gpu:*", &p));
+ EXPECT_FALSE(p.has_job);
+ EXPECT_TRUE(p.has_replica);
+ EXPECT_FALSE(p.has_task);
+ EXPECT_TRUE(p.has_type);
+ EXPECT_FALSE(p.has_id);
+ EXPECT_EQ(p.replica, 4);
+ EXPECT_EQ(p.type, "GPU");
+ }
+ {
+ DeviceNameUtils::ParsedName p;
+ EXPECT_TRUE(
+ DeviceNameUtils::ParseFullName("/job:*/replica:4/device:GPU:*", &p));
+ EXPECT_FALSE(p.has_job);
+ EXPECT_TRUE(p.has_replica);
+ EXPECT_FALSE(p.has_task);
+ EXPECT_TRUE(p.has_type);
+ EXPECT_FALSE(p.has_id);
+ EXPECT_EQ(p.replica, 4);
+ EXPECT_EQ(p.type, "GPU");
+ }
+ {
+ DeviceNameUtils::ParsedName p;
+ EXPECT_TRUE(
+ DeviceNameUtils::ParseFullName("/job:*/device:GPU/replica:4", &p));
+ EXPECT_FALSE(p.has_job);
+ EXPECT_TRUE(p.has_replica);
+ EXPECT_FALSE(p.has_task);
+ EXPECT_TRUE(p.has_type);
+ EXPECT_FALSE(p.has_id);
+ EXPECT_EQ(p.replica, 4);
+ EXPECT_EQ(p.type, "GPU");
+ }
+ {
+ DeviceNameUtils::ParsedName p;
+ EXPECT_TRUE(DeviceNameUtils::ParseFullName(
+ "/job:*/replica:4/device:myspecialdevice:13", &p));
+ EXPECT_FALSE(p.has_job);
+ EXPECT_TRUE(p.has_replica);
+ EXPECT_FALSE(p.has_task);
+ EXPECT_TRUE(p.has_type);
+ EXPECT_TRUE(p.has_id);
+ EXPECT_EQ(p.replica, 4);
+ EXPECT_EQ(p.type, "myspecialdevice");
+ EXPECT_EQ(p.id, 13);
+ }
+ {
+ DeviceNameUtils::ParsedName p;
+ EXPECT_TRUE(DeviceNameUtils::ParseFullName("/", &p));
+ EXPECT_FALSE(p.has_job);
+ EXPECT_FALSE(p.has_replica);
+ EXPECT_FALSE(p.has_task);
+ EXPECT_FALSE(p.has_type);
+ EXPECT_FALSE(p.has_id);
+ }
+ {
+ DeviceNameUtils::ParsedName p;
+ EXPECT_TRUE(DeviceNameUtils::ParseFullName("/job:*/replica:4/gpu:5", &p));
+ EXPECT_FALSE(p.has_job);
+ EXPECT_TRUE(p.has_replica);
+ EXPECT_FALSE(p.has_task);
+ EXPECT_TRUE(p.has_type);
+ EXPECT_TRUE(p.has_id);
+ EXPECT_EQ(p.replica, 4);
+ EXPECT_EQ(p.type, "GPU");
+ EXPECT_EQ(p.id, 5);
+ }
+ { // Same result if we reorder the components
+ DeviceNameUtils::ParsedName p;
+ EXPECT_TRUE(DeviceNameUtils::ParseFullName("/gpu:*/job:*/replica:4", &p));
+ EXPECT_FALSE(p.has_job);
+ EXPECT_TRUE(p.has_replica);
+ EXPECT_FALSE(p.has_task);
+ EXPECT_TRUE(p.has_type);
+ EXPECT_FALSE(p.has_id);
+ EXPECT_EQ(p.replica, 4);
+ EXPECT_EQ(p.type, "GPU");
+ }
+
+ EXPECT_TRUE(DeviceNameUtils::IsSameAddressSpace(
+ "/job:foo/replica:1/task:2/cpu:3", "/job:foo/replica:1/task:2/gpu:4"));
+ EXPECT_FALSE(DeviceNameUtils::IsSameAddressSpace(
+ "/job:foo/replica:1/task:2/cpu:3", "/job:foo/replica:1/task:3/gpu:4"));
+ EXPECT_FALSE(DeviceNameUtils::IsSameAddressSpace(
+ "/job:foo/replica:1/task:2/cpu:3", "/job:foo/replica:10/task:2/gpu:4"));
+ EXPECT_FALSE(DeviceNameUtils::IsSameAddressSpace(
+ "/job:foo/replica:1/task:2/cpu:3", "/job:bar/replica:1/task:2/gpu:4"));
+
+ EXPECT_EQ(DeviceNameUtils::LocalName("CPU", 1), "CPU:1");
+ EXPECT_EQ(DeviceNameUtils::LocalName("GPU", 2), "GPU:2");
+ EXPECT_EQ(DeviceNameUtils::LocalName("MySpecialDevice", 13),
+ "MySpecialDevice:13");
+
+ EXPECT_EQ(
+ DeviceNameUtils::LocalName("/job:foo/replica:1/task:2/device:CPU:3"),
+ "CPU:3");
+
+ EXPECT_EQ(DeviceNameUtils::LocalName("/job:foo/replica:1/task:2/cpu:3"),
+ "CPU:3");
+
+ EXPECT_EQ(
+ DeviceNameUtils::LocalName("/job:foo/replica:1/task:2/device:abc:73"),
+ "abc:73");
+
+ {
+ DeviceNameUtils::ParsedName p;
+ EXPECT_TRUE(DeviceNameUtils::ParseLocalName("CPU:10", &p));
+ EXPECT_EQ(p.type, "CPU");
+ EXPECT_EQ(p.id, 10);
+ EXPECT_FALSE(DeviceNameUtils::ParseLocalName("cpu:abc", &p));
+ EXPECT_FALSE(DeviceNameUtils::ParseLocalName("abc:", &p));
+ EXPECT_FALSE(DeviceNameUtils::ParseLocalName("abc", &p));
+ EXPECT_FALSE(DeviceNameUtils::ParseLocalName("myspecialdevice", &p));
+ }
+}
+
+static bool IsCSHelper(StringPiece pattern, StringPiece actual) {
+ DeviceNameUtils::ParsedName p, a;
+ EXPECT_TRUE(DeviceNameUtils::ParseFullName(pattern, &p));
+ EXPECT_TRUE(DeviceNameUtils::ParseFullName(actual, &a));
+ return DeviceNameUtils::IsCompleteSpecification(p, a);
+}
+
+TEST(DeviceNameUtilsTest, IsCompleteSpecification) {
+ EXPECT_TRUE(IsCSHelper("/job:*", "/job:work/replica:1/task:2/gpu:3"));
+ EXPECT_TRUE(
+ IsCSHelper("/job:*/replica:*", "/job:work/replica:1/task:2/gpu:3"));
+ EXPECT_TRUE(IsCSHelper("/job:*/task:*", "/job:work/replica:1/task:2/gpu:3"));
+ EXPECT_TRUE(IsCSHelper("/job:*/replica:*/task:*",
+ "/job:work/replica:1/task:2/gpu:3"));
+ EXPECT_TRUE(
+ IsCSHelper("/job:*/replica:*/gpu:*", "/job:work/replica:1/task:2/gpu:3"));
+ EXPECT_FALSE(IsCSHelper("/cpu:*", "/job:worker/replica:1/task:2/gpu:3"));
+ EXPECT_FALSE(IsCSHelper("/gpu:2", "/job:worker/replica:1/task:2/gpu:1"));
+ EXPECT_TRUE(IsCSHelper("/gpu:*", "/job:worker/replica:1/task:2/gpu:3"));
+}
+
+static bool IsSpecHelper(StringPiece pattern, StringPiece actual) {
+ DeviceNameUtils::ParsedName p, a;
+ EXPECT_TRUE(DeviceNameUtils::ParseFullName(pattern, &p));
+ EXPECT_TRUE(DeviceNameUtils::ParseFullName(actual, &a));
+ return DeviceNameUtils::IsSpecification(p, a);
+}
+
+TEST(DeviceNameUtilsTest, IsSpecification) {
+ EXPECT_TRUE(IsSpecHelper("/job:*", "/job:work/replica:1/task:2/gpu:3"));
+ EXPECT_TRUE(IsSpecHelper("/job:*", "/job:work/replica:1/gpu:3"));
+ EXPECT_TRUE(IsSpecHelper("/job:*", "/job:work/replica:1"));
+ EXPECT_TRUE(IsSpecHelper("/job:*", "/replica:1"));
+ EXPECT_TRUE(IsSpecHelper("/job:*", "/job:work"));
+ EXPECT_TRUE(
+ IsSpecHelper("/job:*/replica:*", "/job:work/replica:1/task:2/gpu:3"));
+ EXPECT_TRUE(IsSpecHelper("/job:work/replica:1/gpu:*",
+ "/job:work/replica:1/task:2/gpu:3"));
+ EXPECT_TRUE(IsSpecHelper("/job:work/replica:1/gpu:3",
+ "/job:work/replica:1/task:2/gpu:3"));
+ EXPECT_TRUE(IsSpecHelper("/job:work/replica:1/task:2",
+ "/job:work/replica:1/task:2/gpu:3"));
+ EXPECT_TRUE(IsSpecHelper("/job:work/replica:*/task:2",
+ "/job:work/replica:1/task:2/gpu:3"));
+ EXPECT_TRUE(IsSpecHelper("/task:*", "/job:*/replica:1/task:2/gpu:3"));
+ EXPECT_TRUE(IsSpecHelper("/task:2", "/job:*/replica:1/task:2/gpu:3"));
+ EXPECT_TRUE(IsSpecHelper("/cpu:*", "/job:*/replica:1/task:2/cpu:1"));
+ EXPECT_TRUE(IsSpecHelper("/cpu:0", "/cpu:0"));
+ EXPECT_TRUE(IsSpecHelper("/gpu:*", "/job:worker/replica:1/task:2/gpu:3"));
+
+ EXPECT_FALSE(IsSpecHelper("/job:worker/replica:1/task:2/gpu:3", "/gpu:*"));
+ EXPECT_FALSE(IsSpecHelper("/cpu:*", "/job:*/replica:1/task:2"));
+ EXPECT_FALSE(IsSpecHelper("/cpu:*", "/job:*/replica:1/task:2/gpu:1"));
+ EXPECT_FALSE(IsSpecHelper("/cpu:*", "/job:worker/replica:1/task:2/gpu:3"));
+ EXPECT_FALSE(IsSpecHelper("/gpu:2", "/job:worker/replica:1/task:2/gpu:1"));
+ EXPECT_FALSE(IsSpecHelper("/job:work/replica:*/task:0",
+ "/job:work/replica:1/task:2/gpu:3"));
+ EXPECT_FALSE(IsSpecHelper("/job:work/replica:0/task:2",
+ "/job:work/replica:*/task:2/gpu:3"));
+}
+
+TEST(DeviceNameUtilsTest, SplitDeviceName) {
+ string task;
+ string device;
+ EXPECT_TRUE(DeviceNameUtils::SplitDeviceName(
+ "/job:foo/replica:1/task:2/cpu:1", &task, &device));
+ EXPECT_EQ("/job:foo/replica:1/task:2", task);
+ EXPECT_EQ("CPU:1", device);
+ EXPECT_TRUE(DeviceNameUtils::SplitDeviceName(
+ "/job:foo/cpu:1/task:2/replica:1", &task, &device));
+ EXPECT_EQ("/job:foo/replica:1/task:2", task);
+ EXPECT_EQ("CPU:1", device);
+ EXPECT_TRUE(DeviceNameUtils::SplitDeviceName("/gpu:3", &task, &device));
+ EXPECT_EQ("", task);
+ EXPECT_EQ("GPU:3", device);
+ EXPECT_FALSE(DeviceNameUtils::SplitDeviceName("gpu:3", &task, &device));
+ EXPECT_FALSE(DeviceNameUtils::SplitDeviceName("/job:foo/task:2/replica:1",
+ &task, &device));
+ EXPECT_TRUE(DeviceNameUtils::SplitDeviceName("/device:myspecialdevice:3",
+ &task, &device));
+ EXPECT_EQ("", task);
+ EXPECT_EQ("myspecialdevice:3", device);
+}
+
+static DeviceNameUtils::ParsedName Name(const string& str) {
+ DeviceNameUtils::ParsedName ret;
+ CHECK(DeviceNameUtils::ParseFullName(str, &ret)) << "Invalid name: " << str;
+ return ret;
+}
+
+static void MergeDevNamesHelperImpl(const string& name_a, const string& name_b,
+ const string& expected_merge_name,
+ bool allow_soft_placement) {
+ DeviceNameUtils::ParsedName target_a = Name(name_a);
+ EXPECT_OK(DeviceNameUtils::MergeDevNames(&target_a, Name(name_b),
+ allow_soft_placement));
+ DeviceNameUtils::ParsedName target_b = Name(name_b);
+ EXPECT_OK(DeviceNameUtils::MergeDevNames(&target_b, Name(name_a),
+ allow_soft_placement));
+ EXPECT_EQ(target_a, target_b);
+ EXPECT_EQ(target_a, Name(expected_merge_name));
+ EXPECT_EQ(target_b, Name(expected_merge_name));
+}
+
+static void MergeDevNamesHelper(const string& name_a, const string& name_b,
+ const string& expected_merge_name) {
+ MergeDevNamesHelperImpl(name_a, name_b, expected_merge_name, false);
+}
+
+static void MergeDevNamesHelperAllowSoftPlacement(
+ const string& name_a, const string& name_b,
+ const string& expected_merge_name) {
+ MergeDevNamesHelperImpl(name_a, name_b, expected_merge_name, true);
+}
+
+static void MergeDevNamesError(const string& name_a, const string& name_b,
+ const string& expected_error_substr) {
+ DeviceNameUtils::ParsedName target_a = Name(name_a);
+ Status s = DeviceNameUtils::MergeDevNames(&target_a, Name(name_b));
+ EXPECT_EQ(s.code(), error::INVALID_ARGUMENT);
+ EXPECT_TRUE(StringPiece(s.error_message()).contains(expected_error_substr))
+ << s;
+}
+
+TEST(DeviceNameUtilsTest, MergeDevNames) {
+ DeviceNameUtils::ParsedName target;
+
+ // Idempotence tests.
+ MergeDevNamesHelper("", "", "");
+ MergeDevNamesHelper("/job:foo/replica:1/task:2/cpu:1",
+ "/job:foo/replica:1/task:2/cpu:1",
+ "/job:foo/replica:1/task:2/cpu:1");
+
+ // Merging with empty device has no effect.
+ MergeDevNamesHelper("", "/job:foo", "/job:foo");
+ MergeDevNamesHelper("", "/replica:2", "/replica:2");
+ MergeDevNamesHelper("", "/task:7", "/task:7");
+ // MergeDevNamesHelper("", "/gpu:1", "/gpu:1");
+
+ // Combining disjoint names.
+ MergeDevNamesHelper("/job:foo", "/task:7", "/job:foo/task:7");
+ MergeDevNamesHelper("/job:foo", "/gpu:1", "/job:foo/gpu:1");
+
+ // Combining overlapping names.
+ MergeDevNamesHelper("/job:foo/replica:0", "/replica:0/task:1",
+ "/job:foo/replica:0/task:1");
+
+ // Wildcard tests.
+ MergeDevNamesHelper("", "/gpu:*", "/gpu:*");
+ MergeDevNamesHelper("/gpu:*", "/gpu:*", "/gpu:*");
+ MergeDevNamesHelper("/gpu:1", "/gpu:*", "/gpu:1");
+
+ // Incompatible components.
+ MergeDevNamesError("/job:foo", "/job:bar", "incompatible jobs");
+ MergeDevNamesError("/replica:0", "/replica:1", "incompatible replicas");
+ MergeDevNamesError("/task:0", "/task:1", "incompatible tasks");
+ MergeDevNamesError("/gpu:*", "/cpu:*", "incompatible types");
+ MergeDevNamesError("/gpu:0", "/gpu:1", "incompatible ids");
+}
+
+TEST(DeviceNameUtilsTest, MergeDevNamesAllowSoftPlacement) {
+ // Incompatible components with allow_soft_placement.
+ MergeDevNamesHelperAllowSoftPlacement("/gpu:*", "/cpu:1", "");
+ MergeDevNamesHelperAllowSoftPlacement("/cpu:*", "/gpu:1", "");
+ MergeDevNamesHelperAllowSoftPlacement("/gpu:1", "/gpu:2", "/gpu:*");
+}
+
+static void BM_ParseFullName(int iters) {
+ DeviceNameUtils::ParsedName p;
+ while (iters--) {
+ DeviceNameUtils::ParseFullName("/job:worker/replica:3/task:0/cpu:0", &p);
+ }
+}
+BENCHMARK(BM_ParseFullName);
+
+} // namespace tensorflow
diff --git a/tensorflow/core/util/event.proto b/tensorflow/core/util/event.proto
new file mode 100644
index 0000000000..5d67823ce7
--- /dev/null
+++ b/tensorflow/core/util/event.proto
@@ -0,0 +1,29 @@
+syntax = "proto3";
+
+package tensorflow;
+// option cc_enable_arenas = true;
+
+import "tensorflow/core/framework/graph.proto";
+import "tensorflow/core/framework/summary.proto";
+
+// Protocol buffer representing an event that happened during
+// the execution of a Brain model.
+message Event {
+ // Timestamp of the event.
+ double wall_time = 1;
+
+ // Globale step of the event.
+ int64 step = 2;
+
+ oneof what {
+ // An event file was started, with the specified version.
+ // This is use to identify the contents of the record IO files
+ // easily. Current version is "tensorflow.Event:1". All versions
+ // start with "tensorflow.Event:".
+ string file_version = 3;
+ // A model was constructed.
+ GraphDef graph_def = 4;
+ // A summary was generated.
+ Summary summary = 5;
+ }
+}
diff --git a/tensorflow/core/util/events_writer.cc b/tensorflow/core/util/events_writer.cc
new file mode 100644
index 0000000000..1b34a36577
--- /dev/null
+++ b/tensorflow/core/util/events_writer.cc
@@ -0,0 +1,144 @@
+#include "tensorflow/core/util/events_writer.h"
+
+#include <stddef.h> // for NULL
+
+#include "tensorflow/core/lib/io/path.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/lib/strings/stringprintf.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/public/env.h"
+#include "tensorflow/core/public/status.h"
+#include "tensorflow/core/util/event.pb.h"
+
+namespace tensorflow {
+
+EventsWriter::EventsWriter(const string& file_prefix)
+ // TODO(jeff,sanjay): Pass in env and use that here instead of Env::Default
+ : env_(Env::Default()),
+ file_prefix_(file_prefix),
+ num_outstanding_events_(0) {}
+
+bool EventsWriter::Init() {
+ if (recordio_writer_.get() != nullptr) {
+ CHECK(!filename_.empty());
+ if (FileHasDisappeared()) {
+ // Warn user of data loss and let .reset() below do basic cleanup.
+ if (num_outstanding_events_ > 0) {
+ LOG(WARNING) << "Re-intialization, attempting to open a new file, "
+ << num_outstanding_events_ << " events will be lost.";
+ }
+ } else {
+ // No-op: File is present and writer is initialized.
+ return true;
+ }
+ }
+
+ int64 time_in_seconds = env_->NowMicros() / 1000000;
+
+ filename_ = strings::Printf(
+ "%s.out.tfevents.%010lld.%s", file_prefix_.c_str(),
+ static_cast<long long>(time_in_seconds), port::Hostname().c_str());
+ port::AdjustFilenameForLogging(&filename_);
+
+ WritableFile* file;
+ Status s = env_->NewWritableFile(filename_, &file);
+ if (!s.ok()) {
+ LOG(ERROR) << "Could not open events file: " << filename_ << ": " << s;
+ return false;
+ }
+ recordio_file_.reset(file);
+ recordio_writer_.reset(new io::RecordWriter(recordio_file_.get()));
+ if (recordio_writer_.get() == NULL) {
+ LOG(ERROR) << "Could not create record writer";
+ return false;
+ }
+ num_outstanding_events_ = 0;
+ VLOG(1) << "Successfully opened events file: " << filename_;
+ {
+ // Write the first event with the current version, and flush
+ // right away so the file contents will be easily determined.
+
+ Event event;
+ event.set_wall_time(time_in_seconds);
+ event.set_file_version(strings::StrCat(kVersionPrefix, kCurrentVersion));
+ WriteEvent(event);
+ Flush();
+ }
+ return true;
+}
+
+string EventsWriter::FileName() {
+ if (filename_.empty()) {
+ Init();
+ }
+ return filename_;
+}
+
+void EventsWriter::WriteSerializedEvent(const string& event_str) {
+ if (recordio_writer_.get() == NULL) {
+ if (!Init()) {
+ LOG(ERROR) << "Write failed because file could not be opened.";
+ return;
+ }
+ }
+ num_outstanding_events_++;
+ recordio_writer_->WriteRecord(event_str);
+}
+
+void EventsWriter::WriteEvent(const Event& event) {
+ string record;
+ event.AppendToString(&record);
+ WriteSerializedEvent(record);
+}
+
+bool EventsWriter::Flush() {
+ if (num_outstanding_events_ == 0) return true;
+ CHECK(recordio_file_.get() != NULL) << "Unexpected NULL file";
+ // The FileHasDisappeared() condition is necessary because
+ // recordio_writer_->Sync() can return true even if the underlying
+ // file has been deleted. EventWriter.FileDeletionBeforeWriting
+ // demonstrates this and will fail if the FileHasDisappeared()
+ // conditon is removed.
+ // Also, we deliberately attempt to Sync() before checking for a
+ // disappearing file, in case for some file system File::Exists() is
+ // false after File::Open() but before File::Sync().
+ if (!recordio_file_->Flush().ok() || !recordio_file_->Sync().ok() ||
+ FileHasDisappeared()) {
+ LOG(ERROR) << "Failed to flush " << num_outstanding_events_ << " events to "
+ << filename_;
+ return false;
+ }
+ VLOG(1) << "Wrote " << num_outstanding_events_ << " events to disk.";
+ num_outstanding_events_ = 0;
+ return true;
+}
+
+bool EventsWriter::Close() {
+ bool return_value = Flush();
+ if (recordio_file_.get() != NULL) {
+ Status s = recordio_file_->Close();
+ if (!s.ok()) {
+ LOG(ERROR) << "Error when closing previous event file: " << filename_
+ << ": " << s;
+ return_value = false;
+ }
+ recordio_writer_.reset(NULL);
+ recordio_file_.reset(NULL);
+ }
+ num_outstanding_events_ = 0;
+ return return_value;
+}
+
+bool EventsWriter::FileHasDisappeared() {
+ if (env_->FileExists(filename_)) {
+ return false;
+ } else {
+ // This can happen even with non-null recordio_writer_ if some other
+ // process has removed the file.
+ LOG(ERROR) << "The events file " << filename_ << " has disappeared.";
+ return true;
+ }
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/util/events_writer.h b/tensorflow/core/util/events_writer.h
new file mode 100644
index 0000000000..e6b94ad265
--- /dev/null
+++ b/tensorflow/core/util/events_writer.h
@@ -0,0 +1,77 @@
+#ifndef TENSORFLOW_UTIL_EVENTS_WRITER_H_
+#define TENSORFLOW_UTIL_EVENTS_WRITER_H_
+
+#include <memory>
+#include <string>
+#include "tensorflow/core/lib/io/record_writer.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/public/env.h"
+#include "tensorflow/core/util/event.pb.h"
+
+namespace tensorflow {
+
+class EventsWriter {
+ public:
+#ifndef SWIG
+ // Prefix of version string present in the first entry of every event file.
+ static constexpr const char* kVersionPrefix = "brain.Event:";
+ static constexpr const int kCurrentVersion = 1;
+#endif
+
+ // Events files typically have a name of the form
+ // '/some/file/path/my.file.out.events.[timestamp].[hostname]'
+ // To create and EventWriter, the user should provide file_prefix =
+ // '/some/file/path/my.file'
+ // The EventsWriter will append '.out.events.[timestamp].[hostname]'
+ // to the ultimate filename once Init() is called.
+ // Note that it is not recommended to simultaneously have two
+ // EventWriters writing to the same file_prefix.
+ explicit EventsWriter(const string& file_prefix);
+ ~EventsWriter() { Close(); } // Autoclose in destructor.
+
+ // Sets the event file filename and opens file for writing. If not called by
+ // user, will be invoked automatically by a call to FileName() or Write*().
+ // Returns false if the file could not be opened. Idempotent: if file exists
+ // and is open this is a no-op. If on the other hand the file was opened,
+ // but has since disappeared (e.g. deleted by another process), this will open
+ // a new file with a new timestamp in its filename.
+ bool Init();
+
+ // Returns the filename for the current events file:
+ // filename_ = [file_prefix_].out.events.[timestamp].[hostname]
+ string FileName();
+
+ // Append "event" to the file. The "tensorflow::" part is for swig happiness.
+ void WriteEvent(const tensorflow::Event& event);
+
+ // Append "event_str", a serialized Event, to the file.
+ // Note that this function does NOT check that de-serializing event_str
+ // results in a valid Event proto.
+ void WriteSerializedEvent(const string& event_str);
+
+ // EventWriter automatically flushes and closes on destruction, but
+ // these two methods are provided for users who want to write to disk sooner
+ // and/or check for success.
+ // Flush() pushes outstanding events to disk. Returns false if the
+ // events file could not be created, or if the file exists but could not
+ // be written too.
+ // Close() calls Flush() and then closes the current events file.
+ // Returns true only if both the flush and the closure were successful.
+ bool Flush();
+ bool Close();
+
+ private:
+ bool FileHasDisappeared(); // True if event_file_path_ does not exist.
+
+ Env* env_;
+ const string file_prefix_;
+ string filename_;
+ std::unique_ptr<WritableFile> recordio_file_;
+ std::unique_ptr<io::RecordWriter> recordio_writer_;
+ int num_outstanding_events_;
+ TF_DISALLOW_COPY_AND_ASSIGN(EventsWriter);
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_UTIL_EVENTS_WRITER_H_
diff --git a/tensorflow/core/util/events_writer_test.cc b/tensorflow/core/util/events_writer_test.cc
new file mode 100644
index 0000000000..f6523ead92
--- /dev/null
+++ b/tensorflow/core/util/events_writer_test.cc
@@ -0,0 +1,198 @@
+#include "tensorflow/core/util/events_writer.h"
+
+#include <math.h>
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/io/path.h"
+#include "tensorflow/core/lib/io/record_reader.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/public/env.h"
+#include "tensorflow/core/public/status.h"
+#include "tensorflow/core/util/event.pb.h"
+
+namespace tensorflow {
+namespace {
+
+// shorthand
+Env* env() { return Env::Default(); }
+
+void WriteSimpleValue(EventsWriter* writer, double wall_time, int64 step,
+ const string& tag, float simple_value) {
+ Event event;
+ event.set_wall_time(wall_time);
+ event.set_step(step);
+ Summary::Value* summ_val = event.mutable_summary()->add_value();
+ summ_val->set_tag(tag);
+ summ_val->set_simple_value(simple_value);
+ writer->WriteEvent(event);
+}
+
+void WriteFile(EventsWriter* writer) {
+ WriteSimpleValue(writer, 1234, 34, "foo", 3.14159);
+ WriteSimpleValue(writer, 2345, 35, "bar", -42);
+}
+
+static bool ReadEventProto(io::RecordReader* reader, uint64* offset,
+ Event* proto) {
+ string record;
+ Status s = reader->ReadRecord(offset, &record);
+ if (!s.ok()) {
+ return false;
+ }
+ return ParseProtoUnlimited(proto, record);
+}
+
+void VerifyFile(const string& filename) {
+ CHECK(env()->FileExists(filename));
+ RandomAccessFile* event_file;
+ TF_CHECK_OK(env()->NewRandomAccessFile(filename, &event_file));
+ io::RecordReader* reader = new io::RecordReader(event_file);
+
+ uint64 offset = 0;
+
+ Event actual;
+ CHECK(ReadEventProto(reader, &offset, &actual));
+ VLOG(1) << actual.ShortDebugString();
+ // Wall time should be within 5s of now.
+
+ double current_time = env()->NowMicros() / 1000000.0;
+ EXPECT_LT(fabs(actual.wall_time() - current_time), 5);
+ // Should have the current version number.
+ EXPECT_EQ(actual.file_version(),
+ strings::StrCat(EventsWriter::kVersionPrefix,
+ EventsWriter::kCurrentVersion));
+
+ Event expected;
+ CHECK(ReadEventProto(reader, &offset, &actual));
+ VLOG(1) << actual.ShortDebugString();
+ ASSERT_TRUE(protobuf::TextFormat::ParseFromString(
+ "wall_time: 1234 step: 34 "
+ "summary { value { tag: 'foo' simple_value: 3.14159 } }",
+ &expected));
+ // TODO(keveman): Enable this check
+ // EXPECT_THAT(expected, EqualsProto(actual));
+
+ CHECK(ReadEventProto(reader, &offset, &actual));
+ VLOG(1) << actual.ShortDebugString();
+ ASSERT_TRUE(protobuf::TextFormat::ParseFromString(
+ "wall_time: 2345 step: 35 "
+ "summary { value { tag: 'bar' simple_value: -42 } }",
+ &expected));
+ // TODO(keveman): Enable this check
+ // EXPECT_THAT(expected, EqualsProto(actual));
+
+ TF_CHECK_OK(env()->DeleteFile(filename));
+
+ delete reader;
+ delete event_file;
+}
+
+string GetDirName(const string& suffix) {
+ return io::JoinPath(testing::TmpDir(), suffix);
+}
+
+TEST(EventWriter, WriteFlush) {
+ string file_prefix = GetDirName("/writeflush_test");
+ EventsWriter writer(file_prefix);
+ WriteFile(&writer);
+ EXPECT_TRUE(writer.Flush());
+ string filename = writer.FileName();
+ VerifyFile(filename);
+}
+
+TEST(EventWriter, WriteClose) {
+ string file_prefix = GetDirName("/writeclose_test");
+ EventsWriter writer(file_prefix);
+ WriteFile(&writer);
+ EXPECT_TRUE(writer.Close());
+ string filename = writer.FileName();
+ VerifyFile(filename);
+}
+
+TEST(EventWriter, WriteDelete) {
+ string file_prefix = GetDirName("/writedelete_test");
+ EventsWriter* writer = new EventsWriter(file_prefix);
+ WriteFile(writer);
+ string filename = writer->FileName();
+ delete writer;
+ VerifyFile(filename);
+}
+
+TEST(EventWriter, FailFlush) {
+ string file_prefix = GetDirName("/failflush_test");
+ EventsWriter writer(file_prefix);
+ string filename = writer.FileName();
+ WriteFile(&writer);
+ EXPECT_TRUE(env()->FileExists(filename));
+ env()->DeleteFile(filename);
+ EXPECT_FALSE(env()->FileExists(filename));
+ EXPECT_FALSE(writer.Flush());
+ EXPECT_FALSE(env()->FileExists(filename));
+}
+
+TEST(EventWriter, FailClose) {
+ string file_prefix = GetDirName("/failclose_test");
+ EventsWriter writer(file_prefix);
+ string filename = writer.FileName();
+ WriteFile(&writer);
+ EXPECT_TRUE(env()->FileExists(filename));
+ env()->DeleteFile(filename);
+ EXPECT_FALSE(env()->FileExists(filename));
+ EXPECT_FALSE(writer.Close());
+ EXPECT_FALSE(env()->FileExists(filename));
+}
+
+TEST(EventWriter, InitWriteClose) {
+ string file_prefix = GetDirName("/initwriteclose_test");
+ EventsWriter writer(file_prefix);
+ EXPECT_TRUE(writer.Init());
+ string filename0 = writer.FileName();
+ EXPECT_TRUE(env()->FileExists(filename0));
+ WriteFile(&writer);
+ EXPECT_TRUE(writer.Close());
+ string filename1 = writer.FileName();
+ EXPECT_EQ(filename0, filename1);
+ VerifyFile(filename1);
+}
+
+TEST(EventWriter, NameWriteClose) {
+ string file_prefix = GetDirName("/namewriteclose_test");
+ EventsWriter writer(file_prefix);
+ string filename = writer.FileName();
+ EXPECT_TRUE(env()->FileExists(filename));
+ WriteFile(&writer);
+ EXPECT_TRUE(writer.Close());
+ VerifyFile(filename);
+}
+
+TEST(EventWriter, NameClose) {
+ string file_prefix = GetDirName("/nameclose_test");
+ EventsWriter writer(file_prefix);
+ string filename = writer.FileName();
+ EXPECT_TRUE(writer.Close());
+ EXPECT_TRUE(env()->FileExists(filename));
+ env()->DeleteFile(filename);
+}
+
+TEST(EventWriter, FileDeletionBeforeWriting) {
+ string file_prefix = GetDirName("/fdbw_test");
+ EventsWriter writer(file_prefix);
+ string filename0 = writer.FileName();
+ EXPECT_TRUE(env()->FileExists(filename0));
+ env()->SleepForMicroseconds(
+ 2000000); // To make sure timestamp part of filename will differ.
+ env()->DeleteFile(filename0);
+ EXPECT_TRUE(writer.Init()); // Init should reopen file.
+ WriteFile(&writer);
+ EXPECT_TRUE(writer.Flush());
+ string filename1 = writer.FileName();
+ EXPECT_NE(filename0, filename1);
+ VerifyFile(filename1);
+}
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/util/guarded_philox_random.cc b/tensorflow/core/util/guarded_philox_random.cc
new file mode 100644
index 0000000000..4cf58b8979
--- /dev/null
+++ b/tensorflow/core/util/guarded_philox_random.cc
@@ -0,0 +1,39 @@
+#include "tensorflow/core/util/guarded_philox_random.h"
+#include "tensorflow/core/lib/random/random.h"
+
+namespace tensorflow {
+
+Status GuardedPhiloxRandom::Init(OpKernelConstruction* context) {
+ // Grab seed Attrs.
+ int64 seed, seed2;
+ auto status = context->GetAttr("seed", &seed);
+ if (!status.ok()) return status;
+ status = context->GetAttr("seed2", &seed2);
+ if (!status.ok()) return status;
+
+ // Initialize with the given seeds
+ Init(seed, seed2);
+ return Status::OK();
+}
+
+void GuardedPhiloxRandom::Init(int64 seed, int64 seed2) {
+ CHECK(!initialized_);
+ if (seed == 0 && seed2 == 0) {
+ // If both seeds are unspecified, use completely random seeds.
+ seed = random::New64();
+ seed2 = random::New64();
+ }
+ mutex_lock lock(mu_);
+ generator_ = random::PhiloxRandom(seed, seed2);
+ initialized_ = true;
+}
+
+random::PhiloxRandom GuardedPhiloxRandom::ReserveSamples128(int64 samples) {
+ CHECK(initialized_);
+ mutex_lock lock(mu_);
+ auto local = generator_;
+ generator_.Skip(samples);
+ return local;
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/util/guarded_philox_random.h b/tensorflow/core/util/guarded_philox_random.h
new file mode 100644
index 0000000000..6e9cb9f99c
--- /dev/null
+++ b/tensorflow/core/util/guarded_philox_random.h
@@ -0,0 +1,56 @@
+#ifndef TENSORFLOW_KERNELS_GUARDED_PHILOX_RANDOM_H_
+#define TENSORFLOW_KERNELS_GUARDED_PHILOX_RANDOM_H_
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/lib/random/philox_random.h"
+#include "tensorflow/core/platform/port.h"
+
+namespace tensorflow {
+
+// A thread safe wrapper around a Philox generator. Example usage:
+//
+// GuardedRandomPhilox generator;
+// generator.Init(context);
+//
+// // In thread safe code
+// const int samples = ...;
+// auto local_generator = generator.ReserveSamples128(samples);
+// for (int i = 0; i < samples; i++)
+// Array<uint32, 4> sample = local_generator();
+// // Use sample
+// }
+//
+class GuardedPhiloxRandom {
+ public:
+ // Must call Init to finish initialization
+ GuardedPhiloxRandom() : initialized_(false) {}
+
+ // Initialize the generator from attributes "seed" and "seed2".
+ // If both seeds are unspecified, use random seeds.
+ // Must be called exactly once.
+ Status Init(OpKernelConstruction* context);
+
+ // Initialize with given seeds.
+ void Init(int64 seed, int64 seed2);
+
+ // Reserve a certain number of 128-bit samples.
+ // This function is thread safe. The returned generator is valid for the
+ // given number of samples, and can be used without a lock.
+ random::PhiloxRandom ReserveSamples128(int64 samples);
+
+ // Reserve a certain number of 32-bit samples
+ random::PhiloxRandom ReserveSamples32(int64 samples) {
+ return ReserveSamples128((samples + 3) / 4);
+ }
+
+ private:
+ mutex mu_;
+ random::PhiloxRandom generator_ GUARDED_BY(mu_);
+ bool initialized_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(GuardedPhiloxRandom);
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_KERNELS_GUARDED_PHILOX_RANDOM_H_
diff --git a/tensorflow/core/util/padding.cc b/tensorflow/core/util/padding.cc
new file mode 100644
index 0000000000..24273e5ca4
--- /dev/null
+++ b/tensorflow/core/util/padding.cc
@@ -0,0 +1,24 @@
+#include "tensorflow/core/util/padding.h"
+
+#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/lib/core/errors.h"
+
+namespace tensorflow {
+
+Status GetNodeAttr(const NodeDef& node_def, const string& attr_name,
+ Padding* value) {
+ string str_value;
+ TF_RETURN_IF_ERROR(GetNodeAttr(node_def, attr_name, &str_value));
+ if (str_value == "SAME") {
+ *value = SAME;
+ } else if (str_value == "VALID") {
+ *value = VALID;
+ } else {
+ return errors::NotFound(str_value, " is not an allowed padding type");
+ }
+ return Status::OK();
+}
+
+string GetPaddingAttrString() { return "padding: {'SAME', 'VALID'}"; }
+
+} // end namespace tensorflow
diff --git a/tensorflow/core/util/padding.h b/tensorflow/core/util/padding.h
new file mode 100644
index 0000000000..66cd96abdb
--- /dev/null
+++ b/tensorflow/core/util/padding.h
@@ -0,0 +1,37 @@
+#ifndef TENSORFLOW_UTIL_PADDING_H_
+#define TENSORFLOW_UTIL_PADDING_H_
+
+// This file contains helper routines to deal with padding in various ops and
+// kernels.
+
+#include <string>
+
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/public/status.h"
+
+namespace tensorflow {
+
+// Padding: the padding we apply to the input tensor along the rows and columns
+// dimensions. This is usually used to make sure that the spatial dimensions do
+// not shrink when we progress with convolutions. Two types of padding are
+// supported:
+// VALID: No padding is carried out.
+// SAME: The pad value is computed so that the output will have the same
+// dimensions as the input.
+// The padded area is zero-filled.
+enum Padding {
+ VALID = 1, // No padding.
+ SAME = 2, // Input and output layers have the same size.
+};
+
+// Return the string containing the list of valid padding types, that can be
+// used as an Attr() in REGISTER_OP.
+string GetPaddingAttrString();
+
+// Specialization to parse an attribute directly into a Padding enum.
+Status GetNodeAttr(const NodeDef& node_def, const string& attr_name,
+ Padding* value);
+
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_UTIL_PADDING_H_
diff --git a/tensorflow/core/util/port.cc b/tensorflow/core/util/port.cc
new file mode 100644
index 0000000000..12eb076a4d
--- /dev/null
+++ b/tensorflow/core/util/port.cc
@@ -0,0 +1,13 @@
+#include "tensorflow/core/util/port.h"
+
+namespace tensorflow {
+
+bool IsGoogleCudaEnabled() {
+#if GOOGLE_CUDA
+ return true;
+#else
+ return false;
+#endif
+}
+
+} // end namespace tensorflow
diff --git a/tensorflow/core/util/port.h b/tensorflow/core/util/port.h
new file mode 100644
index 0000000000..8b9d033d63
--- /dev/null
+++ b/tensorflow/core/util/port.h
@@ -0,0 +1,11 @@
+#ifndef TENSORFLOW_UTIL_PORT_H_
+#define TENSORFLOW_UTIL_PORT_H_
+
+namespace tensorflow {
+
+// Returns true if GOOGLE_CUDA is defined.
+bool IsGoogleCudaEnabled();
+
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_UTIL_PORT_H_
diff --git a/tensorflow/core/util/saved_tensor_slice.proto b/tensorflow/core/util/saved_tensor_slice.proto
new file mode 100644
index 0000000000..f6599d9669
--- /dev/null
+++ b/tensorflow/core/util/saved_tensor_slice.proto
@@ -0,0 +1,76 @@
+// Protocol buffers for saved tensor slices. It's used for the brain tensor
+// ops checkpoints and the V3 checkpoints in dist_belief.
+
+// A checkpoint file is an sstable. The value for each record is a serialized
+// SavedTensorSlices message (defined below).
+//
+// Each checkpoint file has a record with the empty key (""), which corresponds
+// to a SavedTensorSlices message that contains a "meta", that serves as a
+// table of contents on all the tensor slices saved in this file. Since the key
+// is "", it's always the first record in each file.
+//
+// Each of the rest of the records in a checkpoint stores the raw data of a
+// particular tensor slice, in SavedSlice format. The corresponding key is an
+// ordered code that encodes the name of the tensor and the slice
+// information. The name is also stored in the SaveSlice message for ease of
+// debugging and manual examination.
+
+syntax = "proto3";
+
+package tensorflow;
+// option cc_enable_arenas = true;
+
+import "tensorflow/core/framework/tensor_shape.proto";
+import "tensorflow/core/framework/tensor_slice.proto";
+import "tensorflow/core/framework/tensor.proto";
+import "tensorflow/core/framework/types.proto";
+
+// Metadata describing the set of slices of the same tensor saved in a
+// checkpoint file.
+message SavedSliceMeta {
+ // Name of the tensor.
+ string name = 1;
+
+ // Shape of the tensor
+ TensorShapeProto shape = 2;
+
+ // Type of the tensor
+ DataType type = 3;
+
+ // Explicit list of slices saved in the checkpoint file.
+ repeated TensorSliceProto slice = 4;
+};
+
+// Metadata describing the set of tensor slices saved in a checkpoint file.
+// It is always stored at the beginning of each checkpoint file.
+message SavedTensorSliceMeta {
+ // Each SavedSliceMeta describes the slices for one tensor.
+ repeated SavedSliceMeta tensor = 1;
+};
+
+// Saved tensor slice: it stores the name of the tensors, the slice, and the
+// raw data.
+message SavedSlice {
+ // Name of the tensor that this slice belongs to. This must be identical to
+ // the name used to encode the key for this record.
+ string name = 1;
+
+ // Extent of the slice. Must have one entry for each of the dimension of the
+ // tensor that this slice belongs to.
+ TensorSliceProto slice = 2;
+
+ // The raw data of the slice is stored as a TensorProto. Only raw data are
+ // stored (we don't fill in fields such as dtype or tensor_shape).
+ TensorProto data = 3;
+};
+
+// Each record in a v3 checkpoint file is a serialized SavedTensorSlices
+// message.
+message SavedTensorSlices {
+ // This is only present at the first item of each checkpoint file and serves
+ // as a table of contents, listing all the tensor slices saved in this file.
+ SavedTensorSliceMeta meta = 1;
+
+ // This exists in all but the first item of each checkpoint file.
+ SavedSlice data = 2;
+};
diff --git a/tensorflow/core/util/saved_tensor_slice_util.cc b/tensorflow/core/util/saved_tensor_slice_util.cc
new file mode 100644
index 0000000000..7a5903f07f
--- /dev/null
+++ b/tensorflow/core/util/saved_tensor_slice_util.cc
@@ -0,0 +1,76 @@
+#include "tensorflow/core/util/saved_tensor_slice_util.h"
+
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/strings/ordered_code.h"
+
+namespace tensorflow {
+
+namespace checkpoint {
+
+const char kSavedTensorSlicesKey[] = "";
+
+string EncodeTensorNameSlice(const string& name, const TensorSlice& slice) {
+ string buffer;
+ // All the tensor slice keys will start with a 0
+ tensorflow::strings::OrderedCode::WriteNumIncreasing(&buffer, 0);
+ tensorflow::strings::OrderedCode::WriteString(&buffer, name);
+ tensorflow::strings::OrderedCode::WriteNumIncreasing(&buffer, slice.dims());
+ for (int d = 0; d < slice.dims(); ++d) {
+ // A trivial extent (meaning we take EVERYTHING) will default to -1 for both
+ // start and end. These will be properly parsed.
+ tensorflow::strings::OrderedCode::WriteSignedNumIncreasing(&buffer,
+ slice.start(d));
+ tensorflow::strings::OrderedCode::WriteSignedNumIncreasing(&buffer,
+ slice.length(d));
+ }
+ return buffer;
+}
+
+Status DecodeTensorNameSlice(const string& code, string* name,
+ tensorflow::TensorSlice* slice) {
+ StringPiece src(code);
+ uint64 x;
+ if (!tensorflow::strings::OrderedCode::ReadNumIncreasing(&src, &x)) {
+ return errors::Internal("Failed to parse the leading number: src = ", src);
+ }
+ if (x != 0) {
+ return errors::Internal(
+ "The leading number should always be 0 for any valid key: src = ", src);
+ }
+ if (!tensorflow::strings::OrderedCode::ReadString(&src, name)) {
+ return errors::Internal("Failed to parse the tensor name: src = ", src);
+ }
+ if (!tensorflow::strings::OrderedCode::ReadNumIncreasing(&src, &x)) {
+ return errors::Internal("Failed to parse the tensor rank: src = ", src);
+ }
+ if (x == 0) {
+ return errors::Internal("Expecting positive rank of the tensor, got ", x,
+ ", src = ", src);
+ }
+ if (x >= kint32max) {
+ return errors::Internal("Too many elements ", x);
+ }
+ slice->SetFullSlice(x);
+ for (int d = 0; d < static_cast<int32>(x); ++d) {
+ // We expected 2x integers
+ int64 start, length;
+ if (!tensorflow::strings::OrderedCode::ReadSignedNumIncreasing(&src,
+ &start)) {
+ return errors::Internal("Failed to parse start: src = ", src);
+ }
+ if (!tensorflow::strings::OrderedCode::ReadSignedNumIncreasing(&src,
+ &length)) {
+ return errors::Internal("Failed to parse length: src = ", src);
+ }
+ if (length >= 0) {
+ // a non-trivial extent
+ slice->set_start(d, start);
+ slice->set_length(d, length);
+ }
+ }
+ return Status::OK();
+}
+
+} // namespace checkpoint
+
+} // namespace tensorflow
diff --git a/tensorflow/core/util/saved_tensor_slice_util.h b/tensorflow/core/util/saved_tensor_slice_util.h
new file mode 100644
index 0000000000..6206cd8538
--- /dev/null
+++ b/tensorflow/core/util/saved_tensor_slice_util.h
@@ -0,0 +1,110 @@
+// Utilities for saving/restoring tensor slice checkpoints.
+
+#ifndef TENSORFLOW_UTIL_SAVED_TENSOR_SLICE_UTIL_H_
+#define TENSORFLOW_UTIL_SAVED_TENSOR_SLICE_UTIL_H_
+
+#include <string> // for string
+#include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/framework/tensor.pb.h"
+#include "tensorflow/core/framework/tensor_slice.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/public/status.h" // for Status
+
+namespace tensorflow {
+
+namespace checkpoint {
+
+// The key for the metadata in the tensor slice checkpoint files. It is "" so
+// that the metadata is always at the beginning of a checkpoint file.
+extern const char kSavedTensorSlicesKey[];
+
+// Encode a tensor name + a tensor slice into an ordered code and outputs it as
+// a string.
+// The format is
+// <0>
+// <tensor_name>
+// <rank>
+// <dim-0-start><dim-0-length>
+// <dim-1-start><dim-1-length>
+// ...
+
+string EncodeTensorNameSlice(const string& name,
+ const tensorflow::TensorSlice& slice);
+
+// Parse out the name and the slice from string encoded as an ordered code.
+Status DecodeTensorNameSlice(const string& code, string* name,
+ tensorflow::TensorSlice* slice);
+
+template <typename T>
+struct SaveTypeTraits;
+
+template <typename T>
+const typename SaveTypeTraits<T>::SavedType* TensorProtoData(
+ const TensorProto& t);
+
+template <typename T>
+protobuf::RepeatedField<typename SaveTypeTraits<T>::SavedType>*
+MutableTensorProtoData(TensorProto* t);
+
+template <typename T>
+void Fill(T* data, size_t n, TensorProto* t);
+
+#define TENSOR_PROTO_EXTRACT_TYPE(TYPE, FIELD, FTYPE) \
+ template <> \
+ struct SaveTypeTraits<TYPE> { \
+ static constexpr bool supported = true; \
+ typedef FTYPE SavedType; \
+ }; \
+ template <> \
+ inline const FTYPE* TensorProtoData<TYPE>(const TensorProto& t) { \
+ static_assert(SaveTypeTraits<TYPE>::supported, \
+ "Specified type " #TYPE " not supported for Restore"); \
+ return reinterpret_cast<const FTYPE*>(t.FIELD##_val().data()); \
+ } \
+ template <> \
+ inline protobuf::RepeatedField<FTYPE>* MutableTensorProtoData<TYPE>( \
+ TensorProto * t) { \
+ static_assert(SaveTypeTraits<TYPE>::supported, \
+ "Specified type " #TYPE " not supported for Save"); \
+ return reinterpret_cast<protobuf::RepeatedField<FTYPE>*>( \
+ t->mutable_##FIELD##_val()); \
+ } \
+ template <> \
+ inline void Fill(const TYPE* data, size_t n, TensorProto* t) { \
+ typename protobuf::RepeatedField<FTYPE> copy(data, data + n); \
+ t->mutable_##FIELD##_val()->Swap(&copy); \
+ }
+
+TENSOR_PROTO_EXTRACT_TYPE(float, float, float);
+TENSOR_PROTO_EXTRACT_TYPE(double, double, double);
+TENSOR_PROTO_EXTRACT_TYPE(int32, int, int32);
+TENSOR_PROTO_EXTRACT_TYPE(int64, int64, int64);
+TENSOR_PROTO_EXTRACT_TYPE(uint8, int, int32);
+TENSOR_PROTO_EXTRACT_TYPE(int8, int, int32);
+TENSOR_PROTO_EXTRACT_TYPE(int16, int, int32);
+TENSOR_PROTO_EXTRACT_TYPE(qint8, int, int32);
+TENSOR_PROTO_EXTRACT_TYPE(quint8, int, int32);
+
+#undef TENSOR_PROTO_EXTRACT_TYPE
+
+template <>
+struct SaveTypeTraits<qint32> : SaveTypeTraits<int32> {};
+
+template <>
+inline const int32* TensorProtoData<qint32>(const TensorProto& t) {
+ static_assert(SaveTypeTraits<qint32>::supported,
+ "Specified type qint32 not supported for Restore");
+ return reinterpret_cast<const int32*>(t.int_val().data());
+}
+
+inline void Fill(const qint32* data, size_t n, TensorProto* t) {
+ const int32* p = reinterpret_cast<const int32*>(data);
+ typename protobuf::RepeatedField<int32> copy(p, p + n);
+ t->mutable_int_val()->Swap(&copy);
+}
+
+} // namespace checkpoint
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_UTIL_SAVED_TENSOR_SLICE_UTIL_H_
diff --git a/tensorflow/core/util/saved_tensor_slice_util_test.cc b/tensorflow/core/util/saved_tensor_slice_util_test.cc
new file mode 100644
index 0000000000..2c34c903db
--- /dev/null
+++ b/tensorflow/core/util/saved_tensor_slice_util_test.cc
@@ -0,0 +1,32 @@
+#include "tensorflow/core/util/saved_tensor_slice_util.h"
+
+#include <gtest/gtest.h>
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/protobuf.h"
+
+namespace tensorflow {
+
+namespace checkpoint {
+
+namespace {
+
+// Testing serialization of tensor name and tensor slice in the ordered code
+// format.
+TEST(TensorShapeUtilTest, TensorNameSliceToOrderedCode) {
+ {
+ TensorSlice s = TensorSlice::ParseOrDie("-:-:1,3:4,5");
+ string buffer = EncodeTensorNameSlice("foo", s);
+ string name;
+ s.Clear();
+ TF_CHECK_OK(DecodeTensorNameSlice(buffer, &name, &s));
+ EXPECT_EQ("foo", name);
+ EXPECT_EQ("-:-:1,3:4,5", s.DebugString());
+ }
+}
+
+} // namespace
+
+} // namespace checkpoint
+
+} // namespace tensorflow
diff --git a/tensorflow/core/util/sparse/README.md b/tensorflow/core/util/sparse/README.md
new file mode 100644
index 0000000000..7b0799eb0e
--- /dev/null
+++ b/tensorflow/core/util/sparse/README.md
@@ -0,0 +1,222 @@
+SparseTensor
+============
+
+Sparse Tensors are stored as two dense tensors and a shape:
+
+* `indices`: a `brain::Tensor` storing a matrix, typically `int64`
+* `values`: a `brain::Tensor` storing a vector with values of type T.
+* `shape`: a `TensorShape` storing the bounds of the underlying tensor
+* `order`: (optional) a `gtl::InlinedVector<int64,8>` with the dimensions
+ along which the indices are ordered.
+
+Let
+
+ ix = indices.matrix<int64>()
+ vals = values.vec<T>()
+
+The shape of `ix` is `N x NDIMS`, and each row corresponds to the
+index of a single element of the sparse tensor.
+
+The length of `vals` must be `N`, and `vals(i)` corresponds to the
+value with index `ix(i,:)`.
+
+Shape must be a `TensorShape` with `dims() == NDIMS`.
+The shape is the full shape of the dense tensor these indices
+represent.
+
+To be specific, the representation (pseudocode) is:
+
+ tensor[ix[i,:]] == vals[i] for i = 0, ..., N-1
+
+Ordering
+--------
+
+Indices need not be provided in order. For example, the following
+index matrix is ordered according to dimension order `{0, 1, 2}`.
+
+ [0 0 1]
+ [0 1 1]
+ [2 0 2]
+
+However, you can provide an unordered version:
+
+ [2 0 2]
+ [0 0 1]
+ [0 1 1]
+
+If the SparseTensor is constructed without a provided order, then a
+the default order is `{-1, ..., -1}`. Certain operations will fail or crash
+when the order is not provided.
+
+Resorting the SparseTensor in-place (which resorts the underlying index and
+values tensors in-place) will update the order. The cost of reordering the
+matrix is `O(N*log(N))`, and requires `O(N)` additional temporary space to store
+a reordering index. If the default order is not specified and reordering is not
+performed, the following will happen:
+
+* `group()` will **raise an assertion failure**
+* `IndicesValid()` will **raise an assertion failure**
+
+To update the internal index ordering after construction, call
+`Reorder<T>()` via, e.g., `Reorder<T>({0,1,2})`.
+After this step, all the above methods should work correctly.
+
+The method `IndicesValid()` checks to make sure:
+
+* `0 <= ix(i, d) < shape.dim_size(d)`
+* indices do not repeat
+* indices are in order
+
+Iterating
+---------
+
+### group({grouping dims})
+
+* provides an iterator that groups entries according to
+ dimensions you care about
+* may require a sort if your data isn't presorted in a way that's
+ compatible with grouping_dims
+* for each group, returns the group index (values of the group
+ dims for this iteration), the subset of indices in this group,
+ and the subset of values in this group. these are lazy outputs
+ so to read them individually, copy them as per the example
+ below.
+
+#### **NOTE**
+`group({dim0, ..., dimk})` will **raise an assertion failure** if the
+order of the SparseTensor does not match the dimensions you wish to group by.
+You must either have your indices in the correct order and construct the
+SparseTensor with
+
+ order = {dim0, ..., dimk, ...}
+
+or call
+
+ Reorder<T>({dim0, .., dimk, ...})
+
+to sort the SparseTensor before grouping.
+
+Example of grouping:
+
+ Tensor indices(DT_INT64, TensorShape({N, NDIMS});
+ Tensor values(DT_STRING, TensorShape({N});
+ TensorShape shape({dim0,...});
+ SparseTensor sp(indices, vals, shape);
+ sp.Reorder<string>({1, 2, 0, 3, ...}); // Must provide NDIMS dims.
+ // group according to dims 1 and 2
+ for (const auto& g : sp.group({1, 2})) {
+ cout << "vals of ix[:, 1,2] for this group: "
+ << g.group()[0] << ", " << g.group()[1];
+ cout << "full indices of group:\n" << g.indices();
+ cout << "values of group:\n" << g.values();
+
+ TTypes<int64>::UnalignedMatrix g_ix = g.indices();
+ TTypes<string>::UnalignedVec g_v = g.values();
+ ASSERT(g_ix.dimension(0) == g_v.size()); // number of elements match.
+ }
+
+
+ToDense
+--------
+
+Converts sparse tensor to dense. You must provide a pointer to the
+dense tensor (preallocated). `ToDense()` will optionally
+preinitialize the tensor with zeros.
+
+Shape checking is performed, as is boundary checking.
+
+ Tensor indices(DT_INT64, TensorShape({N, NDIMS});
+ Tensor values(DT_STRING, TensorShape({N});
+ TensorShape shape({dim0,...});
+ SparseTensor sp(indices, vals, shape);
+ ASSERT(sp.IndicesValid()); // checks ordering & index bounds.
+
+ Tensor dense(DT_STRING, shape);
+ // initialize other indices to zero. copy.
+ ASSERT(sp.ToDense<string>(&dense, true));
+
+
+Concat
+--------
+
+Concatenates multiple SparseTensors and returns a new SparseTensor.
+This concatenation is with respect to the "dense" versions of these
+SparseTensors. Concatenation is performed along dimension order[0]
+of all tensors. As a result, shape[order[0]] may differ across
+the inputs, but shape[d] for d != order[0] must match across all inputs.
+
+We call order[0] the **primary dimension**.
+
+**Prerequisites**
+
+* The inputs' ranks must all match.
+* The inputs' order[0] must all match.
+* The inputs' shapes must all match except for dimension order[0].
+* The inputs' values must all be of the same type.
+
+If any of these are false, concat will die with an assertion failure.
+
+Example:
+Concatenate two sparse matrices along columns.
+
+Matrix 1:
+
+ [0 0 1]
+ [2 0 0]
+ [3 0 4]
+
+Matrix 2:
+
+ [0 0 0 0 0]
+ [0 1 0 0 0]
+ [2 0 0 1 0]
+
+Concatenated Matrix:
+
+ [0 0 1 0 0 0 0 0]
+ [2 0 0 0 1 0 0 0]
+ [3 0 4 2 0 0 1 0]
+
+Expected input shapes, orders, and `nnz()`:
+
+ shape_1 = TensorShape({3, 3})
+ shape_2 = TensorShape({3, 8})
+ order_1 = {1, 0} // primary order is 1, columns
+ order_2 = {1, 0} // primary order is 1, must match
+ nnz_1 = 4
+ nnz_2 = 3
+
+Output shapes and orders:
+
+ conc_shape = TensorShape({3, 11}) // primary dim increased, others same
+ conc_order = {1, 0} // Orders match along all inputs
+ conc_nnz = 7 // Sum of nonzeros of inputs
+
+Coding Example:
+
+ Tensor ix1(DT_INT64, TensorShape({N1, 3});
+ Tensor vals1(DT_STRING, TensorShape({N1, 3});
+ Tensor ix2(DT_INT64, TensorShape({N2, 3});
+ Tensor vals2(DT_STRING, TensorShape({N2, 3});
+ Tensor ix3(DT_INT64, TensorShape({N3, 3});
+ Tensor vals3(DT_STRING, TensorShape({N3, 3});
+
+ SparseTensor st1(ix1, vals1, TensorShape({10, 20, 5}), {1, 0, 2});
+ SparseTensor st2(ix2, vals2, TensorShape({10, 10, 5}), {1, 0, 2});
+ // For kicks, st3 indices are out of order, but order[0] matches so we
+ // can still concatenate along this dimension.
+ SparseTensor st3(ix3, vals3, TensorShape({10, 30, 5}), {1, 2, 0});
+
+ SparseTensor conc = SparseTensor::Concat<string>({st1, st2, st3});
+ Tensor ix_conc = conc.indices();
+ Tensor vals_conc = conc.values();
+ EXPECT_EQ(conc.nnz(), st1.nnz() + st2.nnz() + st3.nnz());
+ EXPECT_EQ(conc.Shape(), TensorShape({10, 60, 5}));
+ EXPECT_EQ(conc.Order(), {-1, -1, -1});
+
+ // Reorder st3 so all input tensors have the exact same orders.
+ st3.Reorder<string>({1, 0, 2});
+ SparseTensor conc2 = SparseTensor::Concat<string>({st1, st2, st3});
+ EXPECT_EQ(conc2.Order(), {1, 0, 2});
+ // All indices' orders matched, so output is in order.
+ EXPECT_TRUE(conc2.IndicesValid());
diff --git a/tensorflow/core/util/sparse/dim_comparator.h b/tensorflow/core/util/sparse/dim_comparator.h
new file mode 100644
index 0000000000..57473867cf
--- /dev/null
+++ b/tensorflow/core/util/sparse/dim_comparator.h
@@ -0,0 +1,60 @@
+#ifndef TENSORFLOW_UTIL_SPARSE_DIM_COMPARATOR_H_
+#define TENSORFLOW_UTIL_SPARSE_DIM_COMPARATOR_H_
+
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/platform/logging.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
+namespace tensorflow {
+namespace sparse {
+
+/////////////////
+// DimComparator
+/////////////////
+//
+// Helper class, mainly used by the IndexSortOrder. This comparator
+// can be passed to e.g. std::sort, or any other sorter, to sort two
+// rows of an index matrix according to the dimension(s) of interest.
+// The dimensions to sort by are passed to the constructor as "order".
+//
+// Example: if given index matrix IX, two rows ai and bi, and order = {2,1}.
+// operator() compares
+// IX(ai,2) < IX(bi,2).
+// If IX(ai,2) == IX(bi,2), it compares
+// IX(ai,1) < IX(bi,1).
+//
+// This can be used to sort a vector of row indices into IX according to
+// the values in IX in particular columns (dimensions) of interest.
+class DimComparator {
+ public:
+ typedef typename gtl::ArraySlice<int64> VarDimArray;
+
+ inline DimComparator(const TTypes<int64>::Matrix& ix,
+ const VarDimArray& order, int dims)
+ : ix_(ix), order_(order), dims_(dims) {
+ CHECK_GT(order.size(), 0) << "Must order using at least one index";
+ CHECK_LE(order.size(), dims_) << "Can only sort up to dims";
+ for (size_t d = 0; d < order.size(); ++d) {
+ CHECK_GE(order[d], 0);
+ CHECK_LT(order[d], dims);
+ }
+ }
+
+ inline bool operator()(const int64 i, const int64 j) const {
+ for (int di = 0; di < dims_; ++di) {
+ const int64 d = order_[di];
+ if (ix_(i, d) < ix_(j, d)) return true;
+ if (ix_(i, d) > ix_(j, d)) return false;
+ }
+ return false;
+ }
+
+ const TTypes<int64>::Matrix ix_;
+ const VarDimArray order_;
+ const int dims_;
+};
+
+} // namespace sparse
+} // namespace tensorflow
+
+#endif // TENSORFLOW_UTIL_SPARSE_DIM_COMPARATOR_H_
diff --git a/tensorflow/core/util/sparse/group_iterator.cc b/tensorflow/core/util/sparse/group_iterator.cc
new file mode 100644
index 0000000000..e153bcdbb4
--- /dev/null
+++ b/tensorflow/core/util/sparse/group_iterator.cc
@@ -0,0 +1,49 @@
+#include "tensorflow/core/util/sparse/group_iterator.h"
+
+namespace tensorflow {
+namespace sparse {
+
+void GroupIterable::IteratorStep::UpdateEndOfGroup() {
+ ++next_loc_;
+ int64 N = iter_->ix_.dim_size(0);
+ auto ix_t = iter_->ix_.template matrix<int64>();
+ while (next_loc_ < N && iter_->GroupMatches(ix_t, loc_, next_loc_)) {
+ ++next_loc_;
+ }
+}
+
+bool GroupIterable::IteratorStep::operator!=(const IteratorStep& rhs) const {
+ CHECK_EQ(rhs.iter_, iter_) << "Can't compare steps from different iterators";
+ return (rhs.loc_ != loc_);
+}
+
+GroupIterable::IteratorStep& GroupIterable::IteratorStep::
+operator++() { // prefix ++
+ loc_ = next_loc_;
+ UpdateEndOfGroup();
+ return *this;
+}
+
+GroupIterable::IteratorStep GroupIterable::IteratorStep::operator++(
+ int) { // postfix ++
+ IteratorStep lhs(*this);
+ ++(*this);
+ return lhs;
+}
+
+std::vector<int64> Group::group() const {
+ std::vector<int64> g;
+ auto ix_t = iter_->ix_.template matrix<int64>();
+ for (const int d : iter_->group_dims_) {
+ g.push_back(ix_t(loc_, d));
+ }
+ return g;
+}
+
+TTypes<int64>::UnalignedConstMatrix Group::indices() const {
+ return TTypes<int64>::UnalignedConstMatrix(
+ &(iter_->ix_.matrix<int64>()(loc_, 0)), next_loc_ - loc_, iter_->dims_);
+}
+
+} // namespace sparse
+} // namespace tensorflow
diff --git a/tensorflow/core/util/sparse/group_iterator.h b/tensorflow/core/util/sparse/group_iterator.h
new file mode 100644
index 0000000000..8423d54f27
--- /dev/null
+++ b/tensorflow/core/util/sparse/group_iterator.h
@@ -0,0 +1,120 @@
+#ifndef TENSORFLOW_UTIL_SPARSE_GROUP_ITERATOR_H_
+#define TENSORFLOW_UTIL_SPARSE_GROUP_ITERATOR_H_
+
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/public/status.h"
+#include "tensorflow/core/public/tensor.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
+namespace tensorflow {
+namespace sparse {
+
+class GroupIterable; // Predeclare GroupIterable for Group.
+
+// This class is returned when dereferencing a GroupIterable iterator.
+// It provides the methods group(), indices(), and values(), which
+// provide access into the underlying SparseTensor.
+class Group {
+ public:
+ Group(GroupIterable* iter, int64 loc, int64 next_loc)
+ : iter_(iter), loc_(loc), next_loc_(next_loc) {}
+
+ std::vector<int64> group() const;
+ TTypes<int64>::UnalignedConstMatrix indices() const;
+ template <typename T>
+ typename TTypes<T>::UnalignedVec values() const;
+
+ private:
+ GroupIterable* iter_;
+ int64 loc_;
+ int64 next_loc_;
+};
+
+/////////////////
+// GroupIterable
+/////////////////
+//
+// Returned when calling sparse_tensor.group({dim0, dim1, ...}).
+//
+// Please note: the sparse_tensor should already be ordered according
+// to {dim0, dim1, ...}. Otherwise this iteration will return invalid groups.
+//
+// Allows grouping and iteration of the SparseTensor according to the
+// subset of dimensions provided to the group call.
+//
+// The actual grouping dimensions are stored in the
+// internal vector group_dims_. Iterators inside the iterable provide
+// the three methods:
+//
+// * group(): returns a vector with the current group dimension values.
+// * indices(): a map of index, providing the indices in
+// this group.
+// * values(): a map of values, providing the values in
+// this group.
+//
+// To iterate across GroupIterable, see examples in README.md.
+//
+
+// Forward declaration of SparseTensor
+class GroupIterable {
+ public:
+ typedef gtl::ArraySlice<int64> VarDimArray;
+
+ GroupIterable(Tensor ix, Tensor vals, int dims, const VarDimArray& group_dims)
+ : ix_(ix), vals_(vals), dims_(dims), group_dims_(group_dims) {}
+
+ class IteratorStep;
+
+ IteratorStep begin() { return IteratorStep(this, 0); }
+ IteratorStep end() { return IteratorStep(this, ix_.dim_size(0)); }
+
+ template <typename TIX>
+ inline bool GroupMatches(const TIX& ix, int64 loc_a, int64 loc_b) const {
+ bool matches = true;
+ for (int d : group_dims_) {
+ if (ix(loc_a, d) != ix(loc_b, d)) {
+ matches = false;
+ }
+ }
+ return matches;
+ }
+
+ class IteratorStep {
+ public:
+ IteratorStep(GroupIterable* iter, int64 loc)
+ : iter_(iter), loc_(loc), next_loc_(loc_) {
+ UpdateEndOfGroup();
+ }
+
+ void UpdateEndOfGroup();
+ bool operator!=(const IteratorStep& rhs) const;
+ IteratorStep& operator++(); // prefix ++
+ IteratorStep operator++(int); // postfix ++
+ Group operator*() const { return Group(iter_, loc_, next_loc_); }
+
+ private:
+ GroupIterable* iter_;
+ int64 loc_;
+ int64 next_loc_;
+ };
+
+ private:
+ friend class Group;
+ Tensor ix_;
+ Tensor vals_;
+ const int dims_;
+ const VarDimArray group_dims_;
+};
+
+// Implementation of Group::values<T>()
+template <typename T>
+typename TTypes<T>::UnalignedVec Group::values() const {
+ return typename TTypes<T>::UnalignedVec(&(iter_->vals_.vec<T>()(loc_)),
+ next_loc_ - loc_);
+}
+
+} // namespace sparse
+} // namespace tensorflow
+
+#endif // TENSORFLOW_UTIL_SPARSE_GROUP_ITERATOR_H_
diff --git a/tensorflow/core/util/sparse/sparse_tensor.h b/tensorflow/core/util/sparse/sparse_tensor.h
new file mode 100644
index 0000000000..dcb75e7f54
--- /dev/null
+++ b/tensorflow/core/util/sparse/sparse_tensor.h
@@ -0,0 +1,353 @@
+#ifndef TENSORFLOW_UTIL_SPARSE_SPARSE_TENSOR_H_
+#define TENSORFLOW_UTIL_SPARSE_SPARSE_TENSOR_H_
+
+#include <limits>
+
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/public/status.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/public/tensor.h"
+#include "tensorflow/core/util/sparse/dim_comparator.h"
+#include "tensorflow/core/util/sparse/group_iterator.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
+namespace tensorflow {
+namespace sparse {
+
+class SparseTensor {
+ public:
+ typedef typename gtl::ArraySlice<int64> VarDimArray;
+
+ SparseTensor(Tensor ix, Tensor vals, const TensorShape& shape)
+ : SparseTensor(ix, vals, shape, UndefinedOrder(shape)) {}
+
+ SparseTensor(Tensor ix, Tensor vals, const TensorShape& shape,
+ const VarDimArray& order)
+ : ix_(ix),
+ vals_(vals),
+ shape_(shape),
+ order_(order.begin(), order.end()),
+ dims_(GetDimsFromIx(ix)) {
+ CHECK_EQ(ix.dtype(), DT_INT64) << "indices must be type int64 but got: "
+ << ix.dtype();
+ CHECK(TensorShapeUtils::IsMatrix(ix.shape()))
+ << "indices must be a matrix, but got: " << ix.shape().DebugString();
+ CHECK(TensorShapeUtils::IsVector(vals.shape()))
+ << "vals must be a vec, but got: " << vals.shape().DebugString();
+ CHECK_EQ(ix.shape().dim_size(0), vals.shape().dim_size(0))
+ << "indices and values rows (indexing dimension) must match.";
+ }
+
+ std::size_t num_entries() const { return ix_.dim_size(0); }
+
+ const Tensor& indices() const { return ix_; }
+
+ const Tensor& values() const { return vals_; }
+
+ DataType dtype() const { return vals_.dtype(); }
+
+ bool IndicesValid() const {
+ const auto ix_t = ix_.matrix<int64>();
+ for (int64 ord : order_) {
+ CHECK_GE(ord, 0) << "Order was not provided. Provide an order at "
+ "construction time or run ReorderInPlace";
+ }
+
+ for (std::size_t n = 0; n < num_entries(); ++n) {
+ if (!IndexValid(ix_t, n)) return false;
+ }
+
+ return true;
+ }
+
+ // Returns the tensor shape (the dimensions of the "densified"
+ // tensor this tensor represents).
+ const TensorShape shape() const { return shape_; }
+
+ const VarDimArray order() const { return order_; }
+
+ // Resorts the indices and values according to the dimensions in order.
+ template <typename T>
+ void Reorder(const VarDimArray& order);
+
+ // Returns a group iterable that can be used for clumping indices
+ // and values according to the group indices of interest.
+ //
+ // Precondition: order()[0..group_ix.size()] == group_ix.
+ //
+ // See the README.md in this directory for more usage information.
+ GroupIterable group(const VarDimArray& group_ix) {
+ CHECK_LE(group_ix.size(), dims_);
+ for (std::size_t di = 0; di < group_ix.size(); ++di) {
+ CHECK_GE(group_ix[di], 0) << "Group dimension out of range";
+ CHECK_LT(group_ix[di], dims_) << "Group dimension out of range";
+ CHECK_EQ(group_ix[di], order_[di])
+ << "Group dimension does not match sorted order";
+ }
+ return GroupIterable(ix_, vals_, dims_, group_ix);
+ }
+
+ // Stores the sparse indices into the dense tensor out.
+ // Preconditions:
+ // out->shape().dims() == shape().dims()
+ // out->shape().dim_size(d) >= shape(d) for all d
+ //
+ // Returns true on success. False on failure (mismatched dimensions
+ // or out-of-bounds indices).
+ //
+ // If initialize==True, ToDense first overwrites all coefficients in out to 0.
+ //
+ template <typename T>
+ bool ToDense(Tensor* out, bool initialize = true);
+
+ // Concat() will concatenate all the tensors according to their first order
+ // dimension. All tensors must have identical shape except for
+ // the first order dimension. All tensors orders' first dimension
+ // must match.
+ //
+ // If all of the tensors have identical ordering, then the output
+ // will have this ordering. Otherwise the output is set as not
+ // having any order and a Reorder<T>() should be called on it before
+ // performing any subsequent operations.
+ template <typename T>
+ static SparseTensor Concat(const gtl::ArraySlice<SparseTensor>& tensors);
+
+ private:
+ static int GetDimsFromIx(const Tensor& ix) {
+ CHECK(TensorShapeUtils::IsMatrix(ix.shape()));
+ return ix.dim_size(1);
+ }
+
+ static gtl::InlinedVector<int64, 8> UndefinedOrder(const TensorShape& shape) {
+ return gtl::InlinedVector<int64, 8>(shape.dims(), -1);
+ }
+
+ // Helper for IndicesValid()
+ inline bool IndexValid(const TTypes<int64>::ConstMatrix& ix_t,
+ int64 n) const {
+ bool different = false;
+ bool bad_order = false;
+ bool valid = true;
+ if (n == 0) {
+ for (int di = 0; di < dims_; ++di) {
+ if (ix_t(n, di) < 0 || ix_t(n, di) >= shape_.dim_size(di))
+ valid = false;
+ }
+ different = true;
+ } else {
+ for (int di = 0; di < dims_; ++di) {
+ if (ix_t(n, di) < 0 || ix_t(n, di) >= shape_.dim_size(di))
+ valid = false;
+ int64 diff = ix_t(n, order_[di]) - ix_t(n - 1, order_[di]);
+ if (diff > 0) different = true;
+ if (!different && diff < 0) bad_order = true;
+ }
+ }
+ if (!valid) return false; // Out of bounds
+ if (!different) return false; // The past two indices are identical...
+ if (bad_order) return false; // Decreasing in order.
+ return true;
+ }
+
+ // Helper for ToDense<T>()
+ template <typename T>
+ bool ValidateAndInitializeToDense(Tensor* out, bool initialize);
+
+ Tensor ix_;
+ Tensor vals_;
+ TensorShape shape_;
+ gtl::InlinedVector<int64, 8> order_;
+ const int dims_;
+};
+
+// This operation updates the indices and values Tensor rows, so it is
+// an in-place algorithm. It requires O(N log N) time and O(N)
+// temporary space.
+template <typename T>
+void SparseTensor::Reorder(const VarDimArray& order) {
+ CHECK_EQ(DataTypeToEnum<T>::v(), dtype())
+ << "Reorder requested with the wrong datatype";
+ CHECK_EQ(order.size(), dims_) << "Order length must be SparseTensor rank";
+ auto ix_t = ix_.matrix<int64>();
+ auto vals_t = vals_.vec<T>();
+
+ DimComparator sorter(ix_t, order, dims_);
+
+ std::vector<int64> reorder(num_entries());
+ std::iota(reorder.begin(), reorder.end(), 0);
+
+ // Sort to get order of indices
+ std::sort(reorder.begin(), reorder.end(), sorter);
+
+ // We have a forward reordering, but what we'll need is a
+ // permutation (the inverse). This can be calculated with O(1)
+ // additional
+ // and O(n) time (INVPERM) but we just do the simple thing here.
+ std::vector<int64> permutation(reorder.size());
+ for (std::size_t n = 0; n < reorder.size(); ++n) {
+ permutation[reorder[n]] = n;
+ }
+
+ // Update indices & values by converting the permutations to
+ // a product of transpositions. Iterate over the cycles in the
+ // permutation, and convert each of those into a product of
+ // transpositions (swaps):
+ // https://en.wikipedia.org/wiki/Cyclic_permutation
+ // This is N swaps, 2*N comparisons.
+ for (std::size_t n = 0; n + 1 < permutation.size(); ++n) {
+ while (n != permutation[n]) {
+ std::size_t r = permutation[n];
+ std::swap_ranges(&(ix_t(n, 0)), &(ix_t(n + 1, 0)), &(ix_t(r, 0)));
+ std::swap(vals_t(n), vals_t(r));
+ std::swap(permutation[n], permutation[r]);
+ }
+ }
+
+ order_ = gtl::InlinedVector<int64, 8>(order.begin(), order.end());
+}
+
+template <typename T>
+bool SparseTensor::ValidateAndInitializeToDense(Tensor* out, bool initialize) {
+ CHECK_EQ(DataTypeToEnum<T>::v(), dtype())
+ << "ToDense requested with the wrong datatype";
+
+ CHECK_EQ(out->shape().dims(), dims_)
+ << "Incompatible dimensions between SparseTensor and output";
+
+ CHECK_EQ(out->dtype(), DataTypeToEnum<T>::v())
+ << "Output must be type: " << DataTypeToEnum<T>::v()
+ << " but got: " << out->dtype();
+
+ // Make sure the dense output is the same rank and has room
+ // to hold the SparseTensor.
+ const auto& out_shape = out->shape();
+ if (shape_.dims() != out_shape.dims()) return false;
+ for (int d = 0; d < shape_.dims(); ++d) {
+ if (shape_.dim_size(d) > out_shape.dim_size(d)) return false;
+ }
+
+ if (initialize) {
+ auto out_t = out->flat<T>();
+ out_t.setConstant(T());
+ }
+
+ return true;
+}
+
+template <typename T>
+bool SparseTensor::ToDense(Tensor* out, bool initialize) {
+ if (!ValidateAndInitializeToDense<T>(out, initialize)) return false;
+
+ auto out_t = out->flat<T>();
+ auto ix_t = ix_.matrix<int64>();
+ auto vals_t = vals_.vec<T>();
+
+ std::vector<int64> strides(dims_);
+ const auto& out_shape = out->shape();
+ strides[dims_ - 1] = 1;
+ for (int d = dims_ - 2; d >= 0; --d) {
+ strides[d] = strides[d + 1] * out_shape.dim_size(d + 1);
+ }
+
+ for (std::size_t n = 0; n < vals_t.dimension(0); ++n) {
+ bool invalid_dims = false;
+ int64 ix = 0;
+ for (int d = 0; d < dims_; ++d) {
+ const int64 ix_n_d = ix_t(n, d);
+ if (ix_n_d < 0 || ix_n_d >= out_shape.dim_size(d)) {
+ invalid_dims = true;
+ }
+ ix += strides[d] * ix_n_d;
+ }
+ if (invalid_dims) return false;
+ out_t(ix) = vals_t(n);
+ }
+ return true;
+}
+
+template <typename T>
+SparseTensor SparseTensor::Concat(
+ const gtl::ArraySlice<SparseTensor>& tensors) {
+ CHECK_GE(tensors.size(), 1) << "Cannot concat 0 SparseTensors";
+ const int dims = tensors[0].dims_;
+ CHECK_GE(dims, 1) << "Cannot concat 0-dimensional SparseTensors";
+ auto order_0 = tensors[0].order();
+ const int primary_dim = order_0[0];
+ gtl::InlinedVector<int64, 8> final_order(order_0.begin(), order_0.end());
+ TensorShape final_shape(tensors[0].shape());
+ final_shape.set_dim(primary_dim, 0); // We'll build this up as we go along.
+ int num_entries = 0;
+
+ bool fully_ordered = true;
+ for (const SparseTensor& st : tensors) {
+ CHECK_EQ(st.dims_, dims) << "All SparseTensors must have the same rank.";
+ CHECK_EQ(DataTypeToEnum<T>::v(), st.dtype())
+ << "Concat requested with the wrong data type";
+ CHECK_GE(st.order()[0], 0) << "SparseTensor must be ordered";
+ CHECK_EQ(st.order()[0], primary_dim)
+ << "All SparseTensors' order[0] must match. This is the concat dim.";
+ if (st.order() != final_order) fully_ordered = false;
+ const TensorShape st_shape = st.shape();
+ for (int d = 0; d < dims - 1; ++d) {
+ const int cdim = (d < primary_dim) ? d : d + 1;
+ CHECK_EQ(final_shape.dim_size(cdim), st_shape.dim_size(cdim))
+ << "All SparseTensors' shapes must match except on the concat dim. "
+ << "Concat dim: " << primary_dim
+ << ", mismatched shape at dim: " << cdim
+ << ". Expecting shape like: " << final_shape.DebugString()
+ << " but saw shape: " << st_shape.DebugString();
+ }
+
+ // Update dimension of final shape
+ final_shape.set_dim(primary_dim, final_shape.dim_size(primary_dim) +
+ st_shape.dim_size(primary_dim));
+
+ num_entries += st.num_entries(); // Update number of entries
+ }
+
+ // If nonconsistent ordering among inputs, set final order to -1s.
+ if (!fully_ordered) {
+ final_order = UndefinedOrder(final_shape);
+ }
+
+ Tensor output_ix(DT_INT64, TensorShape({num_entries, dims}));
+ Tensor output_vals(DataTypeToEnum<T>::v(), TensorShape({num_entries}));
+
+ auto ix_t = output_ix.matrix<int64>();
+ auto vals_t = output_vals.vec<T>();
+
+ Eigen::DenseIndex offset = 0;
+ int64 shape_offset = 0;
+ for (const SparseTensor& st : tensors) {
+ int st_num_entries = st.num_entries();
+ Eigen::DSizes<Eigen::DenseIndex, 2> ix_start(offset, 0);
+ Eigen::DSizes<Eigen::DenseIndex, 2> ix_size(st_num_entries, dims);
+ Eigen::DSizes<Eigen::DenseIndex, 1> vals_start(offset);
+ Eigen::DSizes<Eigen::DenseIndex, 1> vals_size(st_num_entries);
+
+ // Fill in indices & values.
+ ix_t.slice(ix_start, ix_size) = st.ix_.matrix<int64>();
+ vals_t.slice(vals_start, vals_size) = st.vals_.vec<T>();
+
+ Eigen::DSizes<Eigen::DenseIndex, 2> ix_update_start(offset, primary_dim);
+ Eigen::DSizes<Eigen::DenseIndex, 2> ix_update_size(st_num_entries, 1);
+ // The index associated with the primary dimension gets increased
+ // by the shapes of the previous concatted Tensors.
+ auto update_slice = ix_t.slice(ix_update_start, ix_update_size);
+ update_slice += update_slice.constant(shape_offset);
+
+ offset += st_num_entries;
+ shape_offset += st.shape().dim_size(primary_dim);
+ }
+
+ return SparseTensor(output_ix, output_vals, final_shape, final_order);
+}
+
+} // namespace sparse
+} // namespace tensorflow
+
+#endif // TENSORFLOW_UTIL_SPARSE_SPARSE_TENSOR_H_
diff --git a/tensorflow/core/util/sparse/sparse_tensor_test.cc b/tensorflow/core/util/sparse/sparse_tensor_test.cc
new file mode 100644
index 0000000000..47126b7187
--- /dev/null
+++ b/tensorflow/core/util/sparse/sparse_tensor_test.cc
@@ -0,0 +1,467 @@
+#include "tensorflow/core/util/sparse/sparse_tensor.h"
+
+#include <string>
+#include <vector>
+
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/public/tensor.h"
+#include <gtest/gtest.h>
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
+namespace tensorflow {
+namespace sparse {
+namespace {
+
+Eigen::Tensor<int64, 2, Eigen::RowMajor, Eigen::DenseIndex>
+GetSimpleIndexTensor(int N, const int NDIM) {
+ Eigen::Tensor<int64, 2, Eigen::RowMajor, Eigen::DenseIndex> ix(N, NDIM);
+ ix(0, 0) = 0;
+ ix(0, 1) = 0;
+ ix(0, 2) = 0;
+
+ ix(1, 0) = 3;
+ ix(1, 1) = 0;
+ ix(1, 2) = 0;
+
+ ix(2, 0) = 2;
+ ix(2, 1) = 0;
+ ix(2, 2) = 0;
+
+ ix(3, 0) = 0;
+ ix(3, 1) = 1;
+ ix(3, 2) = 0;
+
+ ix(4, 0) = 0;
+ ix(4, 1) = 0;
+ ix(4, 2) = 2;
+ return ix;
+}
+
+TEST(SparseTensorTest, DimComparatorSorts) {
+ std::size_t N = 5;
+ const int NDIM = 3;
+ auto ix = GetSimpleIndexTensor(N, NDIM);
+ TTypes<int64>::Matrix map(ix.data(), N, NDIM);
+
+ std::vector<int64> sorting(N);
+ for (std::size_t n = 0; n < N; ++n) sorting[n] = n;
+
+ // new order should be: {0, 4, 3, 2, 1}
+ std::vector<int64> order{0, 1, 2};
+ DimComparator sorter(map, order, NDIM);
+ std::sort(sorting.begin(), sorting.end(), sorter);
+
+ EXPECT_EQ(sorting, std::vector<int64>({0, 4, 3, 2, 1}));
+
+ // new order should be: {0, 3, 2, 1, 4}
+ std::vector<int64> order1{2, 0, 1};
+ DimComparator sorter1(map, order1, NDIM);
+ for (std::size_t n = 0; n < N; ++n) sorting[n] = n;
+ std::sort(sorting.begin(), sorting.end(), sorter1);
+
+ EXPECT_EQ(sorting, std::vector<int64>({0, 3, 2, 1, 4}));
+}
+
+TEST(SparseTensorTest, SparseTensorConstruction) {
+ int N = 5;
+ const int NDIM = 3;
+ auto ix_c = GetSimpleIndexTensor(N, NDIM);
+ Eigen::Tensor<string, 1, Eigen::RowMajor> vals_c(N);
+ vals_c(0) = "hi0";
+ vals_c(1) = "hi1";
+ vals_c(2) = "hi2";
+ vals_c(3) = "hi3";
+ vals_c(4) = "hi4";
+
+ Tensor ix(DT_INT64, TensorShape({N, NDIM}));
+ Tensor vals(DT_STRING, TensorShape({N}));
+
+ auto ix_t = ix.matrix<int64>();
+ auto vals_t = vals.vec<string>();
+ vals_t = vals_c;
+ ix_t = ix_c;
+
+ TensorShape shape({10, 10, 10});
+ std::vector<int64> order{0, 1, 2};
+ SparseTensor st(ix, vals, shape, order);
+ EXPECT_FALSE(st.IndicesValid()); // Out of order
+
+ // Regardless of how order is updated; so long as there are no
+ // duplicates, the resulting indices are valid.
+ st.Reorder<string>({2, 0, 1});
+ EXPECT_TRUE(st.IndicesValid());
+ EXPECT_EQ(vals_t(0), "hi0");
+ EXPECT_EQ(vals_t(1), "hi3");
+ EXPECT_EQ(vals_t(2), "hi2");
+ EXPECT_EQ(vals_t(3), "hi1");
+ EXPECT_EQ(vals_t(4), "hi4");
+
+ ix_t = ix_c;
+ vals_t = vals_c;
+ st.Reorder<string>({0, 1, 2});
+ EXPECT_TRUE(st.IndicesValid());
+ EXPECT_EQ(vals_t(0), "hi0");
+ EXPECT_EQ(vals_t(1), "hi4");
+ EXPECT_EQ(vals_t(2), "hi3");
+ EXPECT_EQ(vals_t(3), "hi2");
+ EXPECT_EQ(vals_t(4), "hi1");
+
+ ix_t = ix_c;
+ vals_t = vals_c;
+ st.Reorder<string>({2, 1, 0});
+ EXPECT_TRUE(st.IndicesValid());
+}
+
+TEST(SparseTensorTest, EmptySparseTensorAllowed) {
+ int N = 0;
+ const int NDIM = 3;
+
+ Tensor ix(DT_INT64, TensorShape({N, NDIM}));
+ Tensor vals(DT_STRING, TensorShape({N}));
+
+ TensorShape shape({10, 10, 10});
+ std::vector<int64> order{0, 1, 2};
+ SparseTensor st(ix, vals, shape, order);
+ EXPECT_TRUE(st.IndicesValid());
+ EXPECT_EQ(st.order(), order);
+
+ std::vector<int64> new_order{1, 0, 2};
+ st.Reorder<string>(new_order);
+ EXPECT_TRUE(st.IndicesValid());
+ EXPECT_EQ(st.order(), new_order);
+}
+
+TEST(SparseTensorTest, SortingWorksCorrectly) {
+ int N = 30;
+ const int NDIM = 4;
+
+ Tensor ix(DT_INT64, TensorShape({N, NDIM}));
+ Tensor vals(DT_STRING, TensorShape({N}));
+ TensorShape shape({1000, 1000, 1000, 1000});
+ SparseTensor st(ix, vals, shape);
+
+ auto ix_t = ix.matrix<int64>();
+
+ for (int n = 0; n < 100; ++n) {
+ ix_t = ix_t.random(Eigen::internal::UniformRandomGenerator<int64>(n + 1));
+ ix_t = ix_t.abs() % 1000;
+ st.Reorder<string>({0, 1, 2, 3});
+ EXPECT_TRUE(st.IndicesValid());
+ st.Reorder<string>({3, 2, 1, 0});
+ EXPECT_TRUE(st.IndicesValid());
+ st.Reorder<string>({1, 0, 2, 3});
+ EXPECT_TRUE(st.IndicesValid());
+ st.Reorder<string>({3, 0, 2, 1});
+ EXPECT_TRUE(st.IndicesValid());
+ }
+}
+
+TEST(SparseTensorTest, ValidateIndicesFindsInvalid) {
+ int N = 2;
+ const int NDIM = 3;
+
+ Tensor ix(DT_INT64, TensorShape({N, NDIM}));
+ Tensor vals(DT_STRING, TensorShape({N}));
+
+ Eigen::Tensor<int64, 2, Eigen::RowMajor> ix_orig(N, NDIM);
+ ix_orig(0, 0) = 0;
+ ix_orig(0, 1) = 0;
+ ix_orig(0, 2) = 0;
+
+ ix_orig(1, 0) = 0;
+ ix_orig(1, 1) = 0;
+ ix_orig(1, 2) = 0;
+
+ auto ix_t = ix.matrix<int64>();
+ ix_t = ix_orig;
+
+ TensorShape shape({10, 10, 10});
+ std::vector<int64> order{0, 1, 2};
+ SparseTensor st(ix, vals, shape, order);
+
+ st.Reorder<string>(order);
+ EXPECT_FALSE(st.IndicesValid()); // two indices are identical
+
+ ix_orig(1, 2) = 1;
+ ix_t = ix_orig;
+ st.Reorder<string>(order);
+ EXPECT_TRUE(st.IndicesValid()); // second index now (0, 0, 1)
+
+ ix_orig(0, 2) = 1;
+ ix_t = ix_orig;
+ st.Reorder<string>(order);
+ EXPECT_FALSE(st.IndicesValid()); // first index now (0, 0, 1)
+}
+
+TEST(SparseTensorTest, SparseTensorCheckBoundaries) {
+ int N = 5;
+ const int NDIM = 3;
+
+ Tensor ix(DT_INT64, TensorShape({N, NDIM}));
+ Tensor vals(DT_STRING, TensorShape({N}));
+
+ auto ix_t = GetSimpleIndexTensor(N, NDIM);
+
+ ix.matrix<int64>() = ix_t;
+
+ TensorShape shape({10, 10, 10});
+ std::vector<int64> order{0, 1, 2};
+
+ SparseTensor st(ix, vals, shape, order);
+ EXPECT_FALSE(st.IndicesValid());
+
+ st.Reorder<string>(order);
+ EXPECT_TRUE(st.IndicesValid());
+
+ ix_t(0, 0) = 11;
+ ix.matrix<int64>() = ix_t;
+ st.Reorder<string>(order);
+ EXPECT_FALSE(st.IndicesValid());
+
+ ix_t(0, 0) = -1;
+ ix.matrix<int64>() = ix_t;
+ st.Reorder<string>(order);
+ EXPECT_FALSE(st.IndicesValid());
+
+ ix_t(0, 0) = 0;
+ ix.matrix<int64>() = ix_t;
+ st.Reorder<string>(order);
+ EXPECT_TRUE(st.IndicesValid());
+}
+
+TEST(SparseTensorTest, SparseTensorToDenseTensor) {
+ int N = 5;
+ const int NDIM = 3;
+
+ Tensor ix(DT_INT64, TensorShape({N, NDIM}));
+ Tensor vals(DT_STRING, TensorShape({N}));
+
+ auto ix_t = GetSimpleIndexTensor(N, NDIM);
+ auto vals_t = vals.vec<string>();
+
+ ix.matrix<int64>() = ix_t;
+
+ vals_t(0) = "hi0";
+ vals_t(1) = "hi1";
+ vals_t(2) = "hi2";
+ vals_t(3) = "hi3";
+ vals_t(4) = "hi4";
+
+ TensorShape shape({4, 4, 5});
+ std::vector<int64> order{0, 1, 2};
+ SparseTensor st(ix, vals, shape, order);
+
+ Tensor dense(DT_STRING, TensorShape({4, 4, 5}));
+ st.ToDense<string>(&dense);
+
+ auto dense_t = dense.tensor<string, 3>();
+ Eigen::array<Eigen::DenseIndex, NDIM> ix_n;
+ for (int n = 0; n < N; ++n) {
+ for (int d = 0; d < NDIM; ++d) ix_n[d] = ix_t(n, d);
+ EXPECT_EQ(dense_t(ix_n), vals_t(n));
+ }
+
+ // Spot checks on the others
+ EXPECT_EQ(dense_t(0, 0, 1), "");
+ EXPECT_EQ(dense_t(0, 0, 3), "");
+ EXPECT_EQ(dense_t(3, 3, 3), "");
+ EXPECT_EQ(dense_t(3, 3, 4), "");
+}
+
+TEST(SparseTensorTest, SparseTensorToLargerDenseTensor) {
+ int N = 5;
+ const int NDIM = 3;
+
+ Tensor ix(DT_INT64, TensorShape({N, NDIM}));
+ Tensor vals(DT_STRING, TensorShape({N}));
+
+ auto ix_t = GetSimpleIndexTensor(N, NDIM);
+ auto vals_t = vals.vec<string>();
+
+ ix.matrix<int64>() = ix_t;
+
+ vals_t(0) = "hi0";
+ vals_t(1) = "hi1";
+ vals_t(2) = "hi2";
+ vals_t(3) = "hi3";
+ vals_t(4) = "hi4";
+
+ TensorShape shape({4, 4, 5});
+ std::vector<int64> order{0, 1, 2};
+ SparseTensor st(ix, vals, shape, order);
+
+ Tensor dense(DT_STRING, TensorShape({10, 10, 10}));
+ st.ToDense<string>(&dense);
+
+ auto dense_t = dense.tensor<string, 3>();
+ Eigen::array<Eigen::DenseIndex, NDIM> ix_n;
+ for (int n = 0; n < N; ++n) {
+ for (int d = 0; d < NDIM; ++d) ix_n[d] = ix_t(n, d);
+ EXPECT_EQ(dense_t(ix_n), vals_t(n));
+ }
+
+ // Spot checks on the others
+ EXPECT_EQ(dense_t(0, 0, 1), "");
+ EXPECT_EQ(dense_t(0, 0, 3), "");
+ EXPECT_EQ(dense_t(3, 3, 3), "");
+ EXPECT_EQ(dense_t(3, 3, 4), "");
+ EXPECT_EQ(dense_t(9, 0, 0), "");
+ EXPECT_EQ(dense_t(9, 0, 9), "");
+ EXPECT_EQ(dense_t(9, 9, 9), "");
+}
+
+TEST(SparseTensorTest, SparseTensorGroup) {
+ int N = 5;
+ const int NDIM = 3;
+
+ Tensor ix(DT_INT64, TensorShape({N, NDIM}));
+ Tensor vals(DT_INT32, TensorShape({N}));
+
+ auto ix_t = ix.matrix<int64>();
+ auto vals_t = vals.vec<int32>();
+
+ ix_t = GetSimpleIndexTensor(N, NDIM);
+
+ vals_t(0) = 1; // associated with ix (000)
+ vals_t(1) = 2; // associated with ix (300)
+ vals_t(2) = 3; // associated with ix (200)
+ vals_t(3) = 4; // associated with ix (010)
+ vals_t(4) = 5; // associated with ix (002)
+
+ TensorShape shape({10, 10, 10});
+ std::vector<int64> order{0, 1, 2};
+
+ SparseTensor st(ix, vals, shape, order);
+ st.Reorder<int32>(order);
+
+ std::vector<std::vector<int64> > groups;
+ std::vector<TTypes<int64>::UnalignedConstMatrix> grouped_indices;
+ std::vector<TTypes<int32>::UnalignedVec> grouped_values;
+
+ // Group by index 0
+ auto gi = st.group({0});
+
+ // All the hard work is right here!
+ for (const auto& g : gi) {
+ groups.push_back(g.group());
+ VLOG(1) << "Group: " << str_util::Join(g.group(), ",");
+ VLOG(1) << "Indices: " << g.indices();
+ VLOG(1) << "Values: " << g.values<int32>();
+
+ grouped_indices.push_back(g.indices());
+ grouped_values.push_back(g.values<int32>());
+ }
+
+ // Group by dimension 0, we have groups: 0--, 2--, 3--
+ EXPECT_EQ(groups.size(), 3);
+ EXPECT_EQ(groups[0], std::vector<int64>({0}));
+ EXPECT_EQ(groups[1], std::vector<int64>({2}));
+ EXPECT_EQ(groups[2], std::vector<int64>({3}));
+
+ std::vector<Eigen::Tensor<int64, 2, Eigen::RowMajor> > expected_indices;
+ std::vector<Eigen::Tensor<int32, 1, Eigen::RowMajor> > expected_vals;
+
+ // First group: 000, 002, 010
+ expected_indices.emplace_back(3, NDIM); // 3 x 3 tensor
+ expected_vals.emplace_back(3); // 3 x 5 x 1 x 1 tensor
+ expected_indices[0].setZero();
+ expected_indices[0](1, 2) = 2; // 002
+ expected_indices[0](2, 1) = 1; // 010
+ expected_vals[0].setConstant(-1);
+ expected_vals[0](0) = 1; // val associated with ix 000
+ expected_vals[0](1) = 5; // val associated with ix 002
+ expected_vals[0](2) = 4; // val associated with ix 010
+
+ // Second group: 200
+ expected_indices.emplace_back(1, NDIM);
+ expected_vals.emplace_back(1);
+ expected_indices[1].setZero();
+ expected_indices[1](0, 0) = 2; // 200
+ expected_vals[1](0) = 3; // val associated with ix 200
+
+ // Third group: 300
+ expected_indices.emplace_back(1, NDIM);
+ expected_vals.emplace_back(1);
+ expected_indices[2].setZero();
+ expected_indices[2](0, 0) = 3; // 300
+ expected_vals[2](0) = 2; // val associated with ix 300
+
+ for (std::size_t gix = 0; gix < groups.size(); ++gix) {
+ // Compare indices
+ auto gi_t = grouped_indices[gix];
+ Eigen::Tensor<bool, 0, Eigen::RowMajor> eval =
+ (gi_t == expected_indices[gix]).all();
+ EXPECT_TRUE(eval()) << gix << " indices: " << gi_t << " vs. "
+ << expected_indices[gix];
+
+ // Compare values
+ auto gv_t = grouped_values[gix];
+ eval = (gv_t == expected_vals[gix]).all();
+ EXPECT_TRUE(eval()) << gix << " values: " << gv_t << " vs. "
+ << expected_vals[gix];
+ }
+}
+
+TEST(SparseTensorTest, Concat) {
+ int N = 5;
+ const int NDIM = 3;
+
+ Tensor ix(DT_INT64, TensorShape({N, NDIM}));
+ Tensor vals(DT_STRING, TensorShape({N}));
+
+ auto ix_c = GetSimpleIndexTensor(N, NDIM);
+
+ auto ix_t = ix.matrix<int64>();
+ auto vals_t = vals.vec<string>();
+
+ ix_t = ix_c;
+
+ TensorShape shape({10, 10, 10});
+ std::vector<int64> order{0, 1, 2};
+
+ SparseTensor st(ix, vals, shape, order);
+ EXPECT_FALSE(st.IndicesValid());
+ st.Reorder<string>(order);
+ EXPECT_TRUE(st.IndicesValid());
+
+ SparseTensor concatted = SparseTensor::Concat<string>({st, st, st, st});
+ EXPECT_EQ(concatted.order(), st.order());
+ TensorShape expected_shape({40, 10, 10});
+ EXPECT_EQ(concatted.shape(), expected_shape);
+ EXPECT_EQ(concatted.num_entries(), 4 * N);
+ EXPECT_TRUE(concatted.IndicesValid());
+
+ auto conc_ix_t = concatted.indices().matrix<int64>();
+ auto conc_vals_t = concatted.values().vec<string>();
+
+ for (int n = 0; n < 4; ++n) {
+ for (int i = 0; i < N; ++i) {
+ // Dimensions match except the primary dim, which is offset by
+ // shape[order[0]]
+ EXPECT_EQ(conc_ix_t(n * N + i, 0), 10 * n + ix_t(i, 0));
+ EXPECT_EQ(conc_ix_t(n * N + i, 1), ix_t(i, 1));
+ EXPECT_EQ(conc_ix_t(n * N + i, 1), ix_t(i, 1));
+
+ // Values match
+ EXPECT_EQ(conc_vals_t(n * N + i), vals_t(i));
+ }
+ }
+
+ // Concat works if non-primary ix is out of order, but output order
+ // is not defined
+ SparseTensor st_ooo(ix, vals, shape, {0, 2, 1}); // non-primary ix OOO
+ SparseTensor conc_ooo = SparseTensor::Concat<string>({st, st, st, st_ooo});
+ std::vector<int64> expected_ooo{-1, -1, -1};
+ EXPECT_EQ(conc_ooo.order(), expected_ooo);
+ EXPECT_EQ(conc_ooo.shape(), expected_shape);
+ EXPECT_EQ(conc_ooo.num_entries(), 4 * N);
+}
+
+// TODO(ebrevdo): ReduceToDense(R={dim1,dim2,...}, reduce_fn, &output)
+// reduce_fn sees slices of resorted values based on generator (dim: DDIMS), and
+// slices of resorted indices on generator.
+
+} // namespace
+} // namespace sparse
+} // namespace tensorflow
diff --git a/tensorflow/core/util/tensor_slice_reader.cc b/tensorflow/core/util/tensor_slice_reader.cc
new file mode 100644
index 0000000000..00bc16f105
--- /dev/null
+++ b/tensorflow/core/util/tensor_slice_reader.cc
@@ -0,0 +1,230 @@
+#include "tensorflow/core/util/tensor_slice_reader.h"
+
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/gtl/stl_util.h"
+#include "tensorflow/core/lib/io/iterator.h"
+#include "tensorflow/core/lib/io/match.h"
+#include "tensorflow/core/lib/io/table.h"
+#include "tensorflow/core/lib/io/table_options.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/public/env.h"
+#include "tensorflow/core/util/saved_tensor_slice_util.h"
+#include "tensorflow/core/util/tensor_slice_util.h"
+
+namespace tensorflow {
+
+namespace checkpoint {
+
+TensorSliceReader::Table::~Table() {}
+
+namespace {
+class TensorSliceReaderTable : public TensorSliceReader::Table {
+ public:
+ explicit TensorSliceReaderTable(RandomAccessFile* f, table::Table* t)
+ : file_(f), table_(t) {}
+
+ ~TensorSliceReaderTable() override {
+ delete table_;
+ delete file_;
+ }
+
+ bool Get(const string& key, string* value) override {
+ std::unique_ptr<table::Iterator> iter(table_->NewIterator());
+ iter->Seek(key);
+ if (iter->Valid() && iter->key() == key) {
+ StringPiece v = iter->value();
+ value->assign(v.data(), v.size());
+ return true;
+ } else {
+ return false;
+ }
+ }
+
+ private:
+ RandomAccessFile* file_;
+ table::Table* table_;
+};
+} // namespace
+
+Status OpenTableTensorSliceReader(const string& fname,
+ TensorSliceReader::Table** result) {
+ *result = nullptr;
+ Env* env = Env::Default();
+ RandomAccessFile* f = nullptr;
+ Status s = env->NewRandomAccessFile(fname, &f);
+ if (s.ok()) {
+ uint64 file_size;
+ s = env->GetFileSize(fname, &file_size);
+ if (s.ok()) {
+ table::Options options;
+ table::Table* table;
+ s = table::Table::Open(options, f, file_size, &table);
+ if (s.ok()) {
+ *result = new TensorSliceReaderTable(f, table);
+ return Status::OK();
+ } else {
+ s = Status(s.code(),
+ strings::StrCat(s.error_message(),
+ ": perhaps your file is in a different "
+ "file format and you need to use a "
+ "different restore operator?"));
+ }
+ }
+ }
+ LOG(WARNING) << "Could not open " << fname << ": " << s;
+ delete f;
+ return s;
+}
+
+TensorSliceReader::TensorSliceReader(const string& filepattern,
+ OpenTableFunction open_function)
+ : TensorSliceReader(filepattern, open_function, kLoadAllShards) {}
+
+TensorSliceReader::TensorSliceReader(const string& filepattern,
+ OpenTableFunction open_function,
+ int preferred_shard)
+ : filepattern_(filepattern), open_function_(open_function) {
+ VLOG(1) << "TensorSliceReader for " << filepattern;
+ Status s = io::GetMatchingFiles(Env::Default(), filepattern, &fnames_);
+ if (!s.ok()) {
+ status_ = errors::InvalidArgument(
+ "Unsuccessful TensorSliceReader constructor: "
+ "Failed to get matching files on ",
+ filepattern, ": ", s.ToString());
+ return;
+ }
+ if (fnames_.empty()) {
+ status_ = errors::NotFound(
+ "Unsuccessful TensorSliceReader constructor: "
+ "Failed to find any matching files for ",
+ filepattern);
+ return;
+ }
+ sss_.resize(fnames_.size());
+ for (size_t shard = 0; shard < fnames_.size(); ++shard) {
+ fname_to_index_.insert(std::make_pair(fnames_[shard], shard));
+ }
+ if (preferred_shard == kLoadAllShards || fnames_.size() == 1 ||
+ static_cast<size_t>(preferred_shard) >= fnames_.size()) {
+ LoadAllShards();
+ } else {
+ VLOG(1) << "Loading shard " << preferred_shard << " for " << filepattern_;
+ LoadShard(preferred_shard);
+ }
+}
+
+void TensorSliceReader::LoadShard(int shard) const {
+ CHECK_LT(shard, sss_.size());
+ if (sss_[shard] || !status_.ok()) {
+ return; // Already loaded, or invalid.
+ }
+ string value;
+ SavedTensorSlices sts;
+ const string fname = fnames_[shard];
+ VLOG(1) << "Reading meta data from file " << fname << "...";
+ Table* table;
+ Status s = open_function_(fname, &table);
+ if (!s.ok()) {
+ status_ = errors::DataLoss("Unable to open table file ", fname, ": ",
+ s.ToString());
+ return;
+ }
+ sss_[shard].reset(table);
+ if (!(table->Get(kSavedTensorSlicesKey, &value) &&
+ ParseProtoUnlimited(&sts, value))) {
+ status_ = errors::Internal(
+ "Failed to find the saved tensor slices at the beginning of the "
+ "checkpoint file: ",
+ fname);
+ return;
+ }
+ for (const SavedSliceMeta& ssm : sts.meta().tensor()) {
+ TensorShape ssm_shape(ssm.shape());
+ for (const TensorSliceProto& tsp : ssm.slice()) {
+ TensorSlice ss_slice(tsp);
+ RegisterTensorSlice(ssm.name(), ssm_shape, ssm.type(), fname, ss_slice);
+ }
+ }
+}
+
+void TensorSliceReader::LoadAllShards() const {
+ VLOG(1) << "Loading all shards for " << filepattern_;
+ for (size_t i = 0; i < fnames_.size() && status_.ok(); ++i) {
+ LoadShard(i);
+ }
+ all_shards_loaded_ = true;
+}
+
+const TensorSliceSet* TensorSliceReader::FindTensorSlice(
+ const string& name, const TensorSlice& slice,
+ std::vector<std::pair<TensorSlice, string>>* details) const {
+ const TensorSliceSet* tss = gtl::FindPtrOrNull(tensors_, name);
+ if (tss && !tss->QueryMeta(slice, details)) {
+ return nullptr;
+ }
+ return tss;
+}
+
+TensorSliceReader::~TensorSliceReader() { gtl::STLDeleteValues(&tensors_); }
+
+void TensorSliceReader::RegisterTensorSlice(const string& name,
+ const TensorShape& shape,
+ DataType type, const string& tag,
+ const TensorSlice& slice) const {
+ TensorSliceSet* tss = gtl::FindPtrOrNull(tensors_, name);
+ // Create a tensor slice set if needed
+ if (!tss) {
+ tss = new TensorSliceSet(shape, type);
+ tensors_.insert(std::make_pair(name, tss));
+ } else {
+ // Check if the shapes match
+ TensorShape tss_shape(tss->shape());
+ if (!shape.IsSameSize(tss_shape)) {
+ status_ =
+ errors::Internal("Incompatible tensor shapes detected for tensor ",
+ name, ": existing = ", tss_shape.DebugString(),
+ ", new = ", shape.DebugString());
+ return;
+ }
+ if (type != tss->type()) {
+ status_ =
+ errors::Internal("Incompatible tensor types detected for tensor ",
+ name, ": existing = ", DataTypeString(tss->type()),
+ ", new = ", DataTypeString(type));
+ return;
+ }
+ }
+ // Register the tensor slices without the actual data.
+ Status s = tss->Register(slice, tag, nullptr);
+ if (!s.ok()) {
+ status_ = s;
+ }
+}
+
+bool TensorSliceReader::HasTensor(const string& name, TensorShape* shape,
+ DataType* type) const {
+ mutex_lock l(mu_);
+ const TensorSliceSet* tss = gtl::FindPtrOrNull(tensors_, name);
+ if (!tss && !all_shards_loaded_) {
+ VLOG(1) << "Did not find tensor in preferred shard, loading all shards: "
+ << name;
+ LoadAllShards();
+ tss = gtl::FindPtrOrNull(tensors_, name);
+ }
+ if (tss) {
+ if (shape) {
+ *shape = tss->shape();
+ }
+ if (type) {
+ *type = tss->type();
+ }
+ return true;
+ } else {
+ return false;
+ }
+}
+
+} // namespace checkpoint
+
+} // namespace tensorflow
diff --git a/tensorflow/core/util/tensor_slice_reader.h b/tensorflow/core/util/tensor_slice_reader.h
new file mode 100644
index 0000000000..b5f26a689b
--- /dev/null
+++ b/tensorflow/core/util/tensor_slice_reader.h
@@ -0,0 +1,157 @@
+// The utility to read checkpoints for google brain tensor ops and v3
+// checkpoints for dist_belief.
+//
+
+#ifndef TENSORFLOW_UTIL_TENSOR_SLICE_READER_H_
+#define TENSORFLOW_UTIL_TENSOR_SLICE_READER_H_
+
+#include <unordered_map>
+
+#include "tensorflow/core/framework/tensor_slice.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/public/status.h"
+#include "tensorflow/core/public/tensor_shape.h"
+#include "tensorflow/core/util/saved_tensor_slice.pb.h"
+#include "tensorflow/core/util/saved_tensor_slice_util.h"
+#include "tensorflow/core/util/tensor_slice_set.h"
+#include "tensorflow/core/util/tensor_slice_util.h"
+
+namespace tensorflow {
+
+namespace checkpoint {
+
+// The reader reads in all the meta data about all the tensor slices. Then it
+// will try to read the relevant data on-demand to produce the data for the
+// slices needed.
+// NOTE(yangke): another way to do this is to first load a list of the tensor
+// slices needed and then just selectively read some of the meta data. That
+// might optimize the loading but makes the logic a bit more complicated. We
+// might want to revisit that.
+// TODO(yangke): consider moving to TensorProto.
+class TensorSliceReader {
+ public:
+ // Abstract interface for reading data out of a tensor slice checkpoint file
+ class Table {
+ public:
+ virtual ~Table();
+ virtual bool Get(const string& key, string* value) = 0;
+ };
+ typedef std::function<Status(const string&, Table**)> OpenTableFunction;
+
+ static const int kLoadAllShards = -1;
+ TensorSliceReader(const string& filepattern, OpenTableFunction open_function);
+ TensorSliceReader(const string& filepattern, OpenTableFunction open_function,
+ int preferred_shard);
+ virtual ~TensorSliceReader();
+
+ // Get the filename this reader is attached to.
+ const string& filepattern() const { return filepattern_; }
+
+ // Get the number of files matched.
+ int num_files() const { return sss_.size(); }
+
+ // Get the status of the reader.
+ const Status status() const { return status_; }
+
+ // Checks if the reader contains any slice of a tensor. In case the reader
+ // does contain the tensor, if "shape" is not nullptr, fill "shape" with the
+ // shape of the tensor; if "type" is not nullptr, fill "type" with the type
+ // of the tensor.
+ bool HasTensor(const string& name, TensorShape* shape, DataType* type) const;
+
+ // Checks if the reader contains all the data about a tensor slice, and if
+ // yes, copies the data of the slice to "data". The caller needs to make sure
+ // that "data" points to a buffer that holds enough data.
+ // This is a slow function since it needs to read sstables.
+ template <typename T>
+ bool CopySliceData(const string& name, const TensorSlice& slice,
+ T* data) const;
+
+ // Get the tensors.
+ const std::unordered_map<string, TensorSliceSet*>& Tensors() const {
+ return tensors_;
+ }
+
+ private:
+ friend class TensorSliceWriteTestHelper;
+
+ void LoadShard(int shard) const;
+ void LoadAllShards() const;
+ void RegisterTensorSlice(const string& name, const TensorShape& shape,
+ DataType type, const string& tag,
+ const TensorSlice& slice) const;
+
+ const TensorSliceSet* FindTensorSlice(
+ const string& name, const TensorSlice& slice,
+ std::vector<std::pair<TensorSlice, string>>* details) const;
+
+ const string filepattern_;
+ const OpenTableFunction open_function_;
+ std::vector<string> fnames_;
+ std::unordered_map<string, int> fname_to_index_;
+
+ // Guards the attributes below.
+ mutable mutex mu_;
+ mutable bool all_shards_loaded_ = false;
+ mutable std::vector<std::unique_ptr<Table>> sss_;
+ mutable std::unordered_map<string, TensorSliceSet*> tensors_;
+ mutable Status status_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(TensorSliceReader);
+};
+
+Status OpenTableTensorSliceReader(const string& fname,
+ TensorSliceReader::Table** table);
+
+template <typename T>
+bool TensorSliceReader::CopySliceData(const string& name,
+ const TensorSlice& slice, T* data) const {
+ std::vector<std::pair<TensorSlice, string>> details;
+ const TensorSliceSet* tss;
+ {
+ mutex_lock l(mu_);
+ tss = FindTensorSlice(name, slice, &details);
+ if (!tss && !all_shards_loaded_) {
+ VLOG(1) << "Did not find slice in preferred shard, loading all shards."
+ << name << ": " << slice.DebugString();
+ LoadAllShards();
+ tss = FindTensorSlice(name, slice, &details);
+ }
+ if (!tss) {
+ // No such tensor
+ return false;
+ }
+ }
+ // We have the data -- copy it over.
+ string value;
+ for (const auto& x : details) {
+ const TensorSlice& slice_s = x.first;
+ const string& fname = x.second;
+ int idx = gtl::FindWithDefault(fname_to_index_, fname, -1);
+ CHECK_GE(idx, 0) << "Failed to find the index for filename " << fname;
+ // We read a record in the corresponding sstable
+ const string key = EncodeTensorNameSlice(name, slice_s);
+ CHECK(sss_[idx]->Get(key, &value))
+ << "Failed to seek to the record for tensor " << name << ", slice "
+ << slice_s.DebugString() << ": computed key = " << key;
+ SavedTensorSlices sts;
+ CHECK(ParseProtoUnlimited(&sts, value))
+ << "Failed to parse the record for tensor " << name << ", slice "
+ << slice_s.DebugString() << ": computed key = " << key;
+ CopyDataFromTensorSliceToTensorSlice(
+ tss->shape(), slice_s, slice,
+ checkpoint::TensorProtoData<T>(sts.data().data()), data);
+ }
+ return true;
+}
+
+} // namespace checkpoint
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_UTIL_TENSOR_SLICE_READER_H_
diff --git a/tensorflow/core/util/tensor_slice_reader_cache.cc b/tensorflow/core/util/tensor_slice_reader_cache.cc
new file mode 100644
index 0000000000..af81d0115e
--- /dev/null
+++ b/tensorflow/core/util/tensor_slice_reader_cache.cc
@@ -0,0 +1,94 @@
+#include "tensorflow/core/util/tensor_slice_reader_cache.h"
+
+#include "tensorflow/core/lib/gtl/stl_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace tensorflow {
+
+namespace checkpoint {
+
+TensorSliceReaderCacheWrapper::TensorSliceReaderCacheWrapper() {}
+TensorSliceReaderCacheWrapper::~TensorSliceReaderCacheWrapper() {
+ if (cache_) {
+ delete cache_;
+ }
+ cache_ = nullptr;
+}
+
+const TensorSliceReader* TensorSliceReaderCacheWrapper::GetReader(
+ const string& filepattern,
+ TensorSliceReader::OpenTableFunction open_function,
+ int preferred_shard) const {
+ mutex_lock l(mu_);
+ if (!cache_) {
+ cache_ = new TensorSliceReaderCache;
+ }
+ return cache_->GetReader(filepattern, open_function, preferred_shard);
+}
+
+TensorSliceReaderCache::TensorSliceReaderCache() {}
+
+TensorSliceReaderCache::~TensorSliceReaderCache() {
+ for (auto pair : readers_) {
+ delete pair.second.second;
+ }
+}
+
+const TensorSliceReader* TensorSliceReaderCache::GetReader(
+ const string& filepattern,
+ TensorSliceReader::OpenTableFunction open_function, int preferred_shard) {
+ mutex_lock l(mu_);
+
+ // Get the function pointer from the open_function value.
+ TensorSliceReaderCache::OpenFuncType* func_ptr =
+ open_function.target<TensorSliceReaderCache::OpenFuncType>();
+ if (!func_ptr) {
+ // We could not get the pointer, no caching is possible.
+ LOG(WARNING) << "Caching disabled because the open function is a lambda.";
+ return nullptr;
+ }
+
+ // Wait if another thread is already trying to open the same files.
+ while (still_opening_.find(filepattern) != still_opening_.end()) {
+ cv_.wait(l);
+ }
+
+ TensorSliceReader* reader = nullptr;
+ if (readers_.find(filepattern) == readers_.end()) {
+ VLOG(1) << "Creating new TensorSliceReader for " << filepattern;
+ still_opening_.insert(filepattern);
+ // Release the lock temporary as constructing TensorSliceReader is
+ // expensive.
+ mu_.unlock();
+ TensorSliceReader* tmp_reader(
+ new TensorSliceReader(filepattern, open_function, preferred_shard));
+ // Acquire the lock again.
+ mu_.lock();
+ if (tmp_reader->status().ok()) {
+ reader = tmp_reader;
+ readers_[filepattern] = make_pair(*func_ptr, reader);
+ } else {
+ delete tmp_reader;
+ }
+ CHECK_EQ(1, still_opening_.erase(filepattern));
+ VLOG(1) << "Cached TensorSliceReader for " << filepattern << ": " << reader;
+ } else {
+ auto cached_val = readers_[filepattern];
+ if (cached_val.first == *func_ptr) {
+ reader = cached_val.second;
+ VLOG(1) << "Using cached TensorSliceReader for " << filepattern << ": "
+ << reader;
+ } else {
+ LOG(WARNING) << "Caching disabled because the checkpoint file "
+ << "is being opened with two different open functions: "
+ << filepattern;
+ }
+ }
+
+ cv_.notify_all();
+ return reader;
+}
+
+} // namespace checkpoint
+
+} // namespace tensorflow
diff --git a/tensorflow/core/util/tensor_slice_reader_cache.h b/tensorflow/core/util/tensor_slice_reader_cache.h
new file mode 100644
index 0000000000..eaeeeec83f
--- /dev/null
+++ b/tensorflow/core/util/tensor_slice_reader_cache.h
@@ -0,0 +1,73 @@
+// The utility to read checkpoints for google brain tensor ops and v3
+// checkpoints for dist_belief.
+//
+
+#ifndef TENSORFLOW_UTIL_TENSOR_SLICE_READER_CACHE_H_
+#define TENSORFLOW_UTIL_TENSOR_SLICE_READER_CACHE_H_
+
+#include <unordered_map>
+
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/util/tensor_slice_reader.h"
+#include "tensorflow/core/public/status.h"
+
+namespace tensorflow {
+
+namespace checkpoint {
+
+class TensorSliceReaderCache;
+
+// Wrapper to a lazily allocated TensorSliceReaderCache.
+class TensorSliceReaderCacheWrapper {
+ public:
+ TensorSliceReaderCacheWrapper();
+ ~TensorSliceReaderCacheWrapper();
+
+ // Same as TensorSliceReaderCache::GetReader().
+ const TensorSliceReader* GetReader(
+ const string& filepattern,
+ TensorSliceReader::OpenTableFunction open_function,
+ int preferred_shard) const;
+
+ private:
+ mutable mutex mu_;
+ mutable TensorSliceReaderCache* cache_ = nullptr;
+};
+
+// A cache of TensorSliceReaders.
+class TensorSliceReaderCache {
+ public:
+ TensorSliceReaderCache();
+ ~TensorSliceReaderCache();
+
+ // Returns the TensorSliceReader corresponding to 'filepattern' and the
+ // open_function. May return nullptr if we can not create a new
+ // TensorSliceReader for the filepattern/open_function combination.
+ const TensorSliceReader* GetReader(
+ const string& filepattern,
+ TensorSliceReader::OpenTableFunction open_function, int preferred_shard);
+
+ private:
+ // Need to use a regular function type in the key map as std::function does
+ // not support ==.
+ typedef Status (*OpenFuncType)(const string&, TensorSliceReader::Table**);
+
+ // Protects attributes below.
+ mutex mu_;
+
+ // Maps of opened readers.
+ std::unordered_map<string, std::pair<OpenFuncType, TensorSliceReader*>>
+ readers_;
+
+ // Set of keys that a previous GetReader() call is still trying to populate.
+ std::set<string> still_opening_;
+
+ // Condition variable to notify when a reader has been created.
+ condition_variable cv_;
+};
+
+} // namespace checkpoint
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_UTIL_TENSOR_SLICE_READER_CACHE_H_
diff --git a/tensorflow/core/util/tensor_slice_reader_test.cc b/tensorflow/core/util/tensor_slice_reader_test.cc
new file mode 100644
index 0000000000..e14b920003
--- /dev/null
+++ b/tensorflow/core/util/tensor_slice_reader_test.cc
@@ -0,0 +1,395 @@
+#include "tensorflow/core/util/tensor_slice_reader.h"
+
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/io/path.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/util/saved_tensor_slice_util.h"
+#include "tensorflow/core/util/tensor_slice_writer.h"
+#include "tensorflow/core/util/tensor_slice_reader_cache.h"
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+
+namespace checkpoint {
+
+namespace {
+
+// A simple test where we write a few tensor slices with a number of tensor
+// slice writers and then read them back from a tensor slice reader.
+//
+// We have a 2-d tensor of shape 4 X 5 that looks like this:
+//
+// 0 1 2 3 4
+// 5 6 7 8 9
+// 10 11 12 13 14
+// 15 16 17 18 19
+//
+// We assume this is a row-major matrix.
+
+void SimpleFloatHelper(TensorSliceWriter::CreateBuilderFunction create_function,
+ TensorSliceReader::OpenTableFunction open_function) {
+ const string fname_base = io::JoinPath(testing::TmpDir(), "float_checkpoint");
+
+ TensorShape shape({4, 5});
+
+ // File #0 contains a slice that is the top two rows:
+ //
+ // 0 1 2 3 4
+ // 5 6 7 8 9
+ // . . . . .
+ // . . . . .
+ {
+ const string fname = strings::StrCat(fname_base, "_0");
+ TensorSliceWriter writer(fname, create_function);
+ const float data[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
+ TensorSlice slice = TensorSlice::ParseOrDie("0,2:-");
+ TF_CHECK_OK(writer.Add("test", shape, slice, data));
+ TF_CHECK_OK(writer.Finish());
+ }
+
+ // File #1 contains two slices:
+ //
+ // slice #0 is the bottom left corner
+ // . . . . .
+ // . . . . .
+ // 10 11 12 . .
+ // 15 16 17 . .
+ //
+ // slice #1 is the bottom right corner
+ // . . . . .
+ // . . . . .
+ // . . . . .
+ // . . . 18 19
+ {
+ const string fname = strings::StrCat(fname_base, "_1");
+ TensorSliceWriter writer(fname, create_function);
+ // slice #0
+ {
+ const float data[] = {10, 11, 12, 15, 16, 17};
+ TensorSlice slice = TensorSlice::ParseOrDie("2,2:0,3");
+ TF_CHECK_OK(writer.Add("test", shape, slice, data));
+ }
+ // slice #1
+ {
+ const float data[] = {18, 19};
+ TensorSlice slice = TensorSlice::ParseOrDie("3,1:3,2");
+ TF_CHECK_OK(writer.Add("test", shape, slice, data));
+ }
+ TF_CHECK_OK(writer.Finish());
+ }
+
+ // Notice that we leave a hole in the tensor
+ // . . . . .
+ // . . . . .
+ // . . . (13) (14)
+ // . . . . .
+
+ // Now we need to read the tensor slices
+ const string filepattern = strings::StrCat(fname_base, "_*");
+ TensorSliceReader reader(filepattern, open_function);
+ EXPECT_OK(reader.status());
+ EXPECT_EQ(2, reader.num_files());
+
+ // We query some of the tensors
+ {
+ TensorShape shape;
+ DataType type;
+ EXPECT_TRUE(reader.HasTensor("test", &shape, &type));
+ EXPECT_EQ(
+ "dim { size: 4 } "
+ "dim { size: 5 }",
+ shape.DebugString());
+ EXPECT_EQ(DT_FLOAT, type);
+ EXPECT_FALSE(reader.HasTensor("don't exist", nullptr, nullptr));
+ }
+
+ // Now we query some slices
+ //
+ // Slice #1 is an exact match
+ // 0 1 2 3 4
+ // 5 6 7 8 9
+ // . . . . .
+ // . . . . .
+ {
+ TensorSlice s = TensorSlice::ParseOrDie("0,2:-");
+ float expected[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
+ float results[10];
+ EXPECT_TRUE(reader.CopySliceData("test", s, results));
+ for (int i = 0; i < 10; ++i) {
+ EXPECT_EQ(expected[i], results[i]);
+ }
+ }
+
+ // Slice #2 is a subset match
+ // . . . . .
+ // 5 6 7 8 9
+ // . . . . .
+ // . . . . .
+ {
+ TensorSlice s = TensorSlice::ParseOrDie("1,1:-");
+ float expected[] = {5, 6, 7, 8, 9};
+ float results[5];
+ EXPECT_TRUE(reader.CopySliceData("test", s, results));
+ for (int i = 0; i < 5; ++i) {
+ EXPECT_EQ(expected[i], results[i]);
+ }
+ }
+
+ // Slice #4 includes the hole and so there is no match
+ // . . . . .
+ // . . 7 8 9
+ // . . 12 13 14
+ // . . . . .
+ {
+ TensorSlice s = TensorSlice::ParseOrDie("1,2:2,3");
+ float results[6];
+ EXPECT_FALSE(reader.CopySliceData("test", s, results));
+ }
+}
+
+TEST(TensorSliceReaderTest, SimpleFloat) {
+ SimpleFloatHelper(CreateTableTensorSliceBuilder, OpenTableTensorSliceReader);
+}
+
+template <typename T, typename U>
+void SimpleIntXHelper(TensorSliceWriter::CreateBuilderFunction create_function,
+ TensorSliceReader::OpenTableFunction open_function,
+ const string& checkpoint_file) {
+ const string fname_base = io::JoinPath(testing::TmpDir(), checkpoint_file);
+
+ TensorShape shape({4, 5});
+
+ // File #0 contains a slice that is the top two rows:
+ //
+ // 0 1 2 3 4
+ // 5 6 7 8 9
+ // . . . . .
+ // . . . . .
+ {
+ const string fname = strings::StrCat(fname_base, "_0");
+ TensorSliceWriter writer(fname, create_function);
+ const T data[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
+ TensorSlice slice = TensorSlice::ParseOrDie("0,2:-");
+ TF_CHECK_OK(writer.Add("test", shape, slice, data));
+ TF_CHECK_OK(writer.Finish());
+ }
+
+ // File #1 contains two slices:
+ //
+ // slice #0 is the bottom left corner
+ // . . . . .
+ // . . . . .
+ // 10 11 12 . .
+ // 15 16 17 . .
+ //
+ // slice #1 is the bottom right corner
+ // . . . . .
+ // . . . . .
+ // . . . . .
+ // . . . 18 19
+ {
+ const string fname = strings::StrCat(fname_base, "_1");
+ TensorSliceWriter writer(fname, create_function);
+ // slice #0
+ {
+ const T data[] = {10, 11, 12, 15, 16, 17};
+ TensorSlice slice = TensorSlice::ParseOrDie("2,2:0,3");
+ TF_CHECK_OK(writer.Add("test", shape, slice, data));
+ }
+ // slice #1
+ {
+ const T data[] = {18, 19};
+ TensorSlice slice = TensorSlice::ParseOrDie("3,1:3,2");
+ TF_CHECK_OK(writer.Add("test", shape, slice, data));
+ }
+ TF_CHECK_OK(writer.Finish());
+ }
+
+ // Notice that we leave a hole in the tensor
+ // . . . . .
+ // . . . . .
+ // . . . (13) (14)
+ // . . . . .
+
+ // Now we need to read the tensor slices
+ const string filepattern = strings::StrCat(fname_base, "_*");
+ TensorSliceReader reader(filepattern, open_function);
+ EXPECT_OK(reader.status());
+ EXPECT_EQ(2, reader.num_files());
+
+ // We query some of the tensors
+ {
+ TensorShape shape;
+ DataType type;
+ EXPECT_TRUE(reader.HasTensor("test", &shape, &type));
+ EXPECT_EQ(
+ "dim { size: 4 } "
+ "dim { size: 5 }",
+ shape.DebugString());
+ EXPECT_EQ(DataTypeToEnum<T>::v(), type);
+ EXPECT_FALSE(reader.HasTensor("don't exist", nullptr, nullptr));
+ }
+
+ // Now we query some slices
+ //
+ // Slice #1 is an exact match
+ // 0 1 2 3 4
+ // 5 6 7 8 9
+ // . . . . .
+ // . . . . .
+ {
+ TensorSlice s = TensorSlice::ParseOrDie("0,2:-");
+ T expected[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
+ U results[10];
+ EXPECT_TRUE(reader.CopySliceData("test", s, results));
+ for (int i = 0; i < 10; ++i) {
+ EXPECT_EQ(expected[i], results[i]);
+ }
+ }
+
+ // Slice #2 is a subset match
+ // . . . . .
+ // 5 6 7 8 9
+ // . . . . .
+ // . . . . .
+ {
+ TensorSlice s = TensorSlice::ParseOrDie("1,1:-");
+ T expected[] = {5, 6, 7, 8, 9};
+ U results[5];
+ EXPECT_TRUE(reader.CopySliceData("test", s, results));
+ for (int i = 0; i < 5; ++i) {
+ EXPECT_EQ(expected[i], results[i]);
+ }
+ }
+
+ // Slice #4 includes the hole and so there is no match
+ // . . . . .
+ // . . 7 8 9
+ // . . 12 13 14
+ // . . . . .
+ {
+ TensorSlice s = TensorSlice::ParseOrDie("1,2:2,3");
+ U results[6];
+ EXPECT_FALSE(reader.CopySliceData("test", s, results));
+ }
+}
+
+#define TEST_SIMPLE_INT(TYPE, SAVED_TYPE) \
+ TEST(TensorSliceReaderTest, Simple##TYPE) { \
+ SimpleIntXHelper<TYPE, SAVED_TYPE>(CreateTableTensorSliceBuilder, \
+ OpenTableTensorSliceReader, \
+ #TYPE "_checkpoint"); \
+ }
+
+TEST_SIMPLE_INT(int32, int32)
+TEST_SIMPLE_INT(int64, int64)
+TEST_SIMPLE_INT(int16, int32)
+TEST_SIMPLE_INT(int8, int32)
+TEST_SIMPLE_INT(uint8, int32)
+
+void CachedTensorSliceReaderTesterHelper(
+ TensorSliceWriter::CreateBuilderFunction create_function,
+ TensorSliceReader::OpenTableFunction open_function) {
+ const string fname_base = io::JoinPath(testing::TmpDir(), "float_checkpoint");
+
+ TensorShape shape({4, 5});
+
+ // File #0 contains a slice that is the top two rows:
+ //
+ // 0 1 2 3 4
+ // 5 6 7 8 9
+ // . . . . .
+ // . . . . .
+ {
+ const string fname = strings::StrCat(fname_base, "_0");
+ TensorSliceWriter writer(fname, create_function);
+ const float data[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
+ TensorSlice slice = TensorSlice::ParseOrDie("0,2:-");
+ TF_CHECK_OK(writer.Add("test", shape, slice, data));
+ TF_CHECK_OK(writer.Finish());
+ }
+
+ // File #1 contains two slices:
+ //
+ // slice #0 is the bottom left corner
+ // . . . . .
+ // . . . . .
+ // 10 11 12 . .
+ // 15 16 17 . .
+ //
+ // slice #1 is the bottom right corner
+ // . . . . .
+ // . . . . .
+ // . . . . .
+ // . . . 18 19
+ {
+ const string fname = strings::StrCat(fname_base, "_1");
+ TensorSliceWriter writer(fname, create_function);
+ // slice #0
+ {
+ const float data[] = {10, 11, 12, 15, 16, 17};
+ TensorSlice slice = TensorSlice::ParseOrDie("2,2:0,3");
+ TF_CHECK_OK(writer.Add("test", shape, slice, data));
+ }
+ // slice #1
+ {
+ const float data[] = {18, 19};
+ TensorSlice slice = TensorSlice::ParseOrDie("3,1:3,2");
+ TF_CHECK_OK(writer.Add("test", shape, slice, data));
+ }
+ TF_CHECK_OK(writer.Finish());
+ }
+
+ // Notice that we leave a hole in the tensor
+ // . . . . .
+ // . . . . .
+ // . . . (13) (14)
+ // . . . . .
+
+ // Now we need to read the tensor slices
+ TensorSliceReaderCache cache;
+ const string filepattern = strings::StrCat(fname_base, "_*");
+ const TensorSliceReader* reader = cache.GetReader(
+ filepattern, open_function, TensorSliceReader::kLoadAllShards);
+ EXPECT_TRUE(reader != nullptr);
+ EXPECT_EQ(2, reader->num_files());
+
+ // We query some of the tensors
+ {
+ TensorShape shape;
+ DataType type;
+ EXPECT_TRUE(reader->HasTensor("test", &shape, &type));
+ EXPECT_EQ(
+ "dim { size: 4 } "
+ "dim { size: 5 }",
+ shape.DebugString());
+ EXPECT_EQ(DT_FLOAT, type);
+ EXPECT_FALSE(reader->HasTensor("don't exist", nullptr, nullptr));
+ }
+
+ // Make sure the reader is cached.
+ const TensorSliceReader* reader2 = cache.GetReader(
+ filepattern, open_function, TensorSliceReader::kLoadAllShards);
+ EXPECT_EQ(reader, reader2);
+
+ reader = cache.GetReader("file_does_not_exist", open_function,
+ TensorSliceReader::kLoadAllShards);
+ EXPECT_TRUE(reader == nullptr);
+}
+
+TEST(CachedTensorSliceReaderTest, SimpleFloat) {
+ CachedTensorSliceReaderTesterHelper(CreateTableTensorSliceBuilder,
+ OpenTableTensorSliceReader);
+}
+
+} // namespace
+
+} // namespace checkpoint
+
+} // namespace tensorflow
diff --git a/tensorflow/core/util/tensor_slice_set.cc b/tensorflow/core/util/tensor_slice_set.cc
new file mode 100644
index 0000000000..765686f189
--- /dev/null
+++ b/tensorflow/core/util/tensor_slice_set.cc
@@ -0,0 +1,148 @@
+#include "tensorflow/core/util/tensor_slice_set.h"
+
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/util/tensor_slice_util.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
+
+namespace tensorflow {
+
+namespace checkpoint {
+
+TensorSliceSet::TensorSliceSet(const TensorShape& shape, DataType type)
+ : shape_(shape), type_(type) {}
+
+TensorSliceSet::~TensorSliceSet() {}
+
+Status TensorSliceSet::Register(const TensorSlice& slice,
+ const string& tag, const float* data) {
+ TensorShape result_shape;
+ TF_RETURN_IF_ERROR(slice.SliceTensorShape(shape_, &result_shape));
+ string str = slice.DebugString();
+ // We check if there is any intersection between this slice and any of the
+ // registered slices.
+ for (const auto x : slices_) {
+ if (slice.Overlaps(x.second.slice)) {
+ return errors::Internal("Overlapping slices: existing slice = ", x.first,
+ ", new slice = ", str);
+ }
+ }
+ // No overlap: we can now insert the slice
+ TensorSliceSet::SliceInfo info = {slice, tag, data,
+ result_shape.num_elements()};
+ slices_.insert(std::make_pair(str, info));
+ return Status::OK();
+}
+
+// TODO(yangke): merge Query() with QueryMeta()
+bool TensorSliceSet::Query(const TensorSlice& slice, float* data) const {
+ Status s;
+ string str = slice.DebugString();
+ // First we check if there is an exactly match (this is the dominant case).
+ const TensorSliceSet::SliceInfo* info = gtl::FindOrNull(slices_, str);
+ if (info) {
+ if (data) {
+ std::copy_n(info->data, info->num_floats, data);
+ }
+ return true;
+ } else {
+ // We didn't find any exact match but there is still a posibility that
+ // mutliple existing slices can be patched together to output the slice.
+ // We figure this out by computing the intersection of each of the existing
+ // slices with the query slice, and check if the union of all these
+ // intersections cover the entire slice. We rely on the fact that the
+ // existing slices don't have any intersection among themselves.
+ TensorShape target_shape;
+ Status s;
+ s = slice.SliceTensorShape(shape_, &target_shape);
+ if (!s.ok()) {
+ LOG(WARNING) << s;
+ return false;
+ }
+ int64 total_size = target_shape.num_elements();
+
+ int64 overlap_size = 0;
+ TensorSlice intersection;
+ TensorShape inter_shape;
+ for (const auto x : slices_) {
+ if (slice.Intersect(x.second.slice, &intersection)) {
+ s = intersection.SliceTensorShape(shape_, &inter_shape);
+ if (!s.ok()) {
+ LOG(WARNING) << s;
+ return false;
+ }
+ overlap_size += inter_shape.num_elements();
+ }
+ }
+ if (total_size == overlap_size) {
+ // We have it!
+ // Now we need to copy the data to "data"
+ if (data) {
+ for (const auto x : slices_) {
+ CopyDataFromTensorSliceToTensorSlice(shape_, x.second.slice, slice,
+ x.second.data, data);
+ }
+ }
+ return true;
+ } else {
+ // We don't have all the data for the asked tensor slice
+ return false;
+ }
+ }
+}
+
+bool TensorSliceSet::QueryMeta(
+ const TensorSlice& slice,
+ std::vector<std::pair<TensorSlice, string>>* results) const {
+ results->clear();
+ Status s;
+ string str = slice.DebugString();
+ // First we check if there is an exactly match (this is the dominant case).
+ const TensorSliceSet::SliceInfo* info = gtl::FindOrNull(slices_, str);
+ if (info) {
+ results->emplace_back(std::make_pair(info->slice, info->tag));
+ return true;
+ } else {
+ // We didn't find any exact match but there is still a posibility that
+ // multiple existing slices can be patched together to output the slice.
+ // We figure this out by computing the intersection of each of the existing
+ // slices with the query slice, and check if the union of all these
+ // intersections cover the entire slice. We rely on the fact that the
+ // existing slices don't have any intersection among themselves.
+ TensorShape target_shape;
+ Status s;
+ s = slice.SliceTensorShape(shape_, &target_shape);
+ if (!s.ok()) {
+ LOG(WARNING) << s;
+ return false;
+ }
+ int64 total_size = target_shape.num_elements();
+
+ int64 overlap_size = 0;
+ TensorSlice intersection;
+ TensorShape inter_shape;
+ for (const auto x : slices_) {
+ if (slice.Intersect(x.second.slice, &intersection)) {
+ s = intersection.SliceTensorShape(shape_, &inter_shape);
+ if (!s.ok()) {
+ LOG(WARNING) << s;
+ return false;
+ }
+ overlap_size += inter_shape.num_elements();
+ results->emplace_back(std::make_pair(x.second.slice, x.second.tag));
+ }
+ }
+ if (total_size == overlap_size) {
+ // We have it!
+ return true;
+ } else {
+ // We don't have all the data for the asked tensor slice
+ results->clear();
+ return false;
+ }
+ }
+}
+
+} // namespace checkpoint
+
+} // namespace tensorflow
diff --git a/tensorflow/core/util/tensor_slice_set.h b/tensorflow/core/util/tensor_slice_set.h
new file mode 100644
index 0000000000..f3f7ac0e76
--- /dev/null
+++ b/tensorflow/core/util/tensor_slice_set.h
@@ -0,0 +1,73 @@
+// A class to manage slices of a tensor. You can "register" set of slices for a
+// tensor and then "query" if we have data for a given slice.
+
+// TODO(yangke): consider moving it to a more private place so that we don't
+// need to expose the API.
+
+#ifndef TENSORFLOW_UTIL_TENSOR_SLICE_SET_H_
+#define TENSORFLOW_UTIL_TENSOR_SLICE_SET_H_
+
+#include <string> // for string
+#include <unordered_map>
+
+#include "tensorflow/core/platform/port.h" // for int64
+#include "tensorflow/core/framework/tensor_slice.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/public/tensor_shape.h"
+#include "tensorflow/core/lib/core/stringpiece.h" // for StringPiece
+#include "tensorflow/core/public/status.h" // for Status
+
+namespace tensorflow {
+
+namespace checkpoint {
+
+class TensorSliceSet {
+ public:
+ TensorSliceSet(const TensorShape& shape, DataType type);
+ virtual ~TensorSliceSet();
+
+ const TensorShape& shape() const { return shape_; }
+ const DataType type() const { return type_; }
+
+ // Register a new slice for the tensor. The "tag" is an arbitrary string
+ // associated with the slice (in one application it denotes the name of the
+ // file that contains the slice); the "data" points to the data of the tensor
+ // slice (it can be a nullptr).
+ // We don't take the ownership of "data" and the caller needs to make sure
+ // the data is always available during the life time of the tensor slice set
+ // if it is not nullptr.
+ Status Register(const TensorSlice& slice, const string& tag,
+ const float* data);
+
+ // Query about a new slice: checks if we have data for "slice" and if we have
+ // the data and "data" is not nullptr, fill "data" with the slice data. The
+ // caller needs to make sure "data" point to a large eough buffer.
+ // TODO(yangke): avoid unnecessary copying by using a core::RefCounted
+ // pointer.
+ bool Query(const TensorSlice& slice, float* data) const;
+
+ // Alternative way of querying about a new slice: instead of copying the
+ // data, it returns a list of meta data about the stored slices that will
+ // supply data for the slice.
+ bool QueryMeta(
+ const TensorSlice& slice,
+ std::vector<std::pair<tensorflow::TensorSlice, string>>* results) const;
+
+ private:
+ const TensorShape shape_;
+ const DataType type_;
+ struct SliceInfo {
+ TensorSlice slice;
+ const string tag;
+ const float* data;
+ int64 num_floats;
+ };
+ // We maintain a mapping from the slice string to the slice information.
+ std::unordered_map<string, SliceInfo> slices_;
+};
+
+} // namespace checkpoint
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_UTIL_TENSOR_SLICE_SET_H_
diff --git a/tensorflow/core/util/tensor_slice_set_test.cc b/tensorflow/core/util/tensor_slice_set_test.cc
new file mode 100644
index 0000000000..fb2f46f34c
--- /dev/null
+++ b/tensorflow/core/util/tensor_slice_set_test.cc
@@ -0,0 +1,227 @@
+#include "tensorflow/core/util/tensor_slice_set.h"
+
+#include "tensorflow/core/platform/logging.h"
+#include <gtest/gtest.h>
+#include "tensorflow/core/public/status.h"
+
+namespace tensorflow {
+
+namespace checkpoint {
+
+namespace {
+
+// A simple test: we have a 2-d tensor of shape 4 X 5 that looks like this:
+//
+// 0 1 2 3 4
+// 5 6 7 8 9
+// 10 11 12 13 14
+// 15 16 17 18 19
+//
+// We assume this is a row-major matrix.
+//
+// We store the tensor in a couple of slices and verify that we can recover all
+// of them.
+TEST(TensorSliceSetTest, QueryTwoD) {
+ TensorShape shape({4, 5});
+
+ TensorSliceSet tss(shape, DT_FLOAT);
+ // We store a few slices.
+
+ // Slice #1 is the top two rows:
+ // 0 1 2 3 4
+ // 5 6 7 8 9
+ // . . . . .
+ // . . . . .
+ const float src_1[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
+ TensorSlice slice_1 = TensorSlice::ParseOrDie("0,2:-");
+ TF_CHECK_OK(tss.Register(slice_1, "", src_1));
+
+ // Slice #2 is the bottom left corner
+ // . . . . .
+ // . . . . .
+ // 10 11 12 . .
+ // 15 16 17 . .
+ const float src_2[] = {10, 11, 12, 15, 16, 17};
+ TensorSlice slice_2 = TensorSlice::ParseOrDie("2,2:0,3");
+ TF_CHECK_OK(tss.Register(slice_2, "", src_2));
+
+ // Slice #3 is the bottom right corner
+ // . . . . .
+ // . . . . .
+ // . . . . .
+ // . . . 18 19
+ const float src_3[] = {18, 19};
+ TensorSlice slice_3 = TensorSlice::ParseOrDie("3,1:3,2");
+ TF_CHECK_OK(tss.Register(slice_3, "", src_3));
+
+ // Notice that we leave a hole in the tensor
+ // . . . . .
+ // . . . . .
+ // . . . (13) (14)
+ // . . . . .
+
+ // Now we query some of the slices
+
+ // Slice #1 is an exact match
+ // 0 1 2 3 4
+ // 5 6 7 8 9
+ // . . . . .
+ // . . . . .
+ {
+ TensorSlice s = TensorSlice::ParseOrDie("0,2:-");
+ float expected[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
+ float results[10];
+ EXPECT_TRUE(tss.Query(s, results));
+ for (int i = 0; i < 10; ++i) {
+ EXPECT_EQ(expected[i], results[i]);
+ }
+ }
+
+ // Slice #2 is a subset match
+ // . . . . .
+ // 5 6 7 8 9
+ // . . . . .
+ // . . . . .
+ {
+ TensorSlice s = TensorSlice::ParseOrDie("1,1:-");
+ float expected[] = {5, 6, 7, 8, 9};
+ float results[5];
+ EXPECT_TRUE(tss.Query(s, results));
+ for (int i = 0; i < 5; ++i) {
+ EXPECT_EQ(expected[i], results[i]);
+ }
+ }
+
+ // Slice #3 is a more complicated match: it needs the combination of a couple
+ // of slices
+ // . . . . .
+ // 5 6 7 . .
+ // 10 11 12 . .
+ // . . . . .
+ {
+ TensorSlice s = TensorSlice::ParseOrDie("1,2:0,3");
+ float expected[] = {5, 6, 7, 10, 11, 12};
+ float results[6];
+ EXPECT_TRUE(tss.Query(s, results));
+ for (int i = 0; i < 6; ++i) {
+ EXPECT_EQ(expected[i], results[i]);
+ }
+ }
+
+ // Slice #4 includes the hole and so there is no match
+ // . . . . .
+ // . . 7 8 9
+ // . . 12 13 14
+ // . . . . .
+ {
+ TensorSlice s = TensorSlice::ParseOrDie("1,2:2,3");
+ float results[6];
+ EXPECT_FALSE(tss.Query(s, results));
+ }
+}
+
+// Testing the meta version of the tensor slice set.
+TEST(TensorSliceSetTest, QueryMetaTwoD) {
+ TensorShape shape({4, 5});
+
+ TensorSliceSet tss(shape, DT_INT32);
+ // We store a few slices.
+
+ // Slice #1 is the top two rows:
+ // 0 1 2 3 4
+ // 5 6 7 8 9
+ // . . . . .
+ // . . . . .
+ TensorSlice slice_1 = TensorSlice::ParseOrDie("0,2:-");
+ TF_CHECK_OK(tss.Register(slice_1, "slice_1", nullptr));
+
+ // Slice #2 is the bottom left corner
+ // . . . . .
+ // . . . . .
+ // 10 11 12 . .
+ // 15 16 17 . .
+ TensorSlice slice_2 = TensorSlice::ParseOrDie("2,2:0,3");
+ TF_CHECK_OK(tss.Register(slice_2, "slice_2", nullptr));
+
+ // Slice #3 is the bottom right corner
+ // . . . . .
+ // . . . . .
+ // . . . . .
+ // . . . 18 19
+ TensorSlice slice_3 = TensorSlice::ParseOrDie("3,1:3,2");
+ TF_CHECK_OK(tss.Register(slice_3, "slice_3", nullptr));
+
+ // Notice that we leave a hole in the tensor
+ // . . . . .
+ // . . . . .
+ // . . . (13) (14)
+ // . . . . .
+
+ // Now we query some of the slices
+
+ // Slice #1 is an exact match
+ // 0 1 2 3 4
+ // 5 6 7 8 9
+ // . . . . .
+ // . . . . .
+ // We just need slice_1 for this
+ {
+ TensorSlice s = TensorSlice::ParseOrDie("0,2:-");
+ std::vector<std::pair<TensorSlice, string>> results;
+ EXPECT_TRUE(tss.QueryMeta(s, &results));
+ EXPECT_EQ(1, results.size());
+ EXPECT_EQ("0,2:-", results[0].first.DebugString());
+ EXPECT_EQ("slice_1", results[0].second);
+ }
+
+ // Slice #2 is a subset match
+ // . . . . .
+ // 5 6 7 8 9
+ // . . . . .
+ // . . . . .
+ // We just need slice_1 for this
+ {
+ TensorSlice s = TensorSlice::ParseOrDie("1,1:-");
+ std::vector<std::pair<TensorSlice, string>> results;
+ EXPECT_TRUE(tss.QueryMeta(s, &results));
+ EXPECT_EQ(1, results.size());
+ EXPECT_EQ("0,2:-", results[0].first.DebugString());
+ EXPECT_EQ("slice_1", results[0].second);
+ }
+
+ // Slice #3 is a more complicated match: it needs the combination of a couple
+ // of slices
+ // . . . . .
+ // 5 6 7 . .
+ // 10 11 12 . .
+ // . . . . .
+ // We need both slice_1 and slice_2 for this.
+ {
+ TensorSlice s = TensorSlice::ParseOrDie("1,2:0,3");
+ std::vector<std::pair<TensorSlice, string>> results;
+ EXPECT_TRUE(tss.QueryMeta(s, &results));
+ EXPECT_EQ(2, results.size());
+ EXPECT_EQ("2,2:0,3", results[0].first.DebugString());
+ EXPECT_EQ("slice_2", results[0].second);
+ EXPECT_EQ("0,2:-", results[1].first.DebugString());
+ EXPECT_EQ("slice_1", results[1].second);
+ }
+
+ // Slice #4 includes the hole and so there is no match
+ // . . . . .
+ // . . 7 8 9
+ // . . 12 13 14
+ // . . . . .
+ {
+ TensorSlice s = TensorSlice::ParseOrDie("1,2:2,3");
+ std::vector<std::pair<TensorSlice, string>> results;
+ EXPECT_FALSE(tss.QueryMeta(s, &results));
+ EXPECT_EQ(0, results.size());
+ }
+}
+
+} // namespace
+
+} // namespace checkpoint
+
+} // namespace tensorflow
diff --git a/tensorflow/core/util/tensor_slice_util.h b/tensorflow/core/util/tensor_slice_util.h
new file mode 100644
index 0000000000..5422c3bef3
--- /dev/null
+++ b/tensorflow/core/util/tensor_slice_util.h
@@ -0,0 +1,88 @@
+#ifndef TENSORFLOW_UTIL_TENSOR_SLICE_UTIL_H_
+#define TENSORFLOW_UTIL_TENSOR_SLICE_UTIL_H_
+
+#include "tensorflow/core/framework/tensor_slice.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/public/tensor_shape.h"
+
+namespace tensorflow {
+
+// Some hackery to invoke eigen tensor to copy over tensor slices with variable
+// dimension tensors.
+// TODO(yangke): get rid of that once the variable dimension tensor support is
+// in.
+static const int kTensorSliceMaxRank = 8;
+
+// Create a tensor map with the given shape: we support up to 8 dimensions. If
+// the shape has less than 8 dimensions, we pad the remaining dimension with 1.
+template <typename T>
+Eigen::TensorMap<Eigen::Tensor<T, kTensorSliceMaxRank, Eigen::RowMajor>>
+GetEigenTensorMapFromTensorShape(const TensorShape& shape, T* data) {
+ Eigen::DSizes<Eigen::DenseIndex, kTensorSliceMaxRank> dsizes =
+ shape.AsEigenDSizesWithPadding<kTensorSliceMaxRank>();
+ Eigen::TensorMap<Eigen::Tensor<T, kTensorSliceMaxRank, Eigen::RowMajor>> eig(
+ data, dsizes);
+ return eig;
+}
+
+// Given a tensor described by "shape", two slices "slice_s" and "slice_d",
+// and two pointers "ptr_s" and "ptr_d", where "ptr_s" points to a chunk of
+// memory that stores the data for "slice_s" and "ptr_d" points to a chunk of
+// memory that stores the data for "slice_d". This function copies the data
+// that belongs to the intersection of the two slices from slice_s to
+// slice_d. Uses Tensor cast<DstT>() to convert from SrcT to DstT. Returns true
+// iff the two slices share any intersection (and thus some data is copied).
+// TODO(yangke): figure out if we can make it private.
+template <typename SrcT, typename DstT>
+static bool CopyDataFromTensorSliceToTensorSlice(const TensorShape& shape,
+ const TensorSlice& slice_s,
+ const TensorSlice& slice_d,
+ const SrcT* ptr_s,
+ DstT* ptr_d) {
+ CHECK_LE(shape.dims(), kTensorSliceMaxRank) << "Only tensors of size up to "
+ << kTensorSliceMaxRank
+ << " are supported";
+ // We need to compute the intersection of the two slices.
+ TensorSlice inter;
+ if (!slice_s.Intersect(slice_d, &inter)) {
+ // There is no intersection: returns false.
+ return false;
+ } else {
+ // We need to compute the applied shapes after applying slice_s and
+ // slice_d.
+ TensorShape shp_s, shp_d;
+ Status s;
+ s = slice_s.SliceTensorShape(shape, &shp_s);
+ if (!s.ok()) {
+ LOG(WARNING) << s;
+ return false;
+ }
+ s = slice_d.SliceTensorShape(shape, &shp_d);
+ if (!s.ok()) {
+ LOG(WARNING) << s;
+ return false;
+ }
+
+ // We need to compute the relative slice of "inter" w.r.t. both slice_s and
+ // slice_d.
+ TensorSlice rel_s, rel_d;
+ slice_s.ComputeRelative(inter, &rel_s);
+ slice_d.ComputeRelative(inter, &rel_d);
+
+ // Get the eigen tensor maps to the data.
+ auto t_s = GetEigenTensorMapFromTensorShape(shp_s, ptr_s);
+ auto t_d = GetEigenTensorMapFromTensorShape(shp_d, ptr_d);
+
+ Eigen::DSizes<Eigen::DenseIndex, kTensorSliceMaxRank> s_start, s_len,
+ d_start, d_len;
+
+ rel_s.FillIndicesAndSizes<kTensorSliceMaxRank>(shp_s, &s_start, &s_len);
+ rel_d.FillIndicesAndSizes<kTensorSliceMaxRank>(shp_d, &d_start, &d_len);
+ t_d.slice(d_start, d_len) = t_s.slice(s_start, s_len).template cast<DstT>();
+ return true;
+ }
+}
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_UTIL_TENSOR_SLICE_UTIL_H_
diff --git a/tensorflow/core/util/tensor_slice_util_test.cc b/tensorflow/core/util/tensor_slice_util_test.cc
new file mode 100644
index 0000000000..348b0c884e
--- /dev/null
+++ b/tensorflow/core/util/tensor_slice_util_test.cc
@@ -0,0 +1,91 @@
+#include "tensorflow/core/util/tensor_slice_util.h"
+
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+namespace {
+
+// Testing copying data from one tensor slice to another tensor slice
+TEST(TensorSliceUtilTest, CopyTensorSliceToTensorSlice) {
+ // We map out a 2-d tensor of size 4 X 5 and we want the final results look
+ // like this:
+ //
+ // 0 1 2 3 4
+ // 5 6 7 8 9
+ // 10 11 12 13 14
+ // 15 16 17 18 19
+ //
+ // We assume this is a row-major matrix
+ //
+ TensorShape shape({4, 5});
+
+ // We will try to do a couple of slice to slice copies.
+
+ // Case 1: simple identity copy
+ // The slice is the "interior" of the matrix
+ // . . . . .
+ // . 6 7 8 .
+ // , 11 12 13 .
+ // . . . . .
+ {
+ TensorSlice slice_s = TensorSlice::ParseOrDie("1,2:1,3");
+ TensorSlice slice_d = TensorSlice::ParseOrDie("1,2:1,3");
+ const float ptr_s[] = {6, 7, 8, 11, 12, 13};
+ float ptr_d[6];
+ for (int i = 0; i < 6; ++i) {
+ ptr_d[i] = 0;
+ }
+ EXPECT_TRUE(CopyDataFromTensorSliceToTensorSlice(shape, slice_s, slice_d,
+ ptr_s, ptr_d));
+ for (int i = 0; i < 6; ++i) {
+ EXPECT_EQ(ptr_s[i], ptr_d[i]);
+ }
+ }
+
+ // Case 2: no intersection
+ {
+ TensorSlice slice_s = TensorSlice::ParseOrDie("1,2:1,3");
+ TensorSlice slice_d = TensorSlice::ParseOrDie("3,1:2,3");
+ const float ptr_s[] = {6, 7, 8, 11, 12, 13};
+ float ptr_d[6];
+ EXPECT_FALSE(CopyDataFromTensorSliceToTensorSlice(shape, slice_s, slice_d,
+ ptr_s, ptr_d));
+ }
+
+ // Case 3: a trickier case
+ // The source slice is on the upper left corner:
+ // 0 1 2 . .
+ // 5 6 7 . .
+ // 10 11 12 . .
+ // . . . . .
+ //
+ // The destination slice is the right part of the middle stripe:
+ // . . . . .
+ // . X X X X
+ // . X X X X
+ // . . . . .
+ //
+ // So we expect to copy over the 2X2 block:
+ // . . . . .
+ // . 6 7 . .
+ // . 11 12 . .
+ // . . . . .
+ {
+ TensorSlice slice_s = TensorSlice::ParseOrDie("0,3:0,3");
+ TensorSlice slice_d = TensorSlice::ParseOrDie("1,2:1,4");
+ const float ptr_s[] = {0, 1, 2, 5, 6, 7, 10, 11, 12};
+ float ptr_d[8];
+ for (int i = 0; i < 8; ++i) {
+ ptr_d[i] = 0;
+ }
+ EXPECT_TRUE(CopyDataFromTensorSliceToTensorSlice(shape, slice_s, slice_d,
+ ptr_s, ptr_d));
+ const float expected[] = {6, 7, 0, 0, 11, 12, 0, 0};
+ for (int i = 0; i < 8; ++i) {
+ EXPECT_EQ(expected[i], ptr_d[i]);
+ }
+ }
+}
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/util/tensor_slice_writer.cc b/tensorflow/core/util/tensor_slice_writer.cc
new file mode 100644
index 0000000000..bb2fd96c05
--- /dev/null
+++ b/tensorflow/core/util/tensor_slice_writer.cc
@@ -0,0 +1,110 @@
+#include "tensorflow/core/util/tensor_slice_writer.h"
+
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/lib/io/table_builder.h"
+#include "tensorflow/core/lib/random/random.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/public/env.h"
+#include "tensorflow/core/util/saved_tensor_slice_util.h"
+
+namespace tensorflow {
+
+namespace checkpoint {
+
+namespace {
+
+class TableBuilder : public TensorSliceWriter::Builder {
+ public:
+ TableBuilder(const string& name, WritableFile* f)
+ : name_(name),
+ file_(f),
+ builder_(new table::TableBuilder(table::Options(), f)) {}
+ void Add(StringPiece key, StringPiece val) override {
+ builder_->Add(key, val);
+ }
+ Status Finish(int64* file_size) override {
+ *file_size = -1;
+ Status s = builder_->Finish();
+ if (s.ok()) {
+ s = file_->Close();
+ if (s.ok()) {
+ *file_size = builder_->FileSize();
+ }
+ }
+ if (!s.ok()) {
+ s = errors::Internal("Error writing (tmp) checkpoint file: ", name_, ": ",
+ s.ToString());
+ }
+ builder_.reset();
+ file_.reset();
+ return s;
+ }
+
+ private:
+ string name_;
+ std::unique_ptr<WritableFile> file_;
+ std::unique_ptr<table::TableBuilder> builder_;
+};
+} // anonymous namespace
+
+Status CreateTableTensorSliceBuilder(
+ const string& name, TensorSliceWriter::Builder** builder) {
+ *builder = nullptr;
+ WritableFile* f;
+ Status s = Env::Default()->NewWritableFile(name, &f);
+ if (s.ok()) {
+ *builder = new TableBuilder(name, f);
+ return Status::OK();
+ } else {
+ return s;
+ }
+}
+
+TensorSliceWriter::TensorSliceWriter(const string& filename,
+ CreateBuilderFunction create_builder)
+ : filename_(filename),
+ create_builder_(create_builder),
+ tmpname_(strings::StrCat(filename, ".tempstate", random::New64())),
+ slices_(0) {}
+
+Status TensorSliceWriter::Finish() {
+ Builder* b;
+ Status s = create_builder_(tmpname_, &b);
+ if (!s.ok()) {
+ delete b;
+ return s;
+ }
+ std::unique_ptr<Builder> builder(b);
+
+ // We save the saved tensor slice metadata as the first element.
+ string meta;
+ sts_.AppendToString(&meta);
+ builder->Add(kSavedTensorSlicesKey, meta);
+
+ // Go through all the data and add them
+ for (const auto& x : data_) {
+ builder->Add(x.first, x.second);
+ }
+
+ int64 file_size;
+ s = builder->Finish(&file_size);
+ // We need to rename the file to the proper name
+ if (s.ok()) {
+ s = Env::Default()->RenameFile(tmpname_, filename_);
+ if (s.ok()) {
+ VLOG(1) << "Written " << slices_ << " slices for "
+ << sts_.meta().tensor_size() << " tensors (" << file_size
+ << " bytes) to " << filename_;
+ } else {
+ LOG(ERROR) << "Failed to rename file " << tmpname_ << " to " << filename_;
+ }
+ } else {
+ Env::Default()->DeleteFile(tmpname_);
+ }
+ return s;
+}
+
+} // namespace checkpoint
+
+} // namespace tensorflow
diff --git a/tensorflow/core/util/tensor_slice_writer.h b/tensorflow/core/util/tensor_slice_writer.h
new file mode 100644
index 0000000000..cce3880cb3
--- /dev/null
+++ b/tensorflow/core/util/tensor_slice_writer.h
@@ -0,0 +1,149 @@
+// The utility to write checkpoints for google brain tensor ops and v3
+// checkpoints for dist_belief.
+//
+
+#ifndef TENSORFLOW_UTIL_TENSOR_SLICE_WRITER_H_
+#define TENSORFLOW_UTIL_TENSOR_SLICE_WRITER_H_
+
+#include <unordered_map>
+
+#include "tensorflow/core/framework/tensor_slice.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/public/tensor_shape.h"
+#include "tensorflow/core/public/status.h"
+#include "tensorflow/core/util/saved_tensor_slice.pb.h"
+#include "tensorflow/core/util/saved_tensor_slice_util.h"
+
+namespace tensorflow {
+
+namespace checkpoint {
+
+class TensorSliceWriter {
+ public:
+ // Abstract interface that TensorSliceWriter uses for building
+ class Builder {
+ public:
+ virtual ~Builder() {}
+ virtual void Add(StringPiece key, StringPiece value) = 0;
+ virtual Status Finish(int64* file_size) = 0;
+ };
+ typedef std::function<Status(const string&, Builder**)>
+ CreateBuilderFunction;
+
+ TensorSliceWriter(const string& filename,
+ CreateBuilderFunction create_builder);
+ virtual ~TensorSliceWriter() {}
+ // Adds a slice. We support float and int32 for now.
+ // TODO(yangke): add more supports
+ template <typename T>
+ Status Add(const string& name, const TensorShape& shape,
+ const TensorSlice& slice, const T* data);
+ Status Finish();
+
+ private:
+ // Allocate "num_elements" elements in "ss" and save the data in "data"
+ // there.
+ template <typename T>
+ static void SaveData(const T* data, int num_elements, SavedSlice* ss);
+
+ const string filename_;
+ const CreateBuilderFunction create_builder_;
+ const string tmpname_;
+
+ // A mapping from the tensor names to their index in meta_.saved_slice_meta()
+ std::unordered_map<string, int> name_to_index_;
+ // The metadata that holds all the saved tensor slices.
+ SavedTensorSlices sts_;
+ // The data to be written to the builder
+ std::map<string, string> data_;
+ // Total number of slices written
+ int slices_;
+ TF_DISALLOW_COPY_AND_ASSIGN(TensorSliceWriter);
+};
+
+template <typename T>
+Status TensorSliceWriter::Add(const string& name, const TensorShape& shape,
+ const TensorSlice& slice, const T* data) {
+ // The tensor and the slice have to be compatible
+ if (shape.dims() != slice.dims()) {
+ return errors::Internal("Incompatible tensor shape and slice: ", "shape = ",
+ shape.DebugString(), ", slice = ",
+ slice.DebugString());
+ }
+ DataType dt = DataTypeToEnum<T>::value;
+ // We need to add an entry for "name" if there isn't an entry already.
+ int index = gtl::FindWithDefault(name_to_index_, name, -1);
+ if (index >= 0) {
+ // The same tensor has been registered -- we verify that the shapes and the
+ // type agree.
+ const SavedSliceMeta& ssm = sts_.meta().tensor(index);
+ CHECK_EQ(name, ssm.name()) << ssm.ShortDebugString();
+ TensorShape ssm_shape(ssm.shape());
+ if (!shape.IsSameSize(ssm_shape)) {
+ return errors::Internal("Mismatching shapes: existing tensor = ",
+ ssm_shape.DebugString(), ", trying to add name ",
+ name, ", shape = ", shape.DebugString());
+ }
+ if (dt != ssm.type()) {
+ return errors::Internal(
+ "Mismatching types: existing type = ", DataTypeString(ssm.type()),
+ ", trying to add name ", name, ", type = ", DataTypeString(dt));
+ }
+ } else {
+ // Insert the new tensor name with the shape information
+ index = sts_.meta().tensor_size();
+ name_to_index_.insert(std::make_pair(name, index));
+ SavedSliceMeta* ssm = sts_.mutable_meta()->add_tensor();
+ ssm->set_name(name);
+ shape.AsProto(ssm->mutable_shape());
+ ssm->set_type(dt);
+ }
+ // Now we need to add the slice info the list of slices.
+ SavedSliceMeta* ssm = sts_.mutable_meta()->mutable_tensor(index);
+ slice.AsProto(ssm->add_slice());
+
+ // Now we need to add the real data.
+ {
+ SavedTensorSlices sts;
+ SavedSlice* ss = sts.mutable_data();
+ ss->set_name(name);
+ slice.AsProto(ss->mutable_slice());
+ TensorShape saved_shape(ssm->shape());
+ TensorShape sliced_shape;
+ TF_RETURN_IF_ERROR(slice.SliceTensorShape(saved_shape, &sliced_shape));
+ SaveData(data, sliced_shape.num_elements(), ss);
+ string key = EncodeTensorNameSlice(name, slice);
+ // TODO(yangke): consider doing a two-pass thing where the first pass just
+ // list the tensor slices we want to save and then another pass to actually
+ // set the data. Need to figure out if the interface works well.
+ std::pair<string, string> key_value(key, "");
+ sts.AppendToString(&key_value.second);
+ data_.insert(key_value);
+ }
+ ++slices_;
+ return Status::OK();
+}
+
+template <typename T>
+void TensorSliceWriter::SaveData(const T* data, int num_elements,
+ SavedSlice* ss) {
+ Fill(data, num_elements, ss->mutable_data());
+}
+
+// Create a table builder that will write to "filename" in
+// tensorflow::io::Table format. If successful, return OK
+// and set "*builder" to the allocated builder. Otherwise, return a
+// non-OK status.
+Status CreateTableTensorSliceBuilder(const string& filename,
+ TensorSliceWriter::Builder** builder);
+
+} // namespace checkpoint
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_UTIL_TENSOR_SLICE_WRITER_H_
diff --git a/tensorflow/core/util/tensor_slice_writer_test.cc b/tensorflow/core/util/tensor_slice_writer_test.cc
new file mode 100644
index 0000000000..ca3dffe422
--- /dev/null
+++ b/tensorflow/core/util/tensor_slice_writer_test.cc
@@ -0,0 +1,248 @@
+#include "tensorflow/core/util/tensor_slice_writer.h"
+
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/lib/io/path.h"
+#include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/util/saved_tensor_slice_util.h"
+#include "tensorflow/core/util/tensor_slice_reader.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/platform/test.h"
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+
+namespace checkpoint {
+
+class TensorSliceWriteTestHelper {
+ public:
+ static void CheckEntries(const string& fname);
+ static void GetData(TensorSliceReader::Table* table, const string& name,
+ const TensorSlice& slice, SavedSlice* ss);
+};
+
+namespace {
+
+// Testing that an array is what is expected
+void ExpectIdenticalFloatArrays(const float* expected, int size,
+ const float* actual) {
+ // TODO(yangke): copy some of the Dump* functions over
+ // LOG(INFO) << "Expected = " << DumpFloatArray(expected, size);
+ // LOG(INFO) << "Actual = " << DumpFloatArray(actual, size);
+ for (int i = 0; i < size; ++i) {
+ EXPECT_NEAR(expected[i], actual[i], 1e-6);
+ }
+}
+
+template <typename T, typename U>
+void ExpectIdenticalIntArrays(const T* expected, int size, const U* actual) {
+ for (int i = 0; i < size; ++i) {
+ EXPECT_EQ(expected[i], static_cast<T>(actual[i]));
+ }
+}
+
+// Nifty routine to get the size of an array
+template <typename T, unsigned SIZE>
+inline size_t ArraySize(const T(&v)[SIZE]) {
+ return SIZE;
+}
+
+// A simple test on writing a few tensor slices
+// TODO(yangke): refactor into smaller tests: will do as we add more stuff to
+// the writer.
+TEST(TensorSliceWriteTest, SimpleWrite) {
+ const string filename = io::JoinPath(testing::TmpDir(), "checkpoint");
+
+ TensorSliceWriter writer(filename, CreateTableTensorSliceBuilder);
+
+ // Add some int32 tensor slices
+ {
+ TensorShape shape({5, 10});
+ TensorSlice slice = TensorSlice::ParseOrDie("-:0,1");
+ const int32 data[] = {0, 1, 2, 3, 4};
+ TF_CHECK_OK(writer.Add("test", shape, slice, data));
+ }
+
+ // Two slices share the same tensor name
+ {
+ TensorShape shape({5, 10});
+ TensorSlice slice = TensorSlice::ParseOrDie("-:3,1");
+ const int32 data[] = {10, 11, 12, 13, 14};
+ TF_CHECK_OK(writer.Add("test", shape, slice, data));
+ }
+
+ // Another slice from a different float tensor -- it has a different name and
+ // should be inserted in front of the previous tensor
+ {
+ TensorShape shape({3, 2});
+ TensorSlice slice = TensorSlice::ParseOrDie("-:-");
+ const float data[] = {1.2, 1.3, 1.4, 2.1, 2.2, 2.3};
+ TF_CHECK_OK(writer.Add("AA", shape, slice, data));
+ }
+
+ // A slice with int64 data
+ {
+ TensorShape shape({5, 10});
+ TensorSlice slice = TensorSlice::ParseOrDie("-:3,1");
+ const int64 data[] = {10, 11, 12, 13, 14};
+ TF_CHECK_OK(writer.Add("int64", shape, slice, data));
+ }
+
+ // A slice with int16 data
+ {
+ TensorShape shape({5, 10});
+ TensorSlice slice = TensorSlice::ParseOrDie("-:3,1");
+ const int16 data[] = {10, 11, 12, 13, 14};
+ TF_CHECK_OK(writer.Add("int16", shape, slice, data));
+ }
+
+ TF_CHECK_OK(writer.Finish());
+
+ // Now we examine the checkpoint file manually.
+ TensorSliceWriteTestHelper::CheckEntries(filename);
+}
+
+} // namespace
+
+void TensorSliceWriteTestHelper::GetData(TensorSliceReader::Table* table,
+ const string& name,
+ const TensorSlice& slice,
+ SavedSlice* ss) {
+ string key = EncodeTensorNameSlice(name, slice);
+ string value;
+ EXPECT_TRUE(table->Get(key, &value));
+ SavedTensorSlices sts;
+ EXPECT_TRUE(ParseProtoUnlimited(&sts, value));
+ EXPECT_FALSE(sts.has_meta());
+ *ss = sts.data();
+ EXPECT_EQ(name, ss->name());
+ TensorSlice slice2(ss->slice());
+ EXPECT_EQ(slice.DebugString(), slice2.DebugString());
+}
+
+void TensorSliceWriteTestHelper::CheckEntries(const string& fname) {
+ TensorSliceReader::Table* tptr;
+ TF_CHECK_OK(OpenTableTensorSliceReader(fname, &tptr));
+ std::unique_ptr<TensorSliceReader::Table> table(tptr);
+ CHECK_NOTNULL(table.get());
+
+ // We expect a block of SavedTensorSlices
+ string value;
+ ASSERT_TRUE(table->Get(kSavedTensorSlicesKey, &value));
+ {
+ SavedTensorSlices sts;
+ EXPECT_TRUE(ParseProtoUnlimited(&sts, value));
+ // We also expect two entries for the tensors
+ EXPECT_TRUE(sts.has_meta());
+ EXPECT_EQ(4, sts.meta().tensor_size());
+ // We don't expect any data in the first block.
+ EXPECT_FALSE(sts.has_data());
+ // The two tensors should be stored in the same order as they are first
+ // created.
+ {
+ // The two slices of the "test" tensor
+ const SavedSliceMeta& ssm = sts.meta().tensor(0);
+ EXPECT_EQ("test", ssm.name());
+ EXPECT_EQ(
+ "dim { size: 5 } "
+ "dim { size: 10 }",
+ ssm.shape().ShortDebugString());
+ EXPECT_EQ(DT_INT32, ssm.type());
+ EXPECT_EQ(2, ssm.slice_size());
+ TensorSlice s0(ssm.slice(0));
+ TensorSlice s1(ssm.slice(1));
+ EXPECT_EQ("-:0,1", s0.DebugString());
+ EXPECT_EQ("-:3,1", s1.DebugString());
+ }
+ {
+ // The "AA" tensor
+ const SavedSliceMeta& ssm = sts.meta().tensor(1);
+ EXPECT_EQ("AA", ssm.name());
+ EXPECT_EQ(
+ "dim { size: 3 } "
+ "dim { size: 2 }",
+ ssm.shape().ShortDebugString());
+ EXPECT_EQ(DT_FLOAT, ssm.type());
+ EXPECT_EQ(1, ssm.slice_size());
+ TensorSlice s0(ssm.slice(0));
+ EXPECT_EQ("-:-", s0.DebugString());
+ }
+ {
+ // The "int64" tensor
+ const SavedSliceMeta& ssm = sts.meta().tensor(2);
+ EXPECT_EQ("int64", ssm.name());
+ EXPECT_EQ(
+ "dim { size: 5 } "
+ "dim { size: 10 }",
+ ssm.shape().ShortDebugString());
+ EXPECT_EQ(DT_INT64, ssm.type());
+ EXPECT_EQ(1, ssm.slice_size());
+ TensorSlice s0(ssm.slice(0));
+ EXPECT_EQ("-:3,1", s0.DebugString());
+ }
+ {
+ // The "int16" tensor
+ const SavedSliceMeta& ssm = sts.meta().tensor(3);
+ EXPECT_EQ("int16", ssm.name());
+ EXPECT_EQ(
+ "dim { size: 5 } "
+ "dim { size: 10 }",
+ ssm.shape().ShortDebugString());
+ EXPECT_EQ(DT_INT16, ssm.type());
+ EXPECT_EQ(1, ssm.slice_size());
+ TensorSlice s0(ssm.slice(0));
+ EXPECT_EQ("-:3,1", s0.DebugString());
+ }
+ }
+
+ // We expect 5 blocks of tensor data
+ {
+ // Block 1: we expect it to be the full slice of the "AA" tensor
+ SavedSlice ss;
+ GetData(table.get(), "AA", TensorSlice(2), &ss);
+ const float data[] = {1.2, 1.3, 1.4, 2.1, 2.2, 2.3};
+ EXPECT_EQ(ArraySize(data), ss.data().float_val_size());
+ ExpectIdenticalFloatArrays(data, ArraySize(data),
+ ss.data().float_val().data());
+ }
+
+ {
+ // Block 2: we expect it to be the first slice of the "test" tensor
+ SavedSlice ss;
+ GetData(table.get(), "test", TensorSlice({{0, -1}, {0, 1}}), &ss);
+ const int32 data[] = {0, 1, 2, 3, 4};
+ EXPECT_EQ(ArraySize(data), ss.data().int_val_size());
+ ExpectIdenticalIntArrays(data, ArraySize(data), ss.data().int_val().data());
+ }
+
+ {
+ // Block 3: we expect it to be the second slice of the "test" tensor
+ SavedSlice ss;
+ GetData(table.get(), "test", TensorSlice({{0, -1}, {3, 1}}), &ss);
+ const int32 data[] = {10, 11, 12, 13, 14};
+ EXPECT_EQ(ArraySize(data), ss.data().int_val_size());
+ ExpectIdenticalIntArrays(data, ArraySize(data), ss.data().int_val().data());
+ }
+
+ {
+ // Block 4: we expect it to be the slice of the "int64" tensor
+ SavedSlice ss;
+ GetData(table.get(), "int64", TensorSlice({{0, -1}, {3, 1}}), &ss);
+ const int64 data[] = {10, 11, 12, 13, 14};
+ EXPECT_EQ(ArraySize(data), ss.data().int64_val_size());
+ ExpectIdenticalIntArrays(data, ArraySize(data),
+ ss.data().int64_val().data());
+ }
+
+ {
+ // Block 5: we expect it to be the slice of the "int16" tensor
+ SavedSlice ss;
+ GetData(table.get(), "int16", TensorSlice({{0, -1}, {3, 1}}), &ss);
+ const int16 data[] = {10, 11, 12, 13, 14};
+ EXPECT_EQ(ArraySize(data), ss.data().int_val_size());
+ ExpectIdenticalIntArrays(data, ArraySize(data), ss.data().int_val().data());
+ }
+}
+
+} // namespace checkpoint
+
+} // namespace tensorflow
diff --git a/tensorflow/core/util/use_cudnn.cc b/tensorflow/core/util/use_cudnn.cc
new file mode 100644
index 0000000000..544b48a679
--- /dev/null
+++ b/tensorflow/core/util/use_cudnn.cc
@@ -0,0 +1,20 @@
+#include "tensorflow/core/util/use_cudnn.h"
+
+#include <stdlib.h>
+
+#include "tensorflow/core/platform/port.h"
+
+namespace tensorflow {
+
+bool CanUseCudnn() {
+ const char* tf_use_cudnn = getenv("TF_USE_CUDNN");
+ if (tf_use_cudnn != nullptr) {
+ string tf_use_cudnn_str = tf_use_cudnn;
+ if (tf_use_cudnn_str == "0") {
+ return false;
+ }
+ }
+ return true;
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/util/use_cudnn.h b/tensorflow/core/util/use_cudnn.h
new file mode 100644
index 0000000000..20ce24c513
--- /dev/null
+++ b/tensorflow/core/util/use_cudnn.h
@@ -0,0 +1,12 @@
+// The utility to check whether we have Cudnn depenedency.
+
+#ifndef TENSORFLOW_UTIL_USE_CUDNN_H_
+#define TENSORFLOW_UTIL_USE_CUDNN_H_
+
+namespace tensorflow {
+
+bool CanUseCudnn();
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_UTIL_USE_CUDNN_H_
diff --git a/tensorflow/core/util/util.cc b/tensorflow/core/util/util.cc
new file mode 100644
index 0000000000..14ac513074
--- /dev/null
+++ b/tensorflow/core/util/util.cc
@@ -0,0 +1,81 @@
+#include "tensorflow/core/util/util.h"
+
+#include "tensorflow/core/platform/logging.h"
+namespace tensorflow {
+
+StringPiece NodeNamePrefix(const StringPiece& op_name) {
+ StringPiece sp(op_name);
+ auto p = sp.find('/');
+ if (p == StringPiece::npos || p == 0) {
+ return "";
+ } else {
+ return StringPiece(sp.data(), p);
+ }
+}
+
+StringPiece NodeNameFullPrefix(const StringPiece& op_name) {
+ StringPiece sp(op_name);
+ auto p = sp.rfind('/');
+ if (p == StringPiece::npos || p == 0) {
+ return "";
+ } else {
+ return StringPiece(sp.data(), p);
+ }
+}
+
+MovingAverage::MovingAverage(int window)
+ : window_(window),
+ sum_(0.0),
+ data_(new double[window_]),
+ head_(0),
+ count_(0) {
+ CHECK_GE(window, 1);
+}
+
+MovingAverage::~MovingAverage() { delete[] data_; }
+
+void MovingAverage::Clear() {
+ count_ = 0;
+ head_ = 0;
+ sum_ = 0;
+}
+
+double MovingAverage::GetAverage() const {
+ if (count_ == 0) {
+ return 0;
+ } else {
+ return static_cast<double>(sum_) / count_;
+ }
+}
+
+void MovingAverage::AddValue(double v) {
+ if (count_ < window_) {
+ // This is the warmup phase. We don't have a full window's worth of data.
+ head_ = count_;
+ data_[count_++] = v;
+ } else {
+ if (window_ == ++head_) {
+ head_ = 0;
+ }
+ // Toss the oldest element
+ sum_ -= data_[head_];
+ // Add the newest element
+ data_[head_] = v;
+ }
+ sum_ += v;
+}
+
+static char hex_char[] = "0123456789abcdef";
+
+string PrintMemory(const char* ptr, int n) {
+ string ret;
+ ret.resize(n * 3);
+ for (int i = 0; i < n; ++i) {
+ ret[i * 3] = ' ';
+ ret[i * 3 + 1] = hex_char[ptr[i] >> 4];
+ ret[i * 3 + 2] = hex_char[ptr[i] & 0xf];
+ }
+ return ret;
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/util/util.h b/tensorflow/core/util/util.h
new file mode 100644
index 0000000000..52650bd8ea
--- /dev/null
+++ b/tensorflow/core/util/util.h
@@ -0,0 +1,40 @@
+#ifndef TENSORFLOW_UTIL_UTIL_H_
+#define TENSORFLOW_UTIL_UTIL_H_
+
+#include "tensorflow/core/lib/core/stringpiece.h"
+
+namespace tensorflow {
+
+// If op_name has '/' in it, then return everything before the first '/'.
+// Otherwise return empty string.
+StringPiece NodeNamePrefix(const StringPiece& op_name);
+
+// If op_name has '/' in it, then return everything before the last '/'.
+// Otherwise return empty string.
+StringPiece NodeNameFullPrefix(const StringPiece& op_name);
+
+class MovingAverage {
+ public:
+ explicit MovingAverage(int window);
+ ~MovingAverage();
+
+ void Clear();
+
+ double GetAverage() const;
+ void AddValue(double v);
+
+ private:
+ const int window_; // Max size of interval
+ double sum_; // Sum over interval
+ double* data_; // Actual data values
+ int head_; // Offset of the newest statistic in data_
+ int count_; // # of valid data elements in window
+};
+
+// Returns a string printing bytes in ptr[0..n). The output looks
+// like "00 01 ef cd cd ef".
+string PrintMemory(const char* ptr, int n);
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_UTIL_UTIL_H_
diff --git a/tensorflow/core/util/work_sharder.cc b/tensorflow/core/util/work_sharder.cc
new file mode 100644
index 0000000000..d9ab0805c5
--- /dev/null
+++ b/tensorflow/core/util/work_sharder.cc
@@ -0,0 +1,57 @@
+#include "tensorflow/core/util/work_sharder.h"
+
+#include <vector>
+#include "tensorflow/core/lib/core/blocking_counter.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace tensorflow {
+
+void Shard(int num_workers, thread::ThreadPool* workers, int64 total,
+ int64 cost_per_unit, std::function<void(int64, int64)> work) {
+ CHECK_GE(total, 0);
+ if (total == 0) {
+ return;
+ }
+ if (num_workers <= 1) {
+ // Just inline the whole work since we only have 1 thread (core).
+ work(0, total);
+ return;
+ }
+ cost_per_unit = std::max(1LL, cost_per_unit);
+ // We shard [0, total) into "num_shards" shards.
+ // 1 <= num_shards <= num worker threads
+ //
+ // If total * cost_per_unit is small, it is not worth shard too
+ // much. Let us assume each cost unit is 1ns, kMinCostPerShard=10000
+ // is 10us.
+ static const int64 kMinCostPerShard = 10000;
+ const int num_shards = std::max(
+ 1, std::min<int>(num_workers, total * cost_per_unit / kMinCostPerShard));
+ // Each shard contains up to "block_size" units. [0, total) is sharded
+ // into:
+ // [0, block_size), [block_size, 2*block_size), ...
+ // The 1st shard is done by the caller thread and the other shards
+ // are dispatched to the worker threads. The last shard may be smaller than
+ // block_size.
+ const int64 block_size = (total + num_shards - 1) / num_shards;
+ CHECK_GT(block_size, 0); // total > 0 guarantees this.
+ if (block_size >= total) {
+ work(0, total);
+ return;
+ }
+ const int num_shards_used = (total + block_size - 1) / block_size;
+ BlockingCounter counter(num_shards_used - 1);
+ for (int64 start = block_size; start < total; start += block_size) {
+ auto limit = std::min(start + block_size, total);
+ workers->Schedule([&work, &counter, start, limit]() {
+ work(start, limit); // Compute the shard.
+ counter.DecrementCount(); // The shard is done.
+ });
+ }
+
+ // Inline execute the 1st shard.
+ work(0, std::min(block_size, total));
+ counter.Wait();
+}
+
+} // end namespace tensorflow
diff --git a/tensorflow/core/util/work_sharder.h b/tensorflow/core/util/work_sharder.h
new file mode 100644
index 0000000000..1ea2cf4397
--- /dev/null
+++ b/tensorflow/core/util/work_sharder.h
@@ -0,0 +1,33 @@
+#ifndef TENSORFLOW_UTIL_WORK_SHARDER_H_
+#define TENSORFLOW_UTIL_WORK_SHARDER_H_
+
+#include <functional>
+
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/lib/core/threadpool.h"
+
+namespace tensorflow {
+
+// Shards the "total" unit of work assuming each unit of work having
+// roughly "cost_per_unit". Each unit of work is indexed 0, 1, ...,
+// total - 1. Each shard contains 1 or more units of work and the
+// total cost of each shard is roughly the same. The total number of
+// shards is no more than num_workers. The calling thread and the
+// "workers" are used to compute each shard (calling work(start,
+// limit). A common configuration is that "workers" is a thread pool
+// with "num_workers" threads.
+//
+// "work" should be a callable taking (int64, int64) arguments.
+// work(start, limit) computes the work units from [start,
+// limit), i.e., [start, limit) is a shard.
+//
+// REQUIRES: num_workers >= 0
+// REQUIRES: workers != nullptr
+// REQUIRES: total >= 0
+// REQUIRES: cost_per_unit >= 0
+void Shard(int num_workers, thread::ThreadPool* workers, int64 total,
+ int64 cost_per_unit, std::function<void(int64, int64)> work);
+
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_UTIL_WORK_SHARDER_H_
diff --git a/tensorflow/core/util/work_sharder_test.cc b/tensorflow/core/util/work_sharder_test.cc
new file mode 100644
index 0000000000..d9792c0e8d
--- /dev/null
+++ b/tensorflow/core/util/work_sharder_test.cc
@@ -0,0 +1,57 @@
+#include "tensorflow/core/util/work_sharder.h"
+
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/lib/core/threadpool.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+namespace {
+
+void RunSharding(int64 num_workers, int64 total, int64 cost_per_unit) {
+ thread::ThreadPool threads(Env::Default(), "test", 16);
+ mutex mu;
+ int64 num_shards = 0;
+ int64 num_done_work = 0;
+ std::vector<bool> work(total, false);
+ Shard(num_workers, &threads, total, cost_per_unit,
+ [&mu, &num_shards, &num_done_work, &work](int start, int limit) {
+ VLOG(1) << "Shard [" << start << "," << limit << ")";
+ mutex_lock l(mu);
+ ++num_shards;
+ for (; start < limit; ++start) {
+ EXPECT_FALSE(work[start]); // No duplicate
+ ++num_done_work;
+ work[start] = true;
+ }
+ });
+ EXPECT_LE(num_shards, num_workers + 1);
+ EXPECT_EQ(num_done_work, total);
+ LOG(INFO) << num_workers << " " << total << " " << cost_per_unit << " "
+ << num_shards;
+}
+
+TEST(Shard, Basic) {
+ for (auto workers : {0, 1, 2, 3, 5, 7, 10, 11, 15, 100, 1000}) {
+ for (auto total : {0, 1, 7, 10, 64, 100, 256, 1000, 9999}) {
+ for (auto cost_per_unit : {0, 1, 11, 102, 1003, 10005, 1000007}) {
+ RunSharding(workers, total, cost_per_unit);
+ }
+ }
+ }
+}
+
+void BM_Sharding(int iters, int arg) {
+ thread::ThreadPool threads(Env::Default(), "test", 16);
+ const int64 total = 1LL << 30;
+ auto lambda = [](int64 start, int64 limit) {};
+ auto work = std::cref(lambda);
+ for (; iters > 0; iters -= arg) {
+ Shard(arg - 1, &threads, total, 1, work);
+ }
+}
+BENCHMARK(BM_Sharding)->Range(1, 128);
+
+} // namespace
+} // namespace tensorflow