diff options
Diffstat (limited to 'tensorflow/core/util')
46 files changed, 5869 insertions, 0 deletions
diff --git a/tensorflow/core/util/bcast.cc b/tensorflow/core/util/bcast.cc new file mode 100644 index 0000000000..4e70b78751 --- /dev/null +++ b/tensorflow/core/util/bcast.cc @@ -0,0 +1,120 @@ +#include "tensorflow/core/util/bcast.h" + +#include "tensorflow/core/platform/logging.h" +namespace tensorflow { + +/* static */ +void BCast::Reverse(Vec* shape) { std::reverse(shape->begin(), shape->end()); } + +BCast::BCast(const Vec& sx, const Vec& sy) { + // Reverse the shape of x and y for convenience. + // After the reverse, 0-th is the inner-most dimension. + Vec x = sx; + Reverse(&x); + Vec y = sy; + Reverse(&y); + + // 1-extend and align x and y so that they are the same size. + if (x.size() > y.size()) { + y.resize(x.size(), 1); + } else { + x.resize(y.size(), 1); + } + + // Going through each dimension starting from the inner-most + // dimension, compares dimension of x and y. They are compatible if + // they are equal or either is 1. + enum State { + UNKNOWN, + SAME, + X_ONE, + Y_ONE, + }; + State prev = UNKNOWN; + const int64 n = x.size(); + for (int i = 0; i < n; ++i) { + // Output shape. + State curr = UNKNOWN; + const int64 x_i = x[i]; // i-th dimension of x. + CHECK_GE(x_i, 0); + const int64 y_i = y[i]; // i-th dimension of y. + CHECK_GE(y_i, 0); + int64 o_i; // i-th dimension of the output. + int64 bx_i; // i-th broadcast for x. + int64 by_i; // i-th broadcast for y. + // Invariant: + // o_i = x_i * bx_i = y_i * by_i + if (x_i == y_i) { + // No broadcast. + o_i = x_i; + bx_i = 1; + by_i = 1; + curr = SAME; + } else if (x_i == 1) { + // x broadcast to y on this dimension. + o_i = y_i; + bx_i = y_i; + by_i = 1; + grad_x_reduce_idx_.push_back(n - 1 - i); + curr = X_ONE; + } else if (y_i == 1) { + // y broadcast to x on this dimension. + o_i = x_i; + bx_i = 1; + by_i = x_i; + grad_y_reduce_idx_.push_back(n - 1 - i); + curr = Y_ONE; + } else { + valid_ = false; + return; + } + output_.push_back(o_i); + // Reshape/broadcast. + // Invariant: + // result[i] == x_reshape[i] * x_bcast[i] == y_reshape_[i] * y_bcast_[i] + if (curr == SAME && x_i == 1) { + // Both side are 1s. + grad_x_reduce_idx_.push_back(n - 1 - i); + grad_y_reduce_idx_.push_back(n - 1 - i); + continue; + } else if (prev == curr) { + // It is a run of the same cases (no broadcast, x broadcast to + // y, y broadcast to x). We can reshape the input so that fewer + // dimensions are involved in the intermediate computation. + result_.back() *= o_i; + x_reshape_.back() *= x_i; + x_bcast_.back() *= bx_i; + y_reshape_.back() *= y_i; + y_bcast_.back() *= by_i; + } else { + result_.push_back(o_i); + x_reshape_.push_back(x_i); + x_bcast_.push_back(bx_i); + y_reshape_.push_back(y_i); + y_bcast_.push_back(by_i); + } + prev = curr; + } + + if (result_.empty()) { + // Can happen when both x and y are effectively scalar. + result_.push_back(1); + x_reshape_.push_back(1); + x_bcast_.push_back(1); + y_reshape_.push_back(1); + y_bcast_.push_back(1); + } + + // Reverse all vectors since x and y were reversed at very + // beginning. + Reverse(&x_reshape_); + Reverse(&x_bcast_); + Reverse(&y_reshape_); + Reverse(&y_bcast_); + Reverse(&result_); + Reverse(&output_); + Reverse(&grad_x_reduce_idx_); + Reverse(&grad_y_reduce_idx_); +} + +} // end namespace tensorflow diff --git a/tensorflow/core/util/bcast.h b/tensorflow/core/util/bcast.h new file mode 100644 index 0000000000..9f0233e415 --- /dev/null +++ b/tensorflow/core/util/bcast.h @@ -0,0 +1,99 @@ +#ifndef TENSORFLOW_UTIL_BCAST_H_ +#define TENSORFLOW_UTIL_BCAST_H_ + +#include <algorithm> +#include <vector> + +#include "tensorflow/core/platform/port.h" + +#include "tensorflow/core/platform/logging.h" +namespace tensorflow { + +// BCast is a helper for broadcasting binary tensor operation. +// TensorFlow's broadcasting rule follows that of numpy (See +// http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html). +// +// The rule has the following properties: +// +// 1. suffix matching: the rule starts with the right-most +// dimension, and works towards the left-most dimension. Since +// TensorFlow is row-major, the right-most dimension (the last +// element in the shape of a tensor) is the inner-most, a.k.a. +// the fastest changing, dimension. +// +// 2. Two dimensions are compatible for broadcasting if both are the +// same or either is 1. +// +// BCast takes the shape of two tensors and computes a few vectors of +// int32 that are useful for the caller to reshape the tensors, apply +// the right broadcasts to them, compute the broadcasted operation, +// and possibly the gradients. In a nutshell, the caller is expected +// to compute the broadcasted operation as following: +// +// BCast b(x.shape(), y.shape()); +// output = x.reshape(b.x_reshape()).broadcast(b.x_bcast()) +// _op_ +// y.reshape(b.y_reshape()).broadcast(b.y_bcast()) +// +// For the gradient computation, +// grad_x = sum(grad * backprop_x(x, y), grad_x_reduce_idx) +// .reshape(x.shape()) +// grad_y = sum(grad * backprop_y(x, y), grad_y_reduce_idx) +// .reshape(y.shape()) +// backprop_x and backprop_y are functionals of the binary function "op", +// e.g., +// for +, backprop_x(x, y) = backprop_y(x, y) = 1; +// for *, backprop_x(x, y) = y, backprop_y(x, y) = x; +// for /, backprop_x(x, y) = 1/y, backprop_y(x, y) = -x/y^2; +// +// The multiplication in the grad * backprop_x itself is also +// broadcasting following the same rule. +// +// TODO(zhifengc): Adds support for n-ary (n >= 2). +class BCast { + public: + // A vector of int32 representing the shape of tensor. The 0-th + // element is the outer-most dimension and the last element is the + // inner-most dimension. Note that we do not use TensorShape since + // it's more convenient to manipulate Vec directly for this module. + typedef std::vector<int64> Vec; + + BCast(const Vec& x, const Vec& y); + ~BCast() {} + + // Returns true iff two operands are compatible according to the + // broadcasting rule. + bool IsValid() const { return valid_; } + + // If and only if IsValid(), the following fields can be used in + // implementing a broadcasted binary tensor operation according to + // the broadcasting rule. + const Vec& x_reshape() const { return x_reshape_; } + const Vec& x_bcast() const { return x_bcast_; } + const Vec& y_reshape() const { return y_reshape_; } + const Vec& y_bcast() const { return y_bcast_; } + const Vec& result_shape() const { return result_; } + const Vec& output_shape() const { return output_; } + const Vec& grad_x_reduce_idx() const { return grad_x_reduce_idx_; } + const Vec& grad_y_reduce_idx() const { return grad_y_reduce_idx_; } + + private: + bool valid_ = true; + Vec x_reshape_; + Vec x_bcast_; + Vec y_reshape_; + Vec y_bcast_; + Vec result_; + Vec output_; + Vec grad_x_reduce_idx_; + Vec grad_y_reduce_idx_; + + static void Reverse(Vec* shape); + static bool HasZero(const Vec& shape); + + TF_DISALLOW_COPY_AND_ASSIGN(BCast); +}; + +} // end namespace tensorflow + +#endif // TENSORFLOW_UTIL_BCAST_H_ diff --git a/tensorflow/core/util/bcast_test.cc b/tensorflow/core/util/bcast_test.cc new file mode 100644 index 0000000000..02d18586d6 --- /dev/null +++ b/tensorflow/core/util/bcast_test.cc @@ -0,0 +1,226 @@ +#include "tensorflow/core/util/bcast.h" + +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include <gtest/gtest.h> + +namespace tensorflow { +namespace { + +string BCast(const tensorflow::BCast::Vec& x, const tensorflow::BCast::Vec& y) { + tensorflow::BCast b(x, y); + if (!b.IsValid()) { + return "invalid"; + } + string ret; + strings::StrAppend(&ret, "[", str_util::Join(b.x_reshape(), ","), "]"); + strings::StrAppend(&ret, "[", str_util::Join(b.x_bcast(), ","), "]"); + strings::StrAppend(&ret, "[", str_util::Join(b.y_reshape(), ","), "]"); + strings::StrAppend(&ret, "[", str_util::Join(b.y_bcast(), ","), "]"); + strings::StrAppend(&ret, "[", str_util::Join(b.result_shape(), ","), "]"); + strings::StrAppend(&ret, "[", str_util::Join(b.output_shape(), ","), "]"); + strings::StrAppend(&ret, "[", str_util::Join(b.grad_x_reduce_idx(), ","), + "]"); + strings::StrAppend(&ret, "[", str_util::Join(b.grad_y_reduce_idx(), ","), + "]"); + return ret; +} + +TEST(BCastTest, Invalid) { + EXPECT_EQ("invalid", BCast({5, 3, 2}, {3})); + EXPECT_EQ("invalid", BCast({5, 3, 2}, {2, 2})); + EXPECT_EQ("invalid", BCast({5, 3, 2}, {10, 1, 1})); + EXPECT_EQ("invalid", BCast({1, 2, 1, 2, 1, 2}, {2, 4, 2, 1, 2, 1})); +} + +TEST(BCastTest, Basic_SameShape) { + // Effectively no broadcast needed. + EXPECT_EQ(BCast({11, 7, 5, 3, 2}, {11, 7, 5, 3, 2}), + "[2310][1][2310][1]" + "[2310]" + "[11,7,5,3,2]" + "[][]"); +} + +TEST(BCastTest, Basic_Scalar_Scalar) { + // Effectively it's a scalar and a scalar. + // [1, 1] [1] + EXPECT_EQ(BCast({1, 1}, {1}), + "[1][1][1][1]" + "[1]" + "[1,1]" + "[0,1][0,1]"); + + // [1] [1, 1] + EXPECT_EQ(BCast({1}, {1, 1}), + "[1][1][1][1]" + "[1]" + "[1,1]" + "[0,1][0,1]"); +} + +TEST(BCastTest, Basic_Tensor_Scalar) { + // Effectively it's a tensor and a scalar. + // [11, 7, 5, 3, 2] [1] + EXPECT_EQ(BCast({11, 7, 5, 3, 2}, {1}), + "[2310][1][1][2310]" + "[2310]" + "[11,7,5,3,2]" + "[][0,1,2,3,4]"); + + // [1] [11, 7, 5, 3, 2] + EXPECT_EQ(BCast({1}, {11, 7, 5, 3, 2}), + "[1][2310][2310][1]" + "[2310]" + "[11,7,5,3,2]" + "[0,1,2,3,4][]"); +} + +TEST(BCastTest, Basic_Tensor_With_DimSize_1_Scalar) { + // Effectively it's a tensor and a scalar. + // [11, 7, 5, 3, 2, 1] [1] + EXPECT_EQ(BCast({11, 7, 5, 3, 2, 1}, {1}), + "[2310][1][1][2310]" + "[2310]" + "[11,7,5,3,2,1]" + "[5][0,1,2,3,4,5]"); + + // [1] [11, 7, 5, 3, 2, 1] + EXPECT_EQ(BCast({1}, {11, 7, 5, 3, 2, 1}), + "[1][2310][2310][1]" + "[2310]" + "[11,7,5,3,2,1]" + "[0,1,2,3,4,5][5]"); + + // Effectively it's a tensor and a scalar. + // [11, 7, 5, 1, 1, 3, 2, 1] [1] + EXPECT_EQ(BCast({11, 7, 5, 1, 1, 3, 2, 1, 1}, {1}), + "[2310][1][1][2310]" + "[2310]" + "[11,7,5,1,1,3,2,1,1]" + "[3,4,7,8][0,1,2,3,4,5,6,7,8]"); + + // [1] [11, 7, 5, 1, 1, 3, 2, 1] + EXPECT_EQ(BCast({1}, {11, 7, 5, 1, 1, 3, 2, 1, 1}), + "[1][2310][2310][1]" + "[2310]" + "[11,7,5,1,1,3,2,1,1]" + "[0,1,2,3,4,5,6,7,8][3,4,7,8]"); +} + +TEST(BCastTest, Basic_Tensor_Vector) { + // [11, 7, 5, 3, 2] [2] + EXPECT_EQ(BCast({11, 7, 5, 3, 2}, {2}), + "[1155,2][1,1][1,2][1155,1]" + "[1155,2]" + "[11,7,5,3,2]" + "[][0,1,2,3]"); + + // [2] [11, 7, 5, 3, 2] + EXPECT_EQ(BCast({2}, {11, 7, 5, 3, 2}), + "[1,2][1155,1][1155,2][1,1]" + "[1155,2]" + "[11,7,5,3,2]" + "[0,1,2,3][]"); +} + +TEST(BCastTest, Basic_Tensor_Matrix) { + // [11, 7, 5, 3, 2] [3, 2] + EXPECT_EQ(BCast({11, 7, 5, 3, 2}, {3, 2}), + "[385,6][1,1][1,6][385,1]" + "[385,6]" + "[11,7,5,3,2]" + "[][0,1,2]"); + // [3, 2] [11, 7, 5, 3, 2] + EXPECT_EQ(BCast({3, 2}, {11, 7, 5, 3, 2}), + "[1,6][385,1][385,6][1,1]" + "[385,6]" + "[11,7,5,3,2]" + "[0,1,2][]"); +} + +TEST(BCastTest, Basic_Tensor_Matrix_Column) { + // [11, 7, 5, 3, 2] [3, 1] + EXPECT_EQ(BCast({11, 7, 5, 3, 2}, {3, 1}), + "[385,3,2][1,1,1][1,3,1][385,1,2]" + "[385,3,2]" + "[11,7,5,3,2]" + "[][0,1,2,4]"); + + // [3, 1] [11, 7, 5, 3, 2] + EXPECT_EQ(BCast({3, 1}, {11, 7, 5, 3, 2}), + "[1,3,1][385,1,2][385,3,2][1,1,1]" + "[385,3,2]" + "[11,7,5,3,2]" + "[0,1,2,4][]"); +} + +TEST(BCastTest, Basic_Tensor_Matrix_As_Tensor) { + // [11, 7, 5, 3, 2] [7, 5, 1, 1] + EXPECT_EQ(BCast({11, 7, 5, 3, 2}, {7, 5, 1, 1}), + "[11,35,6][1,1,1][1,35,1][11,1,6]" + "[11,35,6]" + "[11,7,5,3,2]" + "[][0,3,4]"); + + // [7, 5, 1, 1] [11, 7, 5, 3, 2] + EXPECT_EQ(BCast({7, 5, 1, 1}, {11, 7, 5, 3, 2}), + "[1,35,1][11,1,6][11,35,6][1,1,1]" + "[11,35,6]" + "[11,7,5,3,2]" + "[0,3,4][]"); +} + +TEST(BCastTest, Complex_BCast_To_Each_Other) { + // Rare cases. x and y broadcast to each other. x and y are of + // different ranks. + // Can be verified in numpy as: + // import numpy as np + // x = np.arange(0,110).reshape([11,1,5,1,2]) + // y = np.arange(0,21).reshape([7,1,3,1]) + // np.shape(x + y) + // Out[.]: (11, 7, 5, 3, 2) + EXPECT_EQ(BCast({11, 1, 5, 1, 2}, {7, 1, 3, 1}), + "[11,1,5,1,2][1,7,1,3,1][1,7,1,3,1][11,1,5,1,2]" + "[11,7,5,3,2]" + "[11,7,5,3,2]" + "[1,3][0,2,4]"); +} + +TEST(BCastTest, TestZeroDimensionShape) { + EXPECT_EQ(BCast({2, 0, 5}, {5}), + "[0,5][1,1][1,5][0,1]" + "[0,5]" + "[2,0,5]" + "[][0,1]"); + EXPECT_EQ(BCast({5}, {2, 0, 5}), + "[1,5][0,1][0,5][1,1]" + "[0,5]" + "[2,0,5]" + "[0,1][]"); + + EXPECT_EQ(BCast({2, 0, 3, 0, 5}, {5}), + "[0,5][1,1][1,5][0,1]" + "[0,5]" + "[2,0,3,0,5]" + "[][0,1,2,3]"); + EXPECT_EQ(BCast({5}, {2, 0, 3, 0, 5}), + "[1,5][0,1][0,5][1,1]" + "[0,5]" + "[2,0,3,0,5]" + "[0,1,2,3][]"); + + EXPECT_EQ(BCast({2, 0, 3, 0, 5}, {3, 1, 5}), + "[0,3,0,5][1,1,1,1][1,3,1,5][0,1,0,1]" + "[0,3,0,5]" + "[2,0,3,0,5]" + "[][0,1,3]"); + EXPECT_EQ(BCast({3, 1, 5}, {2, 0, 3, 0, 5}), + "[1,3,1,5][0,1,0,1][0,3,0,5][1,1,1,1]" + "[0,3,0,5]" + "[2,0,3,0,5]" + "[0,1,3][]"); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/util/device_name_utils.cc b/tensorflow/core/util/device_name_utils.cc new file mode 100644 index 0000000000..b8c6a77dd0 --- /dev/null +++ b/tensorflow/core/util/device_name_utils.cc @@ -0,0 +1,338 @@ +#include "tensorflow/core/util/device_name_utils.h" + +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { + +static bool IsAlpha(char c) { + return (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z'); +} + +static bool IsAlphaNum(char c) { return IsAlpha(c) || (c >= '0' && c <= '9'); } + +// Returns true iff "in" is a valid job name. +static bool IsJobName(StringPiece in) { + if (in.empty()) return false; + if (!IsAlpha(in[0])) return false; + for (size_t i = 1; i < in.size(); ++i) { + if (!(IsAlphaNum(in[i]) || in[i] == '_')) return false; + } + return true; +} + +// Returns true and fills in "*job" iff "*in" starts with a job name. +static bool ConsumeJobName(StringPiece* in, string* job) { + if (in->empty()) return false; + if (!IsAlpha((*in)[0])) return false; + size_t i = 1; + for (; i < in->size(); ++i) { + const char c = (*in)[i]; + if (c == '/') break; + if (!(IsAlphaNum(c) || c == '_')) { + return false; + } + } + job->assign(in->data(), i); + in->remove_prefix(i); + return true; +} + +// Returns true and fills in "*device_type" iff "*in" starts with a device type +// name. +static bool ConsumeDeviceType(StringPiece* in, string* device_type) { + if (in->empty()) return false; + if (!IsAlpha((*in)[0])) return false; + size_t i = 1; + for (; i < in->size(); ++i) { + const char c = (*in)[i]; + if (c == '/' || c == ':') break; + if (!(IsAlphaNum(c) || c == '_')) { + return false; + } + } + device_type->assign(in->data(), i); + in->remove_prefix(i); + return true; +} + +// Returns true and fills in "*val" iff "*in" starts with a decimal +// number. +static bool ConsumeNumber(StringPiece* in, int* val) { + uint64 tmp; + if (str_util::ConsumeLeadingDigits(in, &tmp)) { + *val = tmp; + return true; + } else { + return false; + } +} + +/* static */ +string DeviceNameUtils::FullName(const string& job, int replica, int task, + const string& type, int id) { + CHECK(IsJobName(job)) << job; + CHECK_LE(0, replica); + CHECK_LE(0, task); + CHECK(!type.empty()); + CHECK_LE(0, id); + return strings::StrCat("/job:", job, "/replica:", replica, "/task:", task, + "/device:", type, ":", id); +} + +bool DeviceNameUtils::ParseFullName(StringPiece fullname, ParsedName* p) { + p->Clear(); + if (fullname == "/") { + return true; + } + StringPiece tmp; + while (!fullname.empty()) { + if (str_util::ConsumePrefix(&fullname, "/job:")) { + p->has_job = !str_util::ConsumePrefix(&fullname, "*"); + if (p->has_job && !ConsumeJobName(&fullname, &p->job)) { + return false; + } + } else if (str_util::ConsumePrefix(&fullname, "/replica:")) { + p->has_replica = !str_util::ConsumePrefix(&fullname, "*"); + if (p->has_replica && !ConsumeNumber(&fullname, &p->replica)) { + return false; + } + } else if (str_util::ConsumePrefix(&fullname, "/task:")) { + p->has_task = !str_util::ConsumePrefix(&fullname, "*"); + if (p->has_task && !ConsumeNumber(&fullname, &p->task)) { + return false; + } + } else if (str_util::ConsumePrefix(&fullname, "/device:")) { + p->has_type = !str_util::ConsumePrefix(&fullname, "*"); + if (p->has_type && !ConsumeDeviceType(&fullname, &p->type)) { + return false; + } + if (!str_util::ConsumePrefix(&fullname, ":")) { + p->has_id = false; + } else { + p->has_id = !str_util::ConsumePrefix(&fullname, "*"); + if (p->has_id && !ConsumeNumber(&fullname, &p->id)) { + return false; + } + } + + } else if (str_util::ConsumePrefix(&fullname, "/cpu:") || + str_util::ConsumePrefix(&fullname, "/CPU:")) { + p->has_type = true; + p->type = "CPU"; // Treat '/cpu:..' as uppercase '/device:CPU:...' + p->has_id = !str_util::ConsumePrefix(&fullname, "*"); + if (p->has_id && !ConsumeNumber(&fullname, &p->id)) { + return false; + } + } else if (str_util::ConsumePrefix(&fullname, "/gpu:") || + str_util::ConsumePrefix(&fullname, "/GPU:")) { + p->has_type = true; + p->type = "GPU"; // Treat '/gpu:..' as uppercase '/device:GPU:...' + p->has_id = !str_util::ConsumePrefix(&fullname, "*"); + if (p->has_id && !ConsumeNumber(&fullname, &p->id)) { + return false; + } + } else { + return false; + } + } + return true; +} + +/* static */ +string DeviceNameUtils::ParsedNameToString(const ParsedName& pn) { + string buf; + if (pn.has_job) strings::StrAppend(&buf, "/job:", pn.job); + if (pn.has_replica) strings::StrAppend(&buf, "/replica:", pn.replica); + if (pn.has_task) strings::StrAppend(&buf, "/task:", pn.task); + if (pn.has_type) { + strings::StrAppend(&buf, "/", pn.type, ":"); + if (pn.has_id) { + strings::StrAppend(&buf, pn.id); + } else { + strings::StrAppend(&buf, "*"); + } + } + return buf; +} + +/* static */ +bool DeviceNameUtils::IsSpecification(const ParsedName& less_specific, + const ParsedName& more_specific) { + if (less_specific.has_job && + (!more_specific.has_job || (less_specific.job != more_specific.job))) { + return false; + } + if (less_specific.has_replica && + (!more_specific.has_replica || + (less_specific.replica != more_specific.replica))) { + return false; + } + if (less_specific.has_task && + (!more_specific.has_task || (less_specific.task != more_specific.task))) { + return false; + } + if (less_specific.has_type && + (!more_specific.has_type || (less_specific.type != more_specific.type))) { + return false; + } + if (less_specific.has_id && + (!more_specific.has_id || (less_specific.id != more_specific.id))) { + return false; + } + return true; +} + +/* static */ +bool DeviceNameUtils::IsCompleteSpecification(const ParsedName& pattern, + const ParsedName& name) { + CHECK(name.has_job && name.has_replica && name.has_task && name.has_type && + name.has_id); + + if (pattern.has_job && (pattern.job != name.job)) return false; + if (pattern.has_replica && (pattern.replica != name.replica)) return false; + if (pattern.has_task && (pattern.task != name.task)) return false; + if (pattern.has_type && (pattern.type != name.type)) return false; + if (pattern.has_id && (pattern.id != name.id)) return false; + return true; +} + +/* static */ +Status DeviceNameUtils::MergeDevNames(ParsedName* target, + const ParsedName& other, + bool allow_soft_placement) { + if (other.has_job) { + if (target->has_job && target->job != other.job) { + return errors::InvalidArgument( + "Cannot merge devices with incompatible jobs: '", + ParsedNameToString(*target), "' and '", ParsedNameToString(other), + "'"); + } else { + target->has_job = other.has_job; + target->job = other.job; + } + } + + if (other.has_replica) { + if (target->has_replica && target->replica != other.replica) { + return errors::InvalidArgument( + "Cannot merge devices with incompatible replicas: '", + ParsedNameToString(*target), "' and '", ParsedNameToString(other), + "'"); + } else { + target->has_replica = other.has_replica; + target->replica = other.replica; + } + } + + if (other.has_task) { + if (target->has_task && target->task != other.task) { + return errors::InvalidArgument( + "Cannot merge devices with incompatible tasks: '", + ParsedNameToString(*target), "' and '", ParsedNameToString(other), + "'"); + } else { + target->has_task = other.has_task; + target->task = other.task; + } + } + + if (other.has_type) { + if (target->has_type && target->type != other.type) { + if (!allow_soft_placement) { + return errors::InvalidArgument( + "Cannot merge devices with incompatible types: '", + ParsedNameToString(*target), "' and '", ParsedNameToString(other), + "'"); + } else { + target->has_id = false; + target->has_type = false; + return Status::OK(); + } + } else { + target->has_type = other.has_type; + target->type = other.type; + } + } + + if (other.has_id) { + if (target->has_id && target->id != other.id) { + if (!allow_soft_placement) { + return errors::InvalidArgument( + "Cannot merge devices with incompatible ids: '", + ParsedNameToString(*target), "' and '", ParsedNameToString(other), + "'"); + } else { + target->has_id = false; + return Status::OK(); + } + } else { + target->has_id = other.has_id; + target->id = other.id; + } + } + + return Status::OK(); +} + +/* static */ +bool DeviceNameUtils::IsSameAddressSpace(const ParsedName& a, + const ParsedName& b) { + return (a.has_job && b.has_job && (a.job == b.job)) && + (a.has_replica && b.has_replica && (a.replica == b.replica)) && + (a.has_task && b.has_task && (a.task == b.task)); +} + +/* static */ +bool DeviceNameUtils::IsSameAddressSpace(StringPiece src, StringPiece dst) { + ParsedName x; + ParsedName y; + return ParseFullName(src, &x) && ParseFullName(dst, &y) && + IsSameAddressSpace(x, y); +} + +/* static */ +string DeviceNameUtils::LocalName(StringPiece type, int id) { + return strings::StrCat(type, ":", id); +} + +/* static */ +string DeviceNameUtils::LocalName(StringPiece fullname) { + ParsedName x; + CHECK(ParseFullName(fullname, &x)) << fullname; + return LocalName(x.type, x.id); +} + +/* static */ +bool DeviceNameUtils::ParseLocalName(StringPiece name, ParsedName* p) { + ParsedName x; + if (!ConsumeDeviceType(&name, &p->type)) { + return false; + } + if (!str_util::ConsumePrefix(&name, ":")) { + return false; + } + if (!ConsumeNumber(&name, &p->id)) { + return false; + } + return name.empty(); +} + +/* static */ +bool DeviceNameUtils::SplitDeviceName(StringPiece name, string* task, + string* device) { + ParsedName pn; + if (ParseFullName(name, &pn) && pn.has_type && pn.has_id) { + *task = strings::StrCat( + (pn.has_job ? strings::StrCat("/job:", pn.job) : ""), + (pn.has_replica ? strings::StrCat("/replica:", pn.replica) : ""), + (pn.has_task ? strings::StrCat("/task:", pn.task) : "")); + *device = strings::StrCat(pn.type, ":", pn.id); + return true; + } + return false; +} + +} // namespace tensorflow diff --git a/tensorflow/core/util/device_name_utils.h b/tensorflow/core/util/device_name_utils.h new file mode 100644 index 0000000000..8b0a24ed0d --- /dev/null +++ b/tensorflow/core/util/device_name_utils.h @@ -0,0 +1,141 @@ +#ifndef TENSORFLOW_UTIL_DEVICE_NAME_UTILS_H_ +#define TENSORFLOW_UTIL_DEVICE_NAME_UTILS_H_ + +#include <string> + +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/public/status.h" + +namespace tensorflow { + +// In TensorFlow a device name is a string of the following form: +// /job:<name>/replica:<replica>/task:<task>/device:<type>:<device_num> +// +// <name> is a short identifier conforming to the regexp +// [a-zA-Z][_a-zA-Z]* +// <type> is a supported device type (e.g. 'cpu' or 'gpu') +// <replica>, <task>, <device_num> are small non-negative integers and are +// densely allocated (except in tests). +// +// For some purposes, we also allow device patterns, which can specify +// some or none of the specific fields above, with missing components, +// or "<component>:*" indicating "any value allowed for that component. +// +// For example: +// "/job:param_server" - Consider any devices in the "param_server" job +// "/device:cpu:*" - Consider any cpu devices in any job/task/replica +// "/job:*/replica:*/task:*/device:cpu:*" - Consider any cpu devices in any +// job/task/replica +// "/job:w/replica:0/task:0/device:gpu:*" - Consider any gpu devices in +// replica 0, task 0, of job "w" +class DeviceNameUtils { + public: + // Returns a fully qualified device name given the parameters. + static string FullName(const string& job, int replica, int task, + const string& type, int id); + + struct ParsedName { + void Clear() { + has_job = false; + has_replica = false; + has_task = false; + has_type = false; + has_id = false; + job.clear(); + replica = 0; + task = 0; + type.clear(); + id = 0; + } + + bool operator==(const ParsedName& other) const { + return (has_job ? (other.has_job && job == other.job) : !other.has_job) && + (has_replica ? (other.has_replica && replica == other.replica) + : !other.has_replica) && + (has_task ? (other.has_task && task == other.task) + : !other.has_task) && + (has_type ? (other.has_type && type == other.type) + : !other.has_type) && + (has_id ? (other.has_id && id == other.id) : !other.has_id); + } + + bool has_job = false; + string job; + bool has_replica = false; + int replica = 0; + bool has_task = false; + int task = 0; + bool has_type = false; + string type; + bool has_id = false; + int id = 0; + }; + // Parses "fullname" into "*parsed". Returns true iff succeeds. + static bool ParseFullName(StringPiece fullname, ParsedName* parsed); + + // Returns true if "name" specifies any non-trivial constraint on the device. + static bool HasSomeDetails(const ParsedName& name) { + return name.has_job || name.has_replica || name.has_task || name.has_type || + name.has_id; + } + + // Returns true if more_specific is a specification of + // less_specific, i.e. everywhere that less-specific has a + // non-wildcard component value, more_specific has the same value + // for that component. + static bool IsSpecification(const ParsedName& less_specific, + const ParsedName& more_specific); + + // Like IsSpecification, but the second argument "name" must have a + // non-wildcard value for all of its components. + static bool IsCompleteSpecification(const ParsedName& pattern, + const ParsedName& name); + + // True iff there exists any possible complete device name that is + // a specification of both "a" and "b". + static inline bool AreCompatibleDevNames(const ParsedName& a, + const ParsedName& b) { + return IsSpecification(a, b) || IsSpecification(b, a); + } + + // Merges the device specifications in "*target" and "other", and + // stores the result in "*target". Returns OK if "*target" and + // "other" are compatible, otherwise returns an error. + static Status MergeDevNames(ParsedName* target, const ParsedName& other) { + return MergeDevNames(target, other, false); + } + static Status MergeDevNames(ParsedName* target, const ParsedName& other, + bool allow_soft_placement); + + // Returns true iff devices identified by 'src' and 'dst' are in the + // same address space. + static bool IsSameAddressSpace(StringPiece src, StringPiece dst); + static bool IsSameAddressSpace(const ParsedName& src, const ParsedName& dst); + + // Returns the local device given its "type" and "id". + static string LocalName(StringPiece type, int id); + + // Returns a short local device name (cpu:0, gpu:1, etc) based on + // the given fullname. + static string LocalName(StringPiece fullname); + + // If "name" is a valid local device name (cpu:0, gpu:1, etc.), + // fills in parsed.type and parsed.id accordingly. Returns true iff + // succeeds. + static bool ParseLocalName(StringPiece name, ParsedName* parsed); + + // Splits a fully-qualified device name into a task identifier and a + // relative device identifier. It first parses "name" using + // ParseFullName(), then assigns *task with everything except for + // the local device component, and assigns the relative device + // component into *device. This function will still return true if + // the task component is empty, but it requires the relative device + // component to be fully specified. + static bool SplitDeviceName(StringPiece name, string* task, string* device); + + static string ParsedNameToString(const ParsedName& pn); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_UTIL_DEVICE_NAME_UTILS_H_ diff --git a/tensorflow/core/util/device_name_utils_test.cc b/tensorflow/core/util/device_name_utils_test.cc new file mode 100644 index 0000000000..14f30d6de5 --- /dev/null +++ b/tensorflow/core/util/device_name_utils_test.cc @@ -0,0 +1,369 @@ +#include "tensorflow/core/util/device_name_utils.h" + +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test_benchmark.h" +#include <gtest/gtest.h> + +namespace tensorflow { + +TEST(DeviceNameUtilsTest, Basic) { + EXPECT_EQ(DeviceNameUtils::FullName("hello", 1, 2, "CPU", 3), + "/job:hello/replica:1/task:2/device:CPU:3"); + + { + DeviceNameUtils::ParsedName p; + EXPECT_FALSE(DeviceNameUtils::ParseFullName("foobar", &p)); + EXPECT_FALSE( + DeviceNameUtils::ParseFullName("/job:123/replica:1/task:2/gpu:3", &p)); + EXPECT_FALSE( + DeviceNameUtils::ParseFullName("/job:123/replica:1/task:2/gpu:", &p)); + EXPECT_FALSE(DeviceNameUtils::ParseFullName( + "/job:123/replica:1/task:2/device:gpu:", &p)); + EXPECT_FALSE( + DeviceNameUtils::ParseFullName("/job:foo/replica:-1/task:2/gpu:3", &p)); + EXPECT_FALSE( + DeviceNameUtils::ParseFullName("/job:foo/replica:1/task:-2/gpu:3", &p)); + EXPECT_FALSE( + DeviceNameUtils::ParseFullName("/job:foo/replica:1/task:2/bar:3", &p)); + EXPECT_FALSE(DeviceNameUtils::ParseFullName( + "/job:foo/replica:1/task:2/gpu:3/extra", &p)); + EXPECT_TRUE( + DeviceNameUtils::ParseFullName("/job:foo/replica:1/task:2/gpu:3", &p)); + EXPECT_TRUE(p.has_job); + EXPECT_TRUE(p.has_replica); + EXPECT_TRUE(p.has_task); + EXPECT_TRUE(p.has_type); + EXPECT_TRUE(p.has_id); + EXPECT_EQ(p.job, "foo"); + EXPECT_EQ(p.replica, 1); + EXPECT_EQ(p.task, 2); + EXPECT_EQ(p.type, "GPU"); + EXPECT_EQ(p.id, 3); + } + { + // Allow _ in job names. + DeviceNameUtils::ParsedName p; + EXPECT_TRUE(DeviceNameUtils::ParseFullName( + "/job:foo_bar/replica:1/task:2/gpu:3", &p)); + EXPECT_TRUE(p.has_job); + EXPECT_TRUE(p.has_replica); + EXPECT_TRUE(p.has_task); + EXPECT_TRUE(p.has_type); + EXPECT_TRUE(p.has_id); + EXPECT_EQ(p.job, "foo_bar"); + EXPECT_EQ(p.replica, 1); + EXPECT_EQ(p.task, 2); + EXPECT_EQ(p.type, "GPU"); + EXPECT_EQ(p.id, 3); + } + { + // Allow _ in job names. + DeviceNameUtils::ParsedName p; + EXPECT_TRUE(DeviceNameUtils::ParseFullName( + "/job:foo_bar/replica:1/task:2/device:GPU:3", &p)); + EXPECT_TRUE(p.has_job); + EXPECT_TRUE(p.has_replica); + EXPECT_TRUE(p.has_task); + EXPECT_TRUE(p.has_type); + EXPECT_TRUE(p.has_id); + EXPECT_EQ(p.job, "foo_bar"); + EXPECT_EQ(p.replica, 1); + EXPECT_EQ(p.task, 2); + EXPECT_EQ(p.type, "GPU"); + EXPECT_EQ(p.id, 3); + } + { + DeviceNameUtils::ParsedName p; + EXPECT_TRUE(DeviceNameUtils::ParseFullName("/job:*/replica:4/gpu:*", &p)); + EXPECT_FALSE(p.has_job); + EXPECT_TRUE(p.has_replica); + EXPECT_FALSE(p.has_task); + EXPECT_TRUE(p.has_type); + EXPECT_FALSE(p.has_id); + EXPECT_EQ(p.replica, 4); + EXPECT_EQ(p.type, "GPU"); + } + { + DeviceNameUtils::ParsedName p; + EXPECT_TRUE( + DeviceNameUtils::ParseFullName("/job:*/replica:4/device:GPU:*", &p)); + EXPECT_FALSE(p.has_job); + EXPECT_TRUE(p.has_replica); + EXPECT_FALSE(p.has_task); + EXPECT_TRUE(p.has_type); + EXPECT_FALSE(p.has_id); + EXPECT_EQ(p.replica, 4); + EXPECT_EQ(p.type, "GPU"); + } + { + DeviceNameUtils::ParsedName p; + EXPECT_TRUE( + DeviceNameUtils::ParseFullName("/job:*/device:GPU/replica:4", &p)); + EXPECT_FALSE(p.has_job); + EXPECT_TRUE(p.has_replica); + EXPECT_FALSE(p.has_task); + EXPECT_TRUE(p.has_type); + EXPECT_FALSE(p.has_id); + EXPECT_EQ(p.replica, 4); + EXPECT_EQ(p.type, "GPU"); + } + { + DeviceNameUtils::ParsedName p; + EXPECT_TRUE(DeviceNameUtils::ParseFullName( + "/job:*/replica:4/device:myspecialdevice:13", &p)); + EXPECT_FALSE(p.has_job); + EXPECT_TRUE(p.has_replica); + EXPECT_FALSE(p.has_task); + EXPECT_TRUE(p.has_type); + EXPECT_TRUE(p.has_id); + EXPECT_EQ(p.replica, 4); + EXPECT_EQ(p.type, "myspecialdevice"); + EXPECT_EQ(p.id, 13); + } + { + DeviceNameUtils::ParsedName p; + EXPECT_TRUE(DeviceNameUtils::ParseFullName("/", &p)); + EXPECT_FALSE(p.has_job); + EXPECT_FALSE(p.has_replica); + EXPECT_FALSE(p.has_task); + EXPECT_FALSE(p.has_type); + EXPECT_FALSE(p.has_id); + } + { + DeviceNameUtils::ParsedName p; + EXPECT_TRUE(DeviceNameUtils::ParseFullName("/job:*/replica:4/gpu:5", &p)); + EXPECT_FALSE(p.has_job); + EXPECT_TRUE(p.has_replica); + EXPECT_FALSE(p.has_task); + EXPECT_TRUE(p.has_type); + EXPECT_TRUE(p.has_id); + EXPECT_EQ(p.replica, 4); + EXPECT_EQ(p.type, "GPU"); + EXPECT_EQ(p.id, 5); + } + { // Same result if we reorder the components + DeviceNameUtils::ParsedName p; + EXPECT_TRUE(DeviceNameUtils::ParseFullName("/gpu:*/job:*/replica:4", &p)); + EXPECT_FALSE(p.has_job); + EXPECT_TRUE(p.has_replica); + EXPECT_FALSE(p.has_task); + EXPECT_TRUE(p.has_type); + EXPECT_FALSE(p.has_id); + EXPECT_EQ(p.replica, 4); + EXPECT_EQ(p.type, "GPU"); + } + + EXPECT_TRUE(DeviceNameUtils::IsSameAddressSpace( + "/job:foo/replica:1/task:2/cpu:3", "/job:foo/replica:1/task:2/gpu:4")); + EXPECT_FALSE(DeviceNameUtils::IsSameAddressSpace( + "/job:foo/replica:1/task:2/cpu:3", "/job:foo/replica:1/task:3/gpu:4")); + EXPECT_FALSE(DeviceNameUtils::IsSameAddressSpace( + "/job:foo/replica:1/task:2/cpu:3", "/job:foo/replica:10/task:2/gpu:4")); + EXPECT_FALSE(DeviceNameUtils::IsSameAddressSpace( + "/job:foo/replica:1/task:2/cpu:3", "/job:bar/replica:1/task:2/gpu:4")); + + EXPECT_EQ(DeviceNameUtils::LocalName("CPU", 1), "CPU:1"); + EXPECT_EQ(DeviceNameUtils::LocalName("GPU", 2), "GPU:2"); + EXPECT_EQ(DeviceNameUtils::LocalName("MySpecialDevice", 13), + "MySpecialDevice:13"); + + EXPECT_EQ( + DeviceNameUtils::LocalName("/job:foo/replica:1/task:2/device:CPU:3"), + "CPU:3"); + + EXPECT_EQ(DeviceNameUtils::LocalName("/job:foo/replica:1/task:2/cpu:3"), + "CPU:3"); + + EXPECT_EQ( + DeviceNameUtils::LocalName("/job:foo/replica:1/task:2/device:abc:73"), + "abc:73"); + + { + DeviceNameUtils::ParsedName p; + EXPECT_TRUE(DeviceNameUtils::ParseLocalName("CPU:10", &p)); + EXPECT_EQ(p.type, "CPU"); + EXPECT_EQ(p.id, 10); + EXPECT_FALSE(DeviceNameUtils::ParseLocalName("cpu:abc", &p)); + EXPECT_FALSE(DeviceNameUtils::ParseLocalName("abc:", &p)); + EXPECT_FALSE(DeviceNameUtils::ParseLocalName("abc", &p)); + EXPECT_FALSE(DeviceNameUtils::ParseLocalName("myspecialdevice", &p)); + } +} + +static bool IsCSHelper(StringPiece pattern, StringPiece actual) { + DeviceNameUtils::ParsedName p, a; + EXPECT_TRUE(DeviceNameUtils::ParseFullName(pattern, &p)); + EXPECT_TRUE(DeviceNameUtils::ParseFullName(actual, &a)); + return DeviceNameUtils::IsCompleteSpecification(p, a); +} + +TEST(DeviceNameUtilsTest, IsCompleteSpecification) { + EXPECT_TRUE(IsCSHelper("/job:*", "/job:work/replica:1/task:2/gpu:3")); + EXPECT_TRUE( + IsCSHelper("/job:*/replica:*", "/job:work/replica:1/task:2/gpu:3")); + EXPECT_TRUE(IsCSHelper("/job:*/task:*", "/job:work/replica:1/task:2/gpu:3")); + EXPECT_TRUE(IsCSHelper("/job:*/replica:*/task:*", + "/job:work/replica:1/task:2/gpu:3")); + EXPECT_TRUE( + IsCSHelper("/job:*/replica:*/gpu:*", "/job:work/replica:1/task:2/gpu:3")); + EXPECT_FALSE(IsCSHelper("/cpu:*", "/job:worker/replica:1/task:2/gpu:3")); + EXPECT_FALSE(IsCSHelper("/gpu:2", "/job:worker/replica:1/task:2/gpu:1")); + EXPECT_TRUE(IsCSHelper("/gpu:*", "/job:worker/replica:1/task:2/gpu:3")); +} + +static bool IsSpecHelper(StringPiece pattern, StringPiece actual) { + DeviceNameUtils::ParsedName p, a; + EXPECT_TRUE(DeviceNameUtils::ParseFullName(pattern, &p)); + EXPECT_TRUE(DeviceNameUtils::ParseFullName(actual, &a)); + return DeviceNameUtils::IsSpecification(p, a); +} + +TEST(DeviceNameUtilsTest, IsSpecification) { + EXPECT_TRUE(IsSpecHelper("/job:*", "/job:work/replica:1/task:2/gpu:3")); + EXPECT_TRUE(IsSpecHelper("/job:*", "/job:work/replica:1/gpu:3")); + EXPECT_TRUE(IsSpecHelper("/job:*", "/job:work/replica:1")); + EXPECT_TRUE(IsSpecHelper("/job:*", "/replica:1")); + EXPECT_TRUE(IsSpecHelper("/job:*", "/job:work")); + EXPECT_TRUE( + IsSpecHelper("/job:*/replica:*", "/job:work/replica:1/task:2/gpu:3")); + EXPECT_TRUE(IsSpecHelper("/job:work/replica:1/gpu:*", + "/job:work/replica:1/task:2/gpu:3")); + EXPECT_TRUE(IsSpecHelper("/job:work/replica:1/gpu:3", + "/job:work/replica:1/task:2/gpu:3")); + EXPECT_TRUE(IsSpecHelper("/job:work/replica:1/task:2", + "/job:work/replica:1/task:2/gpu:3")); + EXPECT_TRUE(IsSpecHelper("/job:work/replica:*/task:2", + "/job:work/replica:1/task:2/gpu:3")); + EXPECT_TRUE(IsSpecHelper("/task:*", "/job:*/replica:1/task:2/gpu:3")); + EXPECT_TRUE(IsSpecHelper("/task:2", "/job:*/replica:1/task:2/gpu:3")); + EXPECT_TRUE(IsSpecHelper("/cpu:*", "/job:*/replica:1/task:2/cpu:1")); + EXPECT_TRUE(IsSpecHelper("/cpu:0", "/cpu:0")); + EXPECT_TRUE(IsSpecHelper("/gpu:*", "/job:worker/replica:1/task:2/gpu:3")); + + EXPECT_FALSE(IsSpecHelper("/job:worker/replica:1/task:2/gpu:3", "/gpu:*")); + EXPECT_FALSE(IsSpecHelper("/cpu:*", "/job:*/replica:1/task:2")); + EXPECT_FALSE(IsSpecHelper("/cpu:*", "/job:*/replica:1/task:2/gpu:1")); + EXPECT_FALSE(IsSpecHelper("/cpu:*", "/job:worker/replica:1/task:2/gpu:3")); + EXPECT_FALSE(IsSpecHelper("/gpu:2", "/job:worker/replica:1/task:2/gpu:1")); + EXPECT_FALSE(IsSpecHelper("/job:work/replica:*/task:0", + "/job:work/replica:1/task:2/gpu:3")); + EXPECT_FALSE(IsSpecHelper("/job:work/replica:0/task:2", + "/job:work/replica:*/task:2/gpu:3")); +} + +TEST(DeviceNameUtilsTest, SplitDeviceName) { + string task; + string device; + EXPECT_TRUE(DeviceNameUtils::SplitDeviceName( + "/job:foo/replica:1/task:2/cpu:1", &task, &device)); + EXPECT_EQ("/job:foo/replica:1/task:2", task); + EXPECT_EQ("CPU:1", device); + EXPECT_TRUE(DeviceNameUtils::SplitDeviceName( + "/job:foo/cpu:1/task:2/replica:1", &task, &device)); + EXPECT_EQ("/job:foo/replica:1/task:2", task); + EXPECT_EQ("CPU:1", device); + EXPECT_TRUE(DeviceNameUtils::SplitDeviceName("/gpu:3", &task, &device)); + EXPECT_EQ("", task); + EXPECT_EQ("GPU:3", device); + EXPECT_FALSE(DeviceNameUtils::SplitDeviceName("gpu:3", &task, &device)); + EXPECT_FALSE(DeviceNameUtils::SplitDeviceName("/job:foo/task:2/replica:1", + &task, &device)); + EXPECT_TRUE(DeviceNameUtils::SplitDeviceName("/device:myspecialdevice:3", + &task, &device)); + EXPECT_EQ("", task); + EXPECT_EQ("myspecialdevice:3", device); +} + +static DeviceNameUtils::ParsedName Name(const string& str) { + DeviceNameUtils::ParsedName ret; + CHECK(DeviceNameUtils::ParseFullName(str, &ret)) << "Invalid name: " << str; + return ret; +} + +static void MergeDevNamesHelperImpl(const string& name_a, const string& name_b, + const string& expected_merge_name, + bool allow_soft_placement) { + DeviceNameUtils::ParsedName target_a = Name(name_a); + EXPECT_OK(DeviceNameUtils::MergeDevNames(&target_a, Name(name_b), + allow_soft_placement)); + DeviceNameUtils::ParsedName target_b = Name(name_b); + EXPECT_OK(DeviceNameUtils::MergeDevNames(&target_b, Name(name_a), + allow_soft_placement)); + EXPECT_EQ(target_a, target_b); + EXPECT_EQ(target_a, Name(expected_merge_name)); + EXPECT_EQ(target_b, Name(expected_merge_name)); +} + +static void MergeDevNamesHelper(const string& name_a, const string& name_b, + const string& expected_merge_name) { + MergeDevNamesHelperImpl(name_a, name_b, expected_merge_name, false); +} + +static void MergeDevNamesHelperAllowSoftPlacement( + const string& name_a, const string& name_b, + const string& expected_merge_name) { + MergeDevNamesHelperImpl(name_a, name_b, expected_merge_name, true); +} + +static void MergeDevNamesError(const string& name_a, const string& name_b, + const string& expected_error_substr) { + DeviceNameUtils::ParsedName target_a = Name(name_a); + Status s = DeviceNameUtils::MergeDevNames(&target_a, Name(name_b)); + EXPECT_EQ(s.code(), error::INVALID_ARGUMENT); + EXPECT_TRUE(StringPiece(s.error_message()).contains(expected_error_substr)) + << s; +} + +TEST(DeviceNameUtilsTest, MergeDevNames) { + DeviceNameUtils::ParsedName target; + + // Idempotence tests. + MergeDevNamesHelper("", "", ""); + MergeDevNamesHelper("/job:foo/replica:1/task:2/cpu:1", + "/job:foo/replica:1/task:2/cpu:1", + "/job:foo/replica:1/task:2/cpu:1"); + + // Merging with empty device has no effect. + MergeDevNamesHelper("", "/job:foo", "/job:foo"); + MergeDevNamesHelper("", "/replica:2", "/replica:2"); + MergeDevNamesHelper("", "/task:7", "/task:7"); + // MergeDevNamesHelper("", "/gpu:1", "/gpu:1"); + + // Combining disjoint names. + MergeDevNamesHelper("/job:foo", "/task:7", "/job:foo/task:7"); + MergeDevNamesHelper("/job:foo", "/gpu:1", "/job:foo/gpu:1"); + + // Combining overlapping names. + MergeDevNamesHelper("/job:foo/replica:0", "/replica:0/task:1", + "/job:foo/replica:0/task:1"); + + // Wildcard tests. + MergeDevNamesHelper("", "/gpu:*", "/gpu:*"); + MergeDevNamesHelper("/gpu:*", "/gpu:*", "/gpu:*"); + MergeDevNamesHelper("/gpu:1", "/gpu:*", "/gpu:1"); + + // Incompatible components. + MergeDevNamesError("/job:foo", "/job:bar", "incompatible jobs"); + MergeDevNamesError("/replica:0", "/replica:1", "incompatible replicas"); + MergeDevNamesError("/task:0", "/task:1", "incompatible tasks"); + MergeDevNamesError("/gpu:*", "/cpu:*", "incompatible types"); + MergeDevNamesError("/gpu:0", "/gpu:1", "incompatible ids"); +} + +TEST(DeviceNameUtilsTest, MergeDevNamesAllowSoftPlacement) { + // Incompatible components with allow_soft_placement. + MergeDevNamesHelperAllowSoftPlacement("/gpu:*", "/cpu:1", ""); + MergeDevNamesHelperAllowSoftPlacement("/cpu:*", "/gpu:1", ""); + MergeDevNamesHelperAllowSoftPlacement("/gpu:1", "/gpu:2", "/gpu:*"); +} + +static void BM_ParseFullName(int iters) { + DeviceNameUtils::ParsedName p; + while (iters--) { + DeviceNameUtils::ParseFullName("/job:worker/replica:3/task:0/cpu:0", &p); + } +} +BENCHMARK(BM_ParseFullName); + +} // namespace tensorflow diff --git a/tensorflow/core/util/event.proto b/tensorflow/core/util/event.proto new file mode 100644 index 0000000000..5d67823ce7 --- /dev/null +++ b/tensorflow/core/util/event.proto @@ -0,0 +1,29 @@ +syntax = "proto3"; + +package tensorflow; +// option cc_enable_arenas = true; + +import "tensorflow/core/framework/graph.proto"; +import "tensorflow/core/framework/summary.proto"; + +// Protocol buffer representing an event that happened during +// the execution of a Brain model. +message Event { + // Timestamp of the event. + double wall_time = 1; + + // Globale step of the event. + int64 step = 2; + + oneof what { + // An event file was started, with the specified version. + // This is use to identify the contents of the record IO files + // easily. Current version is "tensorflow.Event:1". All versions + // start with "tensorflow.Event:". + string file_version = 3; + // A model was constructed. + GraphDef graph_def = 4; + // A summary was generated. + Summary summary = 5; + } +} diff --git a/tensorflow/core/util/events_writer.cc b/tensorflow/core/util/events_writer.cc new file mode 100644 index 0000000000..1b34a36577 --- /dev/null +++ b/tensorflow/core/util/events_writer.cc @@ -0,0 +1,144 @@ +#include "tensorflow/core/util/events_writer.h" + +#include <stddef.h> // for NULL + +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/lib/strings/stringprintf.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/public/env.h" +#include "tensorflow/core/public/status.h" +#include "tensorflow/core/util/event.pb.h" + +namespace tensorflow { + +EventsWriter::EventsWriter(const string& file_prefix) + // TODO(jeff,sanjay): Pass in env and use that here instead of Env::Default + : env_(Env::Default()), + file_prefix_(file_prefix), + num_outstanding_events_(0) {} + +bool EventsWriter::Init() { + if (recordio_writer_.get() != nullptr) { + CHECK(!filename_.empty()); + if (FileHasDisappeared()) { + // Warn user of data loss and let .reset() below do basic cleanup. + if (num_outstanding_events_ > 0) { + LOG(WARNING) << "Re-intialization, attempting to open a new file, " + << num_outstanding_events_ << " events will be lost."; + } + } else { + // No-op: File is present and writer is initialized. + return true; + } + } + + int64 time_in_seconds = env_->NowMicros() / 1000000; + + filename_ = strings::Printf( + "%s.out.tfevents.%010lld.%s", file_prefix_.c_str(), + static_cast<long long>(time_in_seconds), port::Hostname().c_str()); + port::AdjustFilenameForLogging(&filename_); + + WritableFile* file; + Status s = env_->NewWritableFile(filename_, &file); + if (!s.ok()) { + LOG(ERROR) << "Could not open events file: " << filename_ << ": " << s; + return false; + } + recordio_file_.reset(file); + recordio_writer_.reset(new io::RecordWriter(recordio_file_.get())); + if (recordio_writer_.get() == NULL) { + LOG(ERROR) << "Could not create record writer"; + return false; + } + num_outstanding_events_ = 0; + VLOG(1) << "Successfully opened events file: " << filename_; + { + // Write the first event with the current version, and flush + // right away so the file contents will be easily determined. + + Event event; + event.set_wall_time(time_in_seconds); + event.set_file_version(strings::StrCat(kVersionPrefix, kCurrentVersion)); + WriteEvent(event); + Flush(); + } + return true; +} + +string EventsWriter::FileName() { + if (filename_.empty()) { + Init(); + } + return filename_; +} + +void EventsWriter::WriteSerializedEvent(const string& event_str) { + if (recordio_writer_.get() == NULL) { + if (!Init()) { + LOG(ERROR) << "Write failed because file could not be opened."; + return; + } + } + num_outstanding_events_++; + recordio_writer_->WriteRecord(event_str); +} + +void EventsWriter::WriteEvent(const Event& event) { + string record; + event.AppendToString(&record); + WriteSerializedEvent(record); +} + +bool EventsWriter::Flush() { + if (num_outstanding_events_ == 0) return true; + CHECK(recordio_file_.get() != NULL) << "Unexpected NULL file"; + // The FileHasDisappeared() condition is necessary because + // recordio_writer_->Sync() can return true even if the underlying + // file has been deleted. EventWriter.FileDeletionBeforeWriting + // demonstrates this and will fail if the FileHasDisappeared() + // conditon is removed. + // Also, we deliberately attempt to Sync() before checking for a + // disappearing file, in case for some file system File::Exists() is + // false after File::Open() but before File::Sync(). + if (!recordio_file_->Flush().ok() || !recordio_file_->Sync().ok() || + FileHasDisappeared()) { + LOG(ERROR) << "Failed to flush " << num_outstanding_events_ << " events to " + << filename_; + return false; + } + VLOG(1) << "Wrote " << num_outstanding_events_ << " events to disk."; + num_outstanding_events_ = 0; + return true; +} + +bool EventsWriter::Close() { + bool return_value = Flush(); + if (recordio_file_.get() != NULL) { + Status s = recordio_file_->Close(); + if (!s.ok()) { + LOG(ERROR) << "Error when closing previous event file: " << filename_ + << ": " << s; + return_value = false; + } + recordio_writer_.reset(NULL); + recordio_file_.reset(NULL); + } + num_outstanding_events_ = 0; + return return_value; +} + +bool EventsWriter::FileHasDisappeared() { + if (env_->FileExists(filename_)) { + return false; + } else { + // This can happen even with non-null recordio_writer_ if some other + // process has removed the file. + LOG(ERROR) << "The events file " << filename_ << " has disappeared."; + return true; + } +} + +} // namespace tensorflow diff --git a/tensorflow/core/util/events_writer.h b/tensorflow/core/util/events_writer.h new file mode 100644 index 0000000000..e6b94ad265 --- /dev/null +++ b/tensorflow/core/util/events_writer.h @@ -0,0 +1,77 @@ +#ifndef TENSORFLOW_UTIL_EVENTS_WRITER_H_ +#define TENSORFLOW_UTIL_EVENTS_WRITER_H_ + +#include <memory> +#include <string> +#include "tensorflow/core/lib/io/record_writer.h" +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/public/env.h" +#include "tensorflow/core/util/event.pb.h" + +namespace tensorflow { + +class EventsWriter { + public: +#ifndef SWIG + // Prefix of version string present in the first entry of every event file. + static constexpr const char* kVersionPrefix = "brain.Event:"; + static constexpr const int kCurrentVersion = 1; +#endif + + // Events files typically have a name of the form + // '/some/file/path/my.file.out.events.[timestamp].[hostname]' + // To create and EventWriter, the user should provide file_prefix = + // '/some/file/path/my.file' + // The EventsWriter will append '.out.events.[timestamp].[hostname]' + // to the ultimate filename once Init() is called. + // Note that it is not recommended to simultaneously have two + // EventWriters writing to the same file_prefix. + explicit EventsWriter(const string& file_prefix); + ~EventsWriter() { Close(); } // Autoclose in destructor. + + // Sets the event file filename and opens file for writing. If not called by + // user, will be invoked automatically by a call to FileName() or Write*(). + // Returns false if the file could not be opened. Idempotent: if file exists + // and is open this is a no-op. If on the other hand the file was opened, + // but has since disappeared (e.g. deleted by another process), this will open + // a new file with a new timestamp in its filename. + bool Init(); + + // Returns the filename for the current events file: + // filename_ = [file_prefix_].out.events.[timestamp].[hostname] + string FileName(); + + // Append "event" to the file. The "tensorflow::" part is for swig happiness. + void WriteEvent(const tensorflow::Event& event); + + // Append "event_str", a serialized Event, to the file. + // Note that this function does NOT check that de-serializing event_str + // results in a valid Event proto. + void WriteSerializedEvent(const string& event_str); + + // EventWriter automatically flushes and closes on destruction, but + // these two methods are provided for users who want to write to disk sooner + // and/or check for success. + // Flush() pushes outstanding events to disk. Returns false if the + // events file could not be created, or if the file exists but could not + // be written too. + // Close() calls Flush() and then closes the current events file. + // Returns true only if both the flush and the closure were successful. + bool Flush(); + bool Close(); + + private: + bool FileHasDisappeared(); // True if event_file_path_ does not exist. + + Env* env_; + const string file_prefix_; + string filename_; + std::unique_ptr<WritableFile> recordio_file_; + std::unique_ptr<io::RecordWriter> recordio_writer_; + int num_outstanding_events_; + TF_DISALLOW_COPY_AND_ASSIGN(EventsWriter); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_UTIL_EVENTS_WRITER_H_ diff --git a/tensorflow/core/util/events_writer_test.cc b/tensorflow/core/util/events_writer_test.cc new file mode 100644 index 0000000000..f6523ead92 --- /dev/null +++ b/tensorflow/core/util/events_writer_test.cc @@ -0,0 +1,198 @@ +#include "tensorflow/core/util/events_writer.h" + +#include <math.h> +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/lib/io/record_reader.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/public/env.h" +#include "tensorflow/core/public/status.h" +#include "tensorflow/core/util/event.pb.h" + +namespace tensorflow { +namespace { + +// shorthand +Env* env() { return Env::Default(); } + +void WriteSimpleValue(EventsWriter* writer, double wall_time, int64 step, + const string& tag, float simple_value) { + Event event; + event.set_wall_time(wall_time); + event.set_step(step); + Summary::Value* summ_val = event.mutable_summary()->add_value(); + summ_val->set_tag(tag); + summ_val->set_simple_value(simple_value); + writer->WriteEvent(event); +} + +void WriteFile(EventsWriter* writer) { + WriteSimpleValue(writer, 1234, 34, "foo", 3.14159); + WriteSimpleValue(writer, 2345, 35, "bar", -42); +} + +static bool ReadEventProto(io::RecordReader* reader, uint64* offset, + Event* proto) { + string record; + Status s = reader->ReadRecord(offset, &record); + if (!s.ok()) { + return false; + } + return ParseProtoUnlimited(proto, record); +} + +void VerifyFile(const string& filename) { + CHECK(env()->FileExists(filename)); + RandomAccessFile* event_file; + TF_CHECK_OK(env()->NewRandomAccessFile(filename, &event_file)); + io::RecordReader* reader = new io::RecordReader(event_file); + + uint64 offset = 0; + + Event actual; + CHECK(ReadEventProto(reader, &offset, &actual)); + VLOG(1) << actual.ShortDebugString(); + // Wall time should be within 5s of now. + + double current_time = env()->NowMicros() / 1000000.0; + EXPECT_LT(fabs(actual.wall_time() - current_time), 5); + // Should have the current version number. + EXPECT_EQ(actual.file_version(), + strings::StrCat(EventsWriter::kVersionPrefix, + EventsWriter::kCurrentVersion)); + + Event expected; + CHECK(ReadEventProto(reader, &offset, &actual)); + VLOG(1) << actual.ShortDebugString(); + ASSERT_TRUE(protobuf::TextFormat::ParseFromString( + "wall_time: 1234 step: 34 " + "summary { value { tag: 'foo' simple_value: 3.14159 } }", + &expected)); + // TODO(keveman): Enable this check + // EXPECT_THAT(expected, EqualsProto(actual)); + + CHECK(ReadEventProto(reader, &offset, &actual)); + VLOG(1) << actual.ShortDebugString(); + ASSERT_TRUE(protobuf::TextFormat::ParseFromString( + "wall_time: 2345 step: 35 " + "summary { value { tag: 'bar' simple_value: -42 } }", + &expected)); + // TODO(keveman): Enable this check + // EXPECT_THAT(expected, EqualsProto(actual)); + + TF_CHECK_OK(env()->DeleteFile(filename)); + + delete reader; + delete event_file; +} + +string GetDirName(const string& suffix) { + return io::JoinPath(testing::TmpDir(), suffix); +} + +TEST(EventWriter, WriteFlush) { + string file_prefix = GetDirName("/writeflush_test"); + EventsWriter writer(file_prefix); + WriteFile(&writer); + EXPECT_TRUE(writer.Flush()); + string filename = writer.FileName(); + VerifyFile(filename); +} + +TEST(EventWriter, WriteClose) { + string file_prefix = GetDirName("/writeclose_test"); + EventsWriter writer(file_prefix); + WriteFile(&writer); + EXPECT_TRUE(writer.Close()); + string filename = writer.FileName(); + VerifyFile(filename); +} + +TEST(EventWriter, WriteDelete) { + string file_prefix = GetDirName("/writedelete_test"); + EventsWriter* writer = new EventsWriter(file_prefix); + WriteFile(writer); + string filename = writer->FileName(); + delete writer; + VerifyFile(filename); +} + +TEST(EventWriter, FailFlush) { + string file_prefix = GetDirName("/failflush_test"); + EventsWriter writer(file_prefix); + string filename = writer.FileName(); + WriteFile(&writer); + EXPECT_TRUE(env()->FileExists(filename)); + env()->DeleteFile(filename); + EXPECT_FALSE(env()->FileExists(filename)); + EXPECT_FALSE(writer.Flush()); + EXPECT_FALSE(env()->FileExists(filename)); +} + +TEST(EventWriter, FailClose) { + string file_prefix = GetDirName("/failclose_test"); + EventsWriter writer(file_prefix); + string filename = writer.FileName(); + WriteFile(&writer); + EXPECT_TRUE(env()->FileExists(filename)); + env()->DeleteFile(filename); + EXPECT_FALSE(env()->FileExists(filename)); + EXPECT_FALSE(writer.Close()); + EXPECT_FALSE(env()->FileExists(filename)); +} + +TEST(EventWriter, InitWriteClose) { + string file_prefix = GetDirName("/initwriteclose_test"); + EventsWriter writer(file_prefix); + EXPECT_TRUE(writer.Init()); + string filename0 = writer.FileName(); + EXPECT_TRUE(env()->FileExists(filename0)); + WriteFile(&writer); + EXPECT_TRUE(writer.Close()); + string filename1 = writer.FileName(); + EXPECT_EQ(filename0, filename1); + VerifyFile(filename1); +} + +TEST(EventWriter, NameWriteClose) { + string file_prefix = GetDirName("/namewriteclose_test"); + EventsWriter writer(file_prefix); + string filename = writer.FileName(); + EXPECT_TRUE(env()->FileExists(filename)); + WriteFile(&writer); + EXPECT_TRUE(writer.Close()); + VerifyFile(filename); +} + +TEST(EventWriter, NameClose) { + string file_prefix = GetDirName("/nameclose_test"); + EventsWriter writer(file_prefix); + string filename = writer.FileName(); + EXPECT_TRUE(writer.Close()); + EXPECT_TRUE(env()->FileExists(filename)); + env()->DeleteFile(filename); +} + +TEST(EventWriter, FileDeletionBeforeWriting) { + string file_prefix = GetDirName("/fdbw_test"); + EventsWriter writer(file_prefix); + string filename0 = writer.FileName(); + EXPECT_TRUE(env()->FileExists(filename0)); + env()->SleepForMicroseconds( + 2000000); // To make sure timestamp part of filename will differ. + env()->DeleteFile(filename0); + EXPECT_TRUE(writer.Init()); // Init should reopen file. + WriteFile(&writer); + EXPECT_TRUE(writer.Flush()); + string filename1 = writer.FileName(); + EXPECT_NE(filename0, filename1); + VerifyFile(filename1); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/util/guarded_philox_random.cc b/tensorflow/core/util/guarded_philox_random.cc new file mode 100644 index 0000000000..4cf58b8979 --- /dev/null +++ b/tensorflow/core/util/guarded_philox_random.cc @@ -0,0 +1,39 @@ +#include "tensorflow/core/util/guarded_philox_random.h" +#include "tensorflow/core/lib/random/random.h" + +namespace tensorflow { + +Status GuardedPhiloxRandom::Init(OpKernelConstruction* context) { + // Grab seed Attrs. + int64 seed, seed2; + auto status = context->GetAttr("seed", &seed); + if (!status.ok()) return status; + status = context->GetAttr("seed2", &seed2); + if (!status.ok()) return status; + + // Initialize with the given seeds + Init(seed, seed2); + return Status::OK(); +} + +void GuardedPhiloxRandom::Init(int64 seed, int64 seed2) { + CHECK(!initialized_); + if (seed == 0 && seed2 == 0) { + // If both seeds are unspecified, use completely random seeds. + seed = random::New64(); + seed2 = random::New64(); + } + mutex_lock lock(mu_); + generator_ = random::PhiloxRandom(seed, seed2); + initialized_ = true; +} + +random::PhiloxRandom GuardedPhiloxRandom::ReserveSamples128(int64 samples) { + CHECK(initialized_); + mutex_lock lock(mu_); + auto local = generator_; + generator_.Skip(samples); + return local; +} + +} // namespace tensorflow diff --git a/tensorflow/core/util/guarded_philox_random.h b/tensorflow/core/util/guarded_philox_random.h new file mode 100644 index 0000000000..6e9cb9f99c --- /dev/null +++ b/tensorflow/core/util/guarded_philox_random.h @@ -0,0 +1,56 @@ +#ifndef TENSORFLOW_KERNELS_GUARDED_PHILOX_RANDOM_H_ +#define TENSORFLOW_KERNELS_GUARDED_PHILOX_RANDOM_H_ + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/lib/random/philox_random.h" +#include "tensorflow/core/platform/port.h" + +namespace tensorflow { + +// A thread safe wrapper around a Philox generator. Example usage: +// +// GuardedRandomPhilox generator; +// generator.Init(context); +// +// // In thread safe code +// const int samples = ...; +// auto local_generator = generator.ReserveSamples128(samples); +// for (int i = 0; i < samples; i++) +// Array<uint32, 4> sample = local_generator(); +// // Use sample +// } +// +class GuardedPhiloxRandom { + public: + // Must call Init to finish initialization + GuardedPhiloxRandom() : initialized_(false) {} + + // Initialize the generator from attributes "seed" and "seed2". + // If both seeds are unspecified, use random seeds. + // Must be called exactly once. + Status Init(OpKernelConstruction* context); + + // Initialize with given seeds. + void Init(int64 seed, int64 seed2); + + // Reserve a certain number of 128-bit samples. + // This function is thread safe. The returned generator is valid for the + // given number of samples, and can be used without a lock. + random::PhiloxRandom ReserveSamples128(int64 samples); + + // Reserve a certain number of 32-bit samples + random::PhiloxRandom ReserveSamples32(int64 samples) { + return ReserveSamples128((samples + 3) / 4); + } + + private: + mutex mu_; + random::PhiloxRandom generator_ GUARDED_BY(mu_); + bool initialized_; + + TF_DISALLOW_COPY_AND_ASSIGN(GuardedPhiloxRandom); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_KERNELS_GUARDED_PHILOX_RANDOM_H_ diff --git a/tensorflow/core/util/padding.cc b/tensorflow/core/util/padding.cc new file mode 100644 index 0000000000..24273e5ca4 --- /dev/null +++ b/tensorflow/core/util/padding.cc @@ -0,0 +1,24 @@ +#include "tensorflow/core/util/padding.h" + +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace tensorflow { + +Status GetNodeAttr(const NodeDef& node_def, const string& attr_name, + Padding* value) { + string str_value; + TF_RETURN_IF_ERROR(GetNodeAttr(node_def, attr_name, &str_value)); + if (str_value == "SAME") { + *value = SAME; + } else if (str_value == "VALID") { + *value = VALID; + } else { + return errors::NotFound(str_value, " is not an allowed padding type"); + } + return Status::OK(); +} + +string GetPaddingAttrString() { return "padding: {'SAME', 'VALID'}"; } + +} // end namespace tensorflow diff --git a/tensorflow/core/util/padding.h b/tensorflow/core/util/padding.h new file mode 100644 index 0000000000..66cd96abdb --- /dev/null +++ b/tensorflow/core/util/padding.h @@ -0,0 +1,37 @@ +#ifndef TENSORFLOW_UTIL_PADDING_H_ +#define TENSORFLOW_UTIL_PADDING_H_ + +// This file contains helper routines to deal with padding in various ops and +// kernels. + +#include <string> + +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/public/status.h" + +namespace tensorflow { + +// Padding: the padding we apply to the input tensor along the rows and columns +// dimensions. This is usually used to make sure that the spatial dimensions do +// not shrink when we progress with convolutions. Two types of padding are +// supported: +// VALID: No padding is carried out. +// SAME: The pad value is computed so that the output will have the same +// dimensions as the input. +// The padded area is zero-filled. +enum Padding { + VALID = 1, // No padding. + SAME = 2, // Input and output layers have the same size. +}; + +// Return the string containing the list of valid padding types, that can be +// used as an Attr() in REGISTER_OP. +string GetPaddingAttrString(); + +// Specialization to parse an attribute directly into a Padding enum. +Status GetNodeAttr(const NodeDef& node_def, const string& attr_name, + Padding* value); + +} // end namespace tensorflow + +#endif // TENSORFLOW_UTIL_PADDING_H_ diff --git a/tensorflow/core/util/port.cc b/tensorflow/core/util/port.cc new file mode 100644 index 0000000000..12eb076a4d --- /dev/null +++ b/tensorflow/core/util/port.cc @@ -0,0 +1,13 @@ +#include "tensorflow/core/util/port.h" + +namespace tensorflow { + +bool IsGoogleCudaEnabled() { +#if GOOGLE_CUDA + return true; +#else + return false; +#endif +} + +} // end namespace tensorflow diff --git a/tensorflow/core/util/port.h b/tensorflow/core/util/port.h new file mode 100644 index 0000000000..8b9d033d63 --- /dev/null +++ b/tensorflow/core/util/port.h @@ -0,0 +1,11 @@ +#ifndef TENSORFLOW_UTIL_PORT_H_ +#define TENSORFLOW_UTIL_PORT_H_ + +namespace tensorflow { + +// Returns true if GOOGLE_CUDA is defined. +bool IsGoogleCudaEnabled(); + +} // end namespace tensorflow + +#endif // TENSORFLOW_UTIL_PORT_H_ diff --git a/tensorflow/core/util/saved_tensor_slice.proto b/tensorflow/core/util/saved_tensor_slice.proto new file mode 100644 index 0000000000..f6599d9669 --- /dev/null +++ b/tensorflow/core/util/saved_tensor_slice.proto @@ -0,0 +1,76 @@ +// Protocol buffers for saved tensor slices. It's used for the brain tensor +// ops checkpoints and the V3 checkpoints in dist_belief. + +// A checkpoint file is an sstable. The value for each record is a serialized +// SavedTensorSlices message (defined below). +// +// Each checkpoint file has a record with the empty key (""), which corresponds +// to a SavedTensorSlices message that contains a "meta", that serves as a +// table of contents on all the tensor slices saved in this file. Since the key +// is "", it's always the first record in each file. +// +// Each of the rest of the records in a checkpoint stores the raw data of a +// particular tensor slice, in SavedSlice format. The corresponding key is an +// ordered code that encodes the name of the tensor and the slice +// information. The name is also stored in the SaveSlice message for ease of +// debugging and manual examination. + +syntax = "proto3"; + +package tensorflow; +// option cc_enable_arenas = true; + +import "tensorflow/core/framework/tensor_shape.proto"; +import "tensorflow/core/framework/tensor_slice.proto"; +import "tensorflow/core/framework/tensor.proto"; +import "tensorflow/core/framework/types.proto"; + +// Metadata describing the set of slices of the same tensor saved in a +// checkpoint file. +message SavedSliceMeta { + // Name of the tensor. + string name = 1; + + // Shape of the tensor + TensorShapeProto shape = 2; + + // Type of the tensor + DataType type = 3; + + // Explicit list of slices saved in the checkpoint file. + repeated TensorSliceProto slice = 4; +}; + +// Metadata describing the set of tensor slices saved in a checkpoint file. +// It is always stored at the beginning of each checkpoint file. +message SavedTensorSliceMeta { + // Each SavedSliceMeta describes the slices for one tensor. + repeated SavedSliceMeta tensor = 1; +}; + +// Saved tensor slice: it stores the name of the tensors, the slice, and the +// raw data. +message SavedSlice { + // Name of the tensor that this slice belongs to. This must be identical to + // the name used to encode the key for this record. + string name = 1; + + // Extent of the slice. Must have one entry for each of the dimension of the + // tensor that this slice belongs to. + TensorSliceProto slice = 2; + + // The raw data of the slice is stored as a TensorProto. Only raw data are + // stored (we don't fill in fields such as dtype or tensor_shape). + TensorProto data = 3; +}; + +// Each record in a v3 checkpoint file is a serialized SavedTensorSlices +// message. +message SavedTensorSlices { + // This is only present at the first item of each checkpoint file and serves + // as a table of contents, listing all the tensor slices saved in this file. + SavedTensorSliceMeta meta = 1; + + // This exists in all but the first item of each checkpoint file. + SavedSlice data = 2; +}; diff --git a/tensorflow/core/util/saved_tensor_slice_util.cc b/tensorflow/core/util/saved_tensor_slice_util.cc new file mode 100644 index 0000000000..7a5903f07f --- /dev/null +++ b/tensorflow/core/util/saved_tensor_slice_util.cc @@ -0,0 +1,76 @@ +#include "tensorflow/core/util/saved_tensor_slice_util.h" + +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/strings/ordered_code.h" + +namespace tensorflow { + +namespace checkpoint { + +const char kSavedTensorSlicesKey[] = ""; + +string EncodeTensorNameSlice(const string& name, const TensorSlice& slice) { + string buffer; + // All the tensor slice keys will start with a 0 + tensorflow::strings::OrderedCode::WriteNumIncreasing(&buffer, 0); + tensorflow::strings::OrderedCode::WriteString(&buffer, name); + tensorflow::strings::OrderedCode::WriteNumIncreasing(&buffer, slice.dims()); + for (int d = 0; d < slice.dims(); ++d) { + // A trivial extent (meaning we take EVERYTHING) will default to -1 for both + // start and end. These will be properly parsed. + tensorflow::strings::OrderedCode::WriteSignedNumIncreasing(&buffer, + slice.start(d)); + tensorflow::strings::OrderedCode::WriteSignedNumIncreasing(&buffer, + slice.length(d)); + } + return buffer; +} + +Status DecodeTensorNameSlice(const string& code, string* name, + tensorflow::TensorSlice* slice) { + StringPiece src(code); + uint64 x; + if (!tensorflow::strings::OrderedCode::ReadNumIncreasing(&src, &x)) { + return errors::Internal("Failed to parse the leading number: src = ", src); + } + if (x != 0) { + return errors::Internal( + "The leading number should always be 0 for any valid key: src = ", src); + } + if (!tensorflow::strings::OrderedCode::ReadString(&src, name)) { + return errors::Internal("Failed to parse the tensor name: src = ", src); + } + if (!tensorflow::strings::OrderedCode::ReadNumIncreasing(&src, &x)) { + return errors::Internal("Failed to parse the tensor rank: src = ", src); + } + if (x == 0) { + return errors::Internal("Expecting positive rank of the tensor, got ", x, + ", src = ", src); + } + if (x >= kint32max) { + return errors::Internal("Too many elements ", x); + } + slice->SetFullSlice(x); + for (int d = 0; d < static_cast<int32>(x); ++d) { + // We expected 2x integers + int64 start, length; + if (!tensorflow::strings::OrderedCode::ReadSignedNumIncreasing(&src, + &start)) { + return errors::Internal("Failed to parse start: src = ", src); + } + if (!tensorflow::strings::OrderedCode::ReadSignedNumIncreasing(&src, + &length)) { + return errors::Internal("Failed to parse length: src = ", src); + } + if (length >= 0) { + // a non-trivial extent + slice->set_start(d, start); + slice->set_length(d, length); + } + } + return Status::OK(); +} + +} // namespace checkpoint + +} // namespace tensorflow diff --git a/tensorflow/core/util/saved_tensor_slice_util.h b/tensorflow/core/util/saved_tensor_slice_util.h new file mode 100644 index 0000000000..6206cd8538 --- /dev/null +++ b/tensorflow/core/util/saved_tensor_slice_util.h @@ -0,0 +1,110 @@ +// Utilities for saving/restoring tensor slice checkpoints. + +#ifndef TENSORFLOW_UTIL_SAVED_TENSOR_SLICE_UTIL_H_ +#define TENSORFLOW_UTIL_SAVED_TENSOR_SLICE_UTIL_H_ + +#include <string> // for string +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/framework/tensor.pb.h" +#include "tensorflow/core/framework/tensor_slice.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/public/status.h" // for Status + +namespace tensorflow { + +namespace checkpoint { + +// The key for the metadata in the tensor slice checkpoint files. It is "" so +// that the metadata is always at the beginning of a checkpoint file. +extern const char kSavedTensorSlicesKey[]; + +// Encode a tensor name + a tensor slice into an ordered code and outputs it as +// a string. +// The format is +// <0> +// <tensor_name> +// <rank> +// <dim-0-start><dim-0-length> +// <dim-1-start><dim-1-length> +// ... + +string EncodeTensorNameSlice(const string& name, + const tensorflow::TensorSlice& slice); + +// Parse out the name and the slice from string encoded as an ordered code. +Status DecodeTensorNameSlice(const string& code, string* name, + tensorflow::TensorSlice* slice); + +template <typename T> +struct SaveTypeTraits; + +template <typename T> +const typename SaveTypeTraits<T>::SavedType* TensorProtoData( + const TensorProto& t); + +template <typename T> +protobuf::RepeatedField<typename SaveTypeTraits<T>::SavedType>* +MutableTensorProtoData(TensorProto* t); + +template <typename T> +void Fill(T* data, size_t n, TensorProto* t); + +#define TENSOR_PROTO_EXTRACT_TYPE(TYPE, FIELD, FTYPE) \ + template <> \ + struct SaveTypeTraits<TYPE> { \ + static constexpr bool supported = true; \ + typedef FTYPE SavedType; \ + }; \ + template <> \ + inline const FTYPE* TensorProtoData<TYPE>(const TensorProto& t) { \ + static_assert(SaveTypeTraits<TYPE>::supported, \ + "Specified type " #TYPE " not supported for Restore"); \ + return reinterpret_cast<const FTYPE*>(t.FIELD##_val().data()); \ + } \ + template <> \ + inline protobuf::RepeatedField<FTYPE>* MutableTensorProtoData<TYPE>( \ + TensorProto * t) { \ + static_assert(SaveTypeTraits<TYPE>::supported, \ + "Specified type " #TYPE " not supported for Save"); \ + return reinterpret_cast<protobuf::RepeatedField<FTYPE>*>( \ + t->mutable_##FIELD##_val()); \ + } \ + template <> \ + inline void Fill(const TYPE* data, size_t n, TensorProto* t) { \ + typename protobuf::RepeatedField<FTYPE> copy(data, data + n); \ + t->mutable_##FIELD##_val()->Swap(©); \ + } + +TENSOR_PROTO_EXTRACT_TYPE(float, float, float); +TENSOR_PROTO_EXTRACT_TYPE(double, double, double); +TENSOR_PROTO_EXTRACT_TYPE(int32, int, int32); +TENSOR_PROTO_EXTRACT_TYPE(int64, int64, int64); +TENSOR_PROTO_EXTRACT_TYPE(uint8, int, int32); +TENSOR_PROTO_EXTRACT_TYPE(int8, int, int32); +TENSOR_PROTO_EXTRACT_TYPE(int16, int, int32); +TENSOR_PROTO_EXTRACT_TYPE(qint8, int, int32); +TENSOR_PROTO_EXTRACT_TYPE(quint8, int, int32); + +#undef TENSOR_PROTO_EXTRACT_TYPE + +template <> +struct SaveTypeTraits<qint32> : SaveTypeTraits<int32> {}; + +template <> +inline const int32* TensorProtoData<qint32>(const TensorProto& t) { + static_assert(SaveTypeTraits<qint32>::supported, + "Specified type qint32 not supported for Restore"); + return reinterpret_cast<const int32*>(t.int_val().data()); +} + +inline void Fill(const qint32* data, size_t n, TensorProto* t) { + const int32* p = reinterpret_cast<const int32*>(data); + typename protobuf::RepeatedField<int32> copy(p, p + n); + t->mutable_int_val()->Swap(©); +} + +} // namespace checkpoint + +} // namespace tensorflow + +#endif // TENSORFLOW_UTIL_SAVED_TENSOR_SLICE_UTIL_H_ diff --git a/tensorflow/core/util/saved_tensor_slice_util_test.cc b/tensorflow/core/util/saved_tensor_slice_util_test.cc new file mode 100644 index 0000000000..2c34c903db --- /dev/null +++ b/tensorflow/core/util/saved_tensor_slice_util_test.cc @@ -0,0 +1,32 @@ +#include "tensorflow/core/util/saved_tensor_slice_util.h" + +#include <gtest/gtest.h> +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/protobuf.h" + +namespace tensorflow { + +namespace checkpoint { + +namespace { + +// Testing serialization of tensor name and tensor slice in the ordered code +// format. +TEST(TensorShapeUtilTest, TensorNameSliceToOrderedCode) { + { + TensorSlice s = TensorSlice::ParseOrDie("-:-:1,3:4,5"); + string buffer = EncodeTensorNameSlice("foo", s); + string name; + s.Clear(); + TF_CHECK_OK(DecodeTensorNameSlice(buffer, &name, &s)); + EXPECT_EQ("foo", name); + EXPECT_EQ("-:-:1,3:4,5", s.DebugString()); + } +} + +} // namespace + +} // namespace checkpoint + +} // namespace tensorflow diff --git a/tensorflow/core/util/sparse/README.md b/tensorflow/core/util/sparse/README.md new file mode 100644 index 0000000000..7b0799eb0e --- /dev/null +++ b/tensorflow/core/util/sparse/README.md @@ -0,0 +1,222 @@ +SparseTensor +============ + +Sparse Tensors are stored as two dense tensors and a shape: + +* `indices`: a `brain::Tensor` storing a matrix, typically `int64` +* `values`: a `brain::Tensor` storing a vector with values of type T. +* `shape`: a `TensorShape` storing the bounds of the underlying tensor +* `order`: (optional) a `gtl::InlinedVector<int64,8>` with the dimensions + along which the indices are ordered. + +Let + + ix = indices.matrix<int64>() + vals = values.vec<T>() + +The shape of `ix` is `N x NDIMS`, and each row corresponds to the +index of a single element of the sparse tensor. + +The length of `vals` must be `N`, and `vals(i)` corresponds to the +value with index `ix(i,:)`. + +Shape must be a `TensorShape` with `dims() == NDIMS`. +The shape is the full shape of the dense tensor these indices +represent. + +To be specific, the representation (pseudocode) is: + + tensor[ix[i,:]] == vals[i] for i = 0, ..., N-1 + +Ordering +-------- + +Indices need not be provided in order. For example, the following +index matrix is ordered according to dimension order `{0, 1, 2}`. + + [0 0 1] + [0 1 1] + [2 0 2] + +However, you can provide an unordered version: + + [2 0 2] + [0 0 1] + [0 1 1] + +If the SparseTensor is constructed without a provided order, then a +the default order is `{-1, ..., -1}`. Certain operations will fail or crash +when the order is not provided. + +Resorting the SparseTensor in-place (which resorts the underlying index and +values tensors in-place) will update the order. The cost of reordering the +matrix is `O(N*log(N))`, and requires `O(N)` additional temporary space to store +a reordering index. If the default order is not specified and reordering is not +performed, the following will happen: + +* `group()` will **raise an assertion failure** +* `IndicesValid()` will **raise an assertion failure** + +To update the internal index ordering after construction, call +`Reorder<T>()` via, e.g., `Reorder<T>({0,1,2})`. +After this step, all the above methods should work correctly. + +The method `IndicesValid()` checks to make sure: + +* `0 <= ix(i, d) < shape.dim_size(d)` +* indices do not repeat +* indices are in order + +Iterating +--------- + +### group({grouping dims}) + +* provides an iterator that groups entries according to + dimensions you care about +* may require a sort if your data isn't presorted in a way that's + compatible with grouping_dims +* for each group, returns the group index (values of the group + dims for this iteration), the subset of indices in this group, + and the subset of values in this group. these are lazy outputs + so to read them individually, copy them as per the example + below. + +#### **NOTE** +`group({dim0, ..., dimk})` will **raise an assertion failure** if the +order of the SparseTensor does not match the dimensions you wish to group by. +You must either have your indices in the correct order and construct the +SparseTensor with + + order = {dim0, ..., dimk, ...} + +or call + + Reorder<T>({dim0, .., dimk, ...}) + +to sort the SparseTensor before grouping. + +Example of grouping: + + Tensor indices(DT_INT64, TensorShape({N, NDIMS}); + Tensor values(DT_STRING, TensorShape({N}); + TensorShape shape({dim0,...}); + SparseTensor sp(indices, vals, shape); + sp.Reorder<string>({1, 2, 0, 3, ...}); // Must provide NDIMS dims. + // group according to dims 1 and 2 + for (const auto& g : sp.group({1, 2})) { + cout << "vals of ix[:, 1,2] for this group: " + << g.group()[0] << ", " << g.group()[1]; + cout << "full indices of group:\n" << g.indices(); + cout << "values of group:\n" << g.values(); + + TTypes<int64>::UnalignedMatrix g_ix = g.indices(); + TTypes<string>::UnalignedVec g_v = g.values(); + ASSERT(g_ix.dimension(0) == g_v.size()); // number of elements match. + } + + +ToDense +-------- + +Converts sparse tensor to dense. You must provide a pointer to the +dense tensor (preallocated). `ToDense()` will optionally +preinitialize the tensor with zeros. + +Shape checking is performed, as is boundary checking. + + Tensor indices(DT_INT64, TensorShape({N, NDIMS}); + Tensor values(DT_STRING, TensorShape({N}); + TensorShape shape({dim0,...}); + SparseTensor sp(indices, vals, shape); + ASSERT(sp.IndicesValid()); // checks ordering & index bounds. + + Tensor dense(DT_STRING, shape); + // initialize other indices to zero. copy. + ASSERT(sp.ToDense<string>(&dense, true)); + + +Concat +-------- + +Concatenates multiple SparseTensors and returns a new SparseTensor. +This concatenation is with respect to the "dense" versions of these +SparseTensors. Concatenation is performed along dimension order[0] +of all tensors. As a result, shape[order[0]] may differ across +the inputs, but shape[d] for d != order[0] must match across all inputs. + +We call order[0] the **primary dimension**. + +**Prerequisites** + +* The inputs' ranks must all match. +* The inputs' order[0] must all match. +* The inputs' shapes must all match except for dimension order[0]. +* The inputs' values must all be of the same type. + +If any of these are false, concat will die with an assertion failure. + +Example: +Concatenate two sparse matrices along columns. + +Matrix 1: + + [0 0 1] + [2 0 0] + [3 0 4] + +Matrix 2: + + [0 0 0 0 0] + [0 1 0 0 0] + [2 0 0 1 0] + +Concatenated Matrix: + + [0 0 1 0 0 0 0 0] + [2 0 0 0 1 0 0 0] + [3 0 4 2 0 0 1 0] + +Expected input shapes, orders, and `nnz()`: + + shape_1 = TensorShape({3, 3}) + shape_2 = TensorShape({3, 8}) + order_1 = {1, 0} // primary order is 1, columns + order_2 = {1, 0} // primary order is 1, must match + nnz_1 = 4 + nnz_2 = 3 + +Output shapes and orders: + + conc_shape = TensorShape({3, 11}) // primary dim increased, others same + conc_order = {1, 0} // Orders match along all inputs + conc_nnz = 7 // Sum of nonzeros of inputs + +Coding Example: + + Tensor ix1(DT_INT64, TensorShape({N1, 3}); + Tensor vals1(DT_STRING, TensorShape({N1, 3}); + Tensor ix2(DT_INT64, TensorShape({N2, 3}); + Tensor vals2(DT_STRING, TensorShape({N2, 3}); + Tensor ix3(DT_INT64, TensorShape({N3, 3}); + Tensor vals3(DT_STRING, TensorShape({N3, 3}); + + SparseTensor st1(ix1, vals1, TensorShape({10, 20, 5}), {1, 0, 2}); + SparseTensor st2(ix2, vals2, TensorShape({10, 10, 5}), {1, 0, 2}); + // For kicks, st3 indices are out of order, but order[0] matches so we + // can still concatenate along this dimension. + SparseTensor st3(ix3, vals3, TensorShape({10, 30, 5}), {1, 2, 0}); + + SparseTensor conc = SparseTensor::Concat<string>({st1, st2, st3}); + Tensor ix_conc = conc.indices(); + Tensor vals_conc = conc.values(); + EXPECT_EQ(conc.nnz(), st1.nnz() + st2.nnz() + st3.nnz()); + EXPECT_EQ(conc.Shape(), TensorShape({10, 60, 5})); + EXPECT_EQ(conc.Order(), {-1, -1, -1}); + + // Reorder st3 so all input tensors have the exact same orders. + st3.Reorder<string>({1, 0, 2}); + SparseTensor conc2 = SparseTensor::Concat<string>({st1, st2, st3}); + EXPECT_EQ(conc2.Order(), {1, 0, 2}); + // All indices' orders matched, so output is in order. + EXPECT_TRUE(conc2.IndicesValid()); diff --git a/tensorflow/core/util/sparse/dim_comparator.h b/tensorflow/core/util/sparse/dim_comparator.h new file mode 100644 index 0000000000..57473867cf --- /dev/null +++ b/tensorflow/core/util/sparse/dim_comparator.h @@ -0,0 +1,60 @@ +#ifndef TENSORFLOW_UTIL_SPARSE_DIM_COMPARATOR_H_ +#define TENSORFLOW_UTIL_SPARSE_DIM_COMPARATOR_H_ + +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/platform/logging.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +namespace tensorflow { +namespace sparse { + +///////////////// +// DimComparator +///////////////// +// +// Helper class, mainly used by the IndexSortOrder. This comparator +// can be passed to e.g. std::sort, or any other sorter, to sort two +// rows of an index matrix according to the dimension(s) of interest. +// The dimensions to sort by are passed to the constructor as "order". +// +// Example: if given index matrix IX, two rows ai and bi, and order = {2,1}. +// operator() compares +// IX(ai,2) < IX(bi,2). +// If IX(ai,2) == IX(bi,2), it compares +// IX(ai,1) < IX(bi,1). +// +// This can be used to sort a vector of row indices into IX according to +// the values in IX in particular columns (dimensions) of interest. +class DimComparator { + public: + typedef typename gtl::ArraySlice<int64> VarDimArray; + + inline DimComparator(const TTypes<int64>::Matrix& ix, + const VarDimArray& order, int dims) + : ix_(ix), order_(order), dims_(dims) { + CHECK_GT(order.size(), 0) << "Must order using at least one index"; + CHECK_LE(order.size(), dims_) << "Can only sort up to dims"; + for (size_t d = 0; d < order.size(); ++d) { + CHECK_GE(order[d], 0); + CHECK_LT(order[d], dims); + } + } + + inline bool operator()(const int64 i, const int64 j) const { + for (int di = 0; di < dims_; ++di) { + const int64 d = order_[di]; + if (ix_(i, d) < ix_(j, d)) return true; + if (ix_(i, d) > ix_(j, d)) return false; + } + return false; + } + + const TTypes<int64>::Matrix ix_; + const VarDimArray order_; + const int dims_; +}; + +} // namespace sparse +} // namespace tensorflow + +#endif // TENSORFLOW_UTIL_SPARSE_DIM_COMPARATOR_H_ diff --git a/tensorflow/core/util/sparse/group_iterator.cc b/tensorflow/core/util/sparse/group_iterator.cc new file mode 100644 index 0000000000..e153bcdbb4 --- /dev/null +++ b/tensorflow/core/util/sparse/group_iterator.cc @@ -0,0 +1,49 @@ +#include "tensorflow/core/util/sparse/group_iterator.h" + +namespace tensorflow { +namespace sparse { + +void GroupIterable::IteratorStep::UpdateEndOfGroup() { + ++next_loc_; + int64 N = iter_->ix_.dim_size(0); + auto ix_t = iter_->ix_.template matrix<int64>(); + while (next_loc_ < N && iter_->GroupMatches(ix_t, loc_, next_loc_)) { + ++next_loc_; + } +} + +bool GroupIterable::IteratorStep::operator!=(const IteratorStep& rhs) const { + CHECK_EQ(rhs.iter_, iter_) << "Can't compare steps from different iterators"; + return (rhs.loc_ != loc_); +} + +GroupIterable::IteratorStep& GroupIterable::IteratorStep:: +operator++() { // prefix ++ + loc_ = next_loc_; + UpdateEndOfGroup(); + return *this; +} + +GroupIterable::IteratorStep GroupIterable::IteratorStep::operator++( + int) { // postfix ++ + IteratorStep lhs(*this); + ++(*this); + return lhs; +} + +std::vector<int64> Group::group() const { + std::vector<int64> g; + auto ix_t = iter_->ix_.template matrix<int64>(); + for (const int d : iter_->group_dims_) { + g.push_back(ix_t(loc_, d)); + } + return g; +} + +TTypes<int64>::UnalignedConstMatrix Group::indices() const { + return TTypes<int64>::UnalignedConstMatrix( + &(iter_->ix_.matrix<int64>()(loc_, 0)), next_loc_ - loc_, iter_->dims_); +} + +} // namespace sparse +} // namespace tensorflow diff --git a/tensorflow/core/util/sparse/group_iterator.h b/tensorflow/core/util/sparse/group_iterator.h new file mode 100644 index 0000000000..8423d54f27 --- /dev/null +++ b/tensorflow/core/util/sparse/group_iterator.h @@ -0,0 +1,120 @@ +#ifndef TENSORFLOW_UTIL_SPARSE_GROUP_ITERATOR_H_ +#define TENSORFLOW_UTIL_SPARSE_GROUP_ITERATOR_H_ + +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/public/status.h" +#include "tensorflow/core/public/tensor.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +namespace tensorflow { +namespace sparse { + +class GroupIterable; // Predeclare GroupIterable for Group. + +// This class is returned when dereferencing a GroupIterable iterator. +// It provides the methods group(), indices(), and values(), which +// provide access into the underlying SparseTensor. +class Group { + public: + Group(GroupIterable* iter, int64 loc, int64 next_loc) + : iter_(iter), loc_(loc), next_loc_(next_loc) {} + + std::vector<int64> group() const; + TTypes<int64>::UnalignedConstMatrix indices() const; + template <typename T> + typename TTypes<T>::UnalignedVec values() const; + + private: + GroupIterable* iter_; + int64 loc_; + int64 next_loc_; +}; + +///////////////// +// GroupIterable +///////////////// +// +// Returned when calling sparse_tensor.group({dim0, dim1, ...}). +// +// Please note: the sparse_tensor should already be ordered according +// to {dim0, dim1, ...}. Otherwise this iteration will return invalid groups. +// +// Allows grouping and iteration of the SparseTensor according to the +// subset of dimensions provided to the group call. +// +// The actual grouping dimensions are stored in the +// internal vector group_dims_. Iterators inside the iterable provide +// the three methods: +// +// * group(): returns a vector with the current group dimension values. +// * indices(): a map of index, providing the indices in +// this group. +// * values(): a map of values, providing the values in +// this group. +// +// To iterate across GroupIterable, see examples in README.md. +// + +// Forward declaration of SparseTensor +class GroupIterable { + public: + typedef gtl::ArraySlice<int64> VarDimArray; + + GroupIterable(Tensor ix, Tensor vals, int dims, const VarDimArray& group_dims) + : ix_(ix), vals_(vals), dims_(dims), group_dims_(group_dims) {} + + class IteratorStep; + + IteratorStep begin() { return IteratorStep(this, 0); } + IteratorStep end() { return IteratorStep(this, ix_.dim_size(0)); } + + template <typename TIX> + inline bool GroupMatches(const TIX& ix, int64 loc_a, int64 loc_b) const { + bool matches = true; + for (int d : group_dims_) { + if (ix(loc_a, d) != ix(loc_b, d)) { + matches = false; + } + } + return matches; + } + + class IteratorStep { + public: + IteratorStep(GroupIterable* iter, int64 loc) + : iter_(iter), loc_(loc), next_loc_(loc_) { + UpdateEndOfGroup(); + } + + void UpdateEndOfGroup(); + bool operator!=(const IteratorStep& rhs) const; + IteratorStep& operator++(); // prefix ++ + IteratorStep operator++(int); // postfix ++ + Group operator*() const { return Group(iter_, loc_, next_loc_); } + + private: + GroupIterable* iter_; + int64 loc_; + int64 next_loc_; + }; + + private: + friend class Group; + Tensor ix_; + Tensor vals_; + const int dims_; + const VarDimArray group_dims_; +}; + +// Implementation of Group::values<T>() +template <typename T> +typename TTypes<T>::UnalignedVec Group::values() const { + return typename TTypes<T>::UnalignedVec(&(iter_->vals_.vec<T>()(loc_)), + next_loc_ - loc_); +} + +} // namespace sparse +} // namespace tensorflow + +#endif // TENSORFLOW_UTIL_SPARSE_GROUP_ITERATOR_H_ diff --git a/tensorflow/core/util/sparse/sparse_tensor.h b/tensorflow/core/util/sparse/sparse_tensor.h new file mode 100644 index 0000000000..dcb75e7f54 --- /dev/null +++ b/tensorflow/core/util/sparse/sparse_tensor.h @@ -0,0 +1,353 @@ +#ifndef TENSORFLOW_UTIL_SPARSE_SPARSE_TENSOR_H_ +#define TENSORFLOW_UTIL_SPARSE_SPARSE_TENSOR_H_ + +#include <limits> + +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/public/status.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/public/tensor.h" +#include "tensorflow/core/util/sparse/dim_comparator.h" +#include "tensorflow/core/util/sparse/group_iterator.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +namespace tensorflow { +namespace sparse { + +class SparseTensor { + public: + typedef typename gtl::ArraySlice<int64> VarDimArray; + + SparseTensor(Tensor ix, Tensor vals, const TensorShape& shape) + : SparseTensor(ix, vals, shape, UndefinedOrder(shape)) {} + + SparseTensor(Tensor ix, Tensor vals, const TensorShape& shape, + const VarDimArray& order) + : ix_(ix), + vals_(vals), + shape_(shape), + order_(order.begin(), order.end()), + dims_(GetDimsFromIx(ix)) { + CHECK_EQ(ix.dtype(), DT_INT64) << "indices must be type int64 but got: " + << ix.dtype(); + CHECK(TensorShapeUtils::IsMatrix(ix.shape())) + << "indices must be a matrix, but got: " << ix.shape().DebugString(); + CHECK(TensorShapeUtils::IsVector(vals.shape())) + << "vals must be a vec, but got: " << vals.shape().DebugString(); + CHECK_EQ(ix.shape().dim_size(0), vals.shape().dim_size(0)) + << "indices and values rows (indexing dimension) must match."; + } + + std::size_t num_entries() const { return ix_.dim_size(0); } + + const Tensor& indices() const { return ix_; } + + const Tensor& values() const { return vals_; } + + DataType dtype() const { return vals_.dtype(); } + + bool IndicesValid() const { + const auto ix_t = ix_.matrix<int64>(); + for (int64 ord : order_) { + CHECK_GE(ord, 0) << "Order was not provided. Provide an order at " + "construction time or run ReorderInPlace"; + } + + for (std::size_t n = 0; n < num_entries(); ++n) { + if (!IndexValid(ix_t, n)) return false; + } + + return true; + } + + // Returns the tensor shape (the dimensions of the "densified" + // tensor this tensor represents). + const TensorShape shape() const { return shape_; } + + const VarDimArray order() const { return order_; } + + // Resorts the indices and values according to the dimensions in order. + template <typename T> + void Reorder(const VarDimArray& order); + + // Returns a group iterable that can be used for clumping indices + // and values according to the group indices of interest. + // + // Precondition: order()[0..group_ix.size()] == group_ix. + // + // See the README.md in this directory for more usage information. + GroupIterable group(const VarDimArray& group_ix) { + CHECK_LE(group_ix.size(), dims_); + for (std::size_t di = 0; di < group_ix.size(); ++di) { + CHECK_GE(group_ix[di], 0) << "Group dimension out of range"; + CHECK_LT(group_ix[di], dims_) << "Group dimension out of range"; + CHECK_EQ(group_ix[di], order_[di]) + << "Group dimension does not match sorted order"; + } + return GroupIterable(ix_, vals_, dims_, group_ix); + } + + // Stores the sparse indices into the dense tensor out. + // Preconditions: + // out->shape().dims() == shape().dims() + // out->shape().dim_size(d) >= shape(d) for all d + // + // Returns true on success. False on failure (mismatched dimensions + // or out-of-bounds indices). + // + // If initialize==True, ToDense first overwrites all coefficients in out to 0. + // + template <typename T> + bool ToDense(Tensor* out, bool initialize = true); + + // Concat() will concatenate all the tensors according to their first order + // dimension. All tensors must have identical shape except for + // the first order dimension. All tensors orders' first dimension + // must match. + // + // If all of the tensors have identical ordering, then the output + // will have this ordering. Otherwise the output is set as not + // having any order and a Reorder<T>() should be called on it before + // performing any subsequent operations. + template <typename T> + static SparseTensor Concat(const gtl::ArraySlice<SparseTensor>& tensors); + + private: + static int GetDimsFromIx(const Tensor& ix) { + CHECK(TensorShapeUtils::IsMatrix(ix.shape())); + return ix.dim_size(1); + } + + static gtl::InlinedVector<int64, 8> UndefinedOrder(const TensorShape& shape) { + return gtl::InlinedVector<int64, 8>(shape.dims(), -1); + } + + // Helper for IndicesValid() + inline bool IndexValid(const TTypes<int64>::ConstMatrix& ix_t, + int64 n) const { + bool different = false; + bool bad_order = false; + bool valid = true; + if (n == 0) { + for (int di = 0; di < dims_; ++di) { + if (ix_t(n, di) < 0 || ix_t(n, di) >= shape_.dim_size(di)) + valid = false; + } + different = true; + } else { + for (int di = 0; di < dims_; ++di) { + if (ix_t(n, di) < 0 || ix_t(n, di) >= shape_.dim_size(di)) + valid = false; + int64 diff = ix_t(n, order_[di]) - ix_t(n - 1, order_[di]); + if (diff > 0) different = true; + if (!different && diff < 0) bad_order = true; + } + } + if (!valid) return false; // Out of bounds + if (!different) return false; // The past two indices are identical... + if (bad_order) return false; // Decreasing in order. + return true; + } + + // Helper for ToDense<T>() + template <typename T> + bool ValidateAndInitializeToDense(Tensor* out, bool initialize); + + Tensor ix_; + Tensor vals_; + TensorShape shape_; + gtl::InlinedVector<int64, 8> order_; + const int dims_; +}; + +// This operation updates the indices and values Tensor rows, so it is +// an in-place algorithm. It requires O(N log N) time and O(N) +// temporary space. +template <typename T> +void SparseTensor::Reorder(const VarDimArray& order) { + CHECK_EQ(DataTypeToEnum<T>::v(), dtype()) + << "Reorder requested with the wrong datatype"; + CHECK_EQ(order.size(), dims_) << "Order length must be SparseTensor rank"; + auto ix_t = ix_.matrix<int64>(); + auto vals_t = vals_.vec<T>(); + + DimComparator sorter(ix_t, order, dims_); + + std::vector<int64> reorder(num_entries()); + std::iota(reorder.begin(), reorder.end(), 0); + + // Sort to get order of indices + std::sort(reorder.begin(), reorder.end(), sorter); + + // We have a forward reordering, but what we'll need is a + // permutation (the inverse). This can be calculated with O(1) + // additional + // and O(n) time (INVPERM) but we just do the simple thing here. + std::vector<int64> permutation(reorder.size()); + for (std::size_t n = 0; n < reorder.size(); ++n) { + permutation[reorder[n]] = n; + } + + // Update indices & values by converting the permutations to + // a product of transpositions. Iterate over the cycles in the + // permutation, and convert each of those into a product of + // transpositions (swaps): + // https://en.wikipedia.org/wiki/Cyclic_permutation + // This is N swaps, 2*N comparisons. + for (std::size_t n = 0; n + 1 < permutation.size(); ++n) { + while (n != permutation[n]) { + std::size_t r = permutation[n]; + std::swap_ranges(&(ix_t(n, 0)), &(ix_t(n + 1, 0)), &(ix_t(r, 0))); + std::swap(vals_t(n), vals_t(r)); + std::swap(permutation[n], permutation[r]); + } + } + + order_ = gtl::InlinedVector<int64, 8>(order.begin(), order.end()); +} + +template <typename T> +bool SparseTensor::ValidateAndInitializeToDense(Tensor* out, bool initialize) { + CHECK_EQ(DataTypeToEnum<T>::v(), dtype()) + << "ToDense requested with the wrong datatype"; + + CHECK_EQ(out->shape().dims(), dims_) + << "Incompatible dimensions between SparseTensor and output"; + + CHECK_EQ(out->dtype(), DataTypeToEnum<T>::v()) + << "Output must be type: " << DataTypeToEnum<T>::v() + << " but got: " << out->dtype(); + + // Make sure the dense output is the same rank and has room + // to hold the SparseTensor. + const auto& out_shape = out->shape(); + if (shape_.dims() != out_shape.dims()) return false; + for (int d = 0; d < shape_.dims(); ++d) { + if (shape_.dim_size(d) > out_shape.dim_size(d)) return false; + } + + if (initialize) { + auto out_t = out->flat<T>(); + out_t.setConstant(T()); + } + + return true; +} + +template <typename T> +bool SparseTensor::ToDense(Tensor* out, bool initialize) { + if (!ValidateAndInitializeToDense<T>(out, initialize)) return false; + + auto out_t = out->flat<T>(); + auto ix_t = ix_.matrix<int64>(); + auto vals_t = vals_.vec<T>(); + + std::vector<int64> strides(dims_); + const auto& out_shape = out->shape(); + strides[dims_ - 1] = 1; + for (int d = dims_ - 2; d >= 0; --d) { + strides[d] = strides[d + 1] * out_shape.dim_size(d + 1); + } + + for (std::size_t n = 0; n < vals_t.dimension(0); ++n) { + bool invalid_dims = false; + int64 ix = 0; + for (int d = 0; d < dims_; ++d) { + const int64 ix_n_d = ix_t(n, d); + if (ix_n_d < 0 || ix_n_d >= out_shape.dim_size(d)) { + invalid_dims = true; + } + ix += strides[d] * ix_n_d; + } + if (invalid_dims) return false; + out_t(ix) = vals_t(n); + } + return true; +} + +template <typename T> +SparseTensor SparseTensor::Concat( + const gtl::ArraySlice<SparseTensor>& tensors) { + CHECK_GE(tensors.size(), 1) << "Cannot concat 0 SparseTensors"; + const int dims = tensors[0].dims_; + CHECK_GE(dims, 1) << "Cannot concat 0-dimensional SparseTensors"; + auto order_0 = tensors[0].order(); + const int primary_dim = order_0[0]; + gtl::InlinedVector<int64, 8> final_order(order_0.begin(), order_0.end()); + TensorShape final_shape(tensors[0].shape()); + final_shape.set_dim(primary_dim, 0); // We'll build this up as we go along. + int num_entries = 0; + + bool fully_ordered = true; + for (const SparseTensor& st : tensors) { + CHECK_EQ(st.dims_, dims) << "All SparseTensors must have the same rank."; + CHECK_EQ(DataTypeToEnum<T>::v(), st.dtype()) + << "Concat requested with the wrong data type"; + CHECK_GE(st.order()[0], 0) << "SparseTensor must be ordered"; + CHECK_EQ(st.order()[0], primary_dim) + << "All SparseTensors' order[0] must match. This is the concat dim."; + if (st.order() != final_order) fully_ordered = false; + const TensorShape st_shape = st.shape(); + for (int d = 0; d < dims - 1; ++d) { + const int cdim = (d < primary_dim) ? d : d + 1; + CHECK_EQ(final_shape.dim_size(cdim), st_shape.dim_size(cdim)) + << "All SparseTensors' shapes must match except on the concat dim. " + << "Concat dim: " << primary_dim + << ", mismatched shape at dim: " << cdim + << ". Expecting shape like: " << final_shape.DebugString() + << " but saw shape: " << st_shape.DebugString(); + } + + // Update dimension of final shape + final_shape.set_dim(primary_dim, final_shape.dim_size(primary_dim) + + st_shape.dim_size(primary_dim)); + + num_entries += st.num_entries(); // Update number of entries + } + + // If nonconsistent ordering among inputs, set final order to -1s. + if (!fully_ordered) { + final_order = UndefinedOrder(final_shape); + } + + Tensor output_ix(DT_INT64, TensorShape({num_entries, dims})); + Tensor output_vals(DataTypeToEnum<T>::v(), TensorShape({num_entries})); + + auto ix_t = output_ix.matrix<int64>(); + auto vals_t = output_vals.vec<T>(); + + Eigen::DenseIndex offset = 0; + int64 shape_offset = 0; + for (const SparseTensor& st : tensors) { + int st_num_entries = st.num_entries(); + Eigen::DSizes<Eigen::DenseIndex, 2> ix_start(offset, 0); + Eigen::DSizes<Eigen::DenseIndex, 2> ix_size(st_num_entries, dims); + Eigen::DSizes<Eigen::DenseIndex, 1> vals_start(offset); + Eigen::DSizes<Eigen::DenseIndex, 1> vals_size(st_num_entries); + + // Fill in indices & values. + ix_t.slice(ix_start, ix_size) = st.ix_.matrix<int64>(); + vals_t.slice(vals_start, vals_size) = st.vals_.vec<T>(); + + Eigen::DSizes<Eigen::DenseIndex, 2> ix_update_start(offset, primary_dim); + Eigen::DSizes<Eigen::DenseIndex, 2> ix_update_size(st_num_entries, 1); + // The index associated with the primary dimension gets increased + // by the shapes of the previous concatted Tensors. + auto update_slice = ix_t.slice(ix_update_start, ix_update_size); + update_slice += update_slice.constant(shape_offset); + + offset += st_num_entries; + shape_offset += st.shape().dim_size(primary_dim); + } + + return SparseTensor(output_ix, output_vals, final_shape, final_order); +} + +} // namespace sparse +} // namespace tensorflow + +#endif // TENSORFLOW_UTIL_SPARSE_SPARSE_TENSOR_H_ diff --git a/tensorflow/core/util/sparse/sparse_tensor_test.cc b/tensorflow/core/util/sparse/sparse_tensor_test.cc new file mode 100644 index 0000000000..47126b7187 --- /dev/null +++ b/tensorflow/core/util/sparse/sparse_tensor_test.cc @@ -0,0 +1,467 @@ +#include "tensorflow/core/util/sparse/sparse_tensor.h" + +#include <string> +#include <vector> + +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/public/tensor.h" +#include <gtest/gtest.h> +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +namespace tensorflow { +namespace sparse { +namespace { + +Eigen::Tensor<int64, 2, Eigen::RowMajor, Eigen::DenseIndex> +GetSimpleIndexTensor(int N, const int NDIM) { + Eigen::Tensor<int64, 2, Eigen::RowMajor, Eigen::DenseIndex> ix(N, NDIM); + ix(0, 0) = 0; + ix(0, 1) = 0; + ix(0, 2) = 0; + + ix(1, 0) = 3; + ix(1, 1) = 0; + ix(1, 2) = 0; + + ix(2, 0) = 2; + ix(2, 1) = 0; + ix(2, 2) = 0; + + ix(3, 0) = 0; + ix(3, 1) = 1; + ix(3, 2) = 0; + + ix(4, 0) = 0; + ix(4, 1) = 0; + ix(4, 2) = 2; + return ix; +} + +TEST(SparseTensorTest, DimComparatorSorts) { + std::size_t N = 5; + const int NDIM = 3; + auto ix = GetSimpleIndexTensor(N, NDIM); + TTypes<int64>::Matrix map(ix.data(), N, NDIM); + + std::vector<int64> sorting(N); + for (std::size_t n = 0; n < N; ++n) sorting[n] = n; + + // new order should be: {0, 4, 3, 2, 1} + std::vector<int64> order{0, 1, 2}; + DimComparator sorter(map, order, NDIM); + std::sort(sorting.begin(), sorting.end(), sorter); + + EXPECT_EQ(sorting, std::vector<int64>({0, 4, 3, 2, 1})); + + // new order should be: {0, 3, 2, 1, 4} + std::vector<int64> order1{2, 0, 1}; + DimComparator sorter1(map, order1, NDIM); + for (std::size_t n = 0; n < N; ++n) sorting[n] = n; + std::sort(sorting.begin(), sorting.end(), sorter1); + + EXPECT_EQ(sorting, std::vector<int64>({0, 3, 2, 1, 4})); +} + +TEST(SparseTensorTest, SparseTensorConstruction) { + int N = 5; + const int NDIM = 3; + auto ix_c = GetSimpleIndexTensor(N, NDIM); + Eigen::Tensor<string, 1, Eigen::RowMajor> vals_c(N); + vals_c(0) = "hi0"; + vals_c(1) = "hi1"; + vals_c(2) = "hi2"; + vals_c(3) = "hi3"; + vals_c(4) = "hi4"; + + Tensor ix(DT_INT64, TensorShape({N, NDIM})); + Tensor vals(DT_STRING, TensorShape({N})); + + auto ix_t = ix.matrix<int64>(); + auto vals_t = vals.vec<string>(); + vals_t = vals_c; + ix_t = ix_c; + + TensorShape shape({10, 10, 10}); + std::vector<int64> order{0, 1, 2}; + SparseTensor st(ix, vals, shape, order); + EXPECT_FALSE(st.IndicesValid()); // Out of order + + // Regardless of how order is updated; so long as there are no + // duplicates, the resulting indices are valid. + st.Reorder<string>({2, 0, 1}); + EXPECT_TRUE(st.IndicesValid()); + EXPECT_EQ(vals_t(0), "hi0"); + EXPECT_EQ(vals_t(1), "hi3"); + EXPECT_EQ(vals_t(2), "hi2"); + EXPECT_EQ(vals_t(3), "hi1"); + EXPECT_EQ(vals_t(4), "hi4"); + + ix_t = ix_c; + vals_t = vals_c; + st.Reorder<string>({0, 1, 2}); + EXPECT_TRUE(st.IndicesValid()); + EXPECT_EQ(vals_t(0), "hi0"); + EXPECT_EQ(vals_t(1), "hi4"); + EXPECT_EQ(vals_t(2), "hi3"); + EXPECT_EQ(vals_t(3), "hi2"); + EXPECT_EQ(vals_t(4), "hi1"); + + ix_t = ix_c; + vals_t = vals_c; + st.Reorder<string>({2, 1, 0}); + EXPECT_TRUE(st.IndicesValid()); +} + +TEST(SparseTensorTest, EmptySparseTensorAllowed) { + int N = 0; + const int NDIM = 3; + + Tensor ix(DT_INT64, TensorShape({N, NDIM})); + Tensor vals(DT_STRING, TensorShape({N})); + + TensorShape shape({10, 10, 10}); + std::vector<int64> order{0, 1, 2}; + SparseTensor st(ix, vals, shape, order); + EXPECT_TRUE(st.IndicesValid()); + EXPECT_EQ(st.order(), order); + + std::vector<int64> new_order{1, 0, 2}; + st.Reorder<string>(new_order); + EXPECT_TRUE(st.IndicesValid()); + EXPECT_EQ(st.order(), new_order); +} + +TEST(SparseTensorTest, SortingWorksCorrectly) { + int N = 30; + const int NDIM = 4; + + Tensor ix(DT_INT64, TensorShape({N, NDIM})); + Tensor vals(DT_STRING, TensorShape({N})); + TensorShape shape({1000, 1000, 1000, 1000}); + SparseTensor st(ix, vals, shape); + + auto ix_t = ix.matrix<int64>(); + + for (int n = 0; n < 100; ++n) { + ix_t = ix_t.random(Eigen::internal::UniformRandomGenerator<int64>(n + 1)); + ix_t = ix_t.abs() % 1000; + st.Reorder<string>({0, 1, 2, 3}); + EXPECT_TRUE(st.IndicesValid()); + st.Reorder<string>({3, 2, 1, 0}); + EXPECT_TRUE(st.IndicesValid()); + st.Reorder<string>({1, 0, 2, 3}); + EXPECT_TRUE(st.IndicesValid()); + st.Reorder<string>({3, 0, 2, 1}); + EXPECT_TRUE(st.IndicesValid()); + } +} + +TEST(SparseTensorTest, ValidateIndicesFindsInvalid) { + int N = 2; + const int NDIM = 3; + + Tensor ix(DT_INT64, TensorShape({N, NDIM})); + Tensor vals(DT_STRING, TensorShape({N})); + + Eigen::Tensor<int64, 2, Eigen::RowMajor> ix_orig(N, NDIM); + ix_orig(0, 0) = 0; + ix_orig(0, 1) = 0; + ix_orig(0, 2) = 0; + + ix_orig(1, 0) = 0; + ix_orig(1, 1) = 0; + ix_orig(1, 2) = 0; + + auto ix_t = ix.matrix<int64>(); + ix_t = ix_orig; + + TensorShape shape({10, 10, 10}); + std::vector<int64> order{0, 1, 2}; + SparseTensor st(ix, vals, shape, order); + + st.Reorder<string>(order); + EXPECT_FALSE(st.IndicesValid()); // two indices are identical + + ix_orig(1, 2) = 1; + ix_t = ix_orig; + st.Reorder<string>(order); + EXPECT_TRUE(st.IndicesValid()); // second index now (0, 0, 1) + + ix_orig(0, 2) = 1; + ix_t = ix_orig; + st.Reorder<string>(order); + EXPECT_FALSE(st.IndicesValid()); // first index now (0, 0, 1) +} + +TEST(SparseTensorTest, SparseTensorCheckBoundaries) { + int N = 5; + const int NDIM = 3; + + Tensor ix(DT_INT64, TensorShape({N, NDIM})); + Tensor vals(DT_STRING, TensorShape({N})); + + auto ix_t = GetSimpleIndexTensor(N, NDIM); + + ix.matrix<int64>() = ix_t; + + TensorShape shape({10, 10, 10}); + std::vector<int64> order{0, 1, 2}; + + SparseTensor st(ix, vals, shape, order); + EXPECT_FALSE(st.IndicesValid()); + + st.Reorder<string>(order); + EXPECT_TRUE(st.IndicesValid()); + + ix_t(0, 0) = 11; + ix.matrix<int64>() = ix_t; + st.Reorder<string>(order); + EXPECT_FALSE(st.IndicesValid()); + + ix_t(0, 0) = -1; + ix.matrix<int64>() = ix_t; + st.Reorder<string>(order); + EXPECT_FALSE(st.IndicesValid()); + + ix_t(0, 0) = 0; + ix.matrix<int64>() = ix_t; + st.Reorder<string>(order); + EXPECT_TRUE(st.IndicesValid()); +} + +TEST(SparseTensorTest, SparseTensorToDenseTensor) { + int N = 5; + const int NDIM = 3; + + Tensor ix(DT_INT64, TensorShape({N, NDIM})); + Tensor vals(DT_STRING, TensorShape({N})); + + auto ix_t = GetSimpleIndexTensor(N, NDIM); + auto vals_t = vals.vec<string>(); + + ix.matrix<int64>() = ix_t; + + vals_t(0) = "hi0"; + vals_t(1) = "hi1"; + vals_t(2) = "hi2"; + vals_t(3) = "hi3"; + vals_t(4) = "hi4"; + + TensorShape shape({4, 4, 5}); + std::vector<int64> order{0, 1, 2}; + SparseTensor st(ix, vals, shape, order); + + Tensor dense(DT_STRING, TensorShape({4, 4, 5})); + st.ToDense<string>(&dense); + + auto dense_t = dense.tensor<string, 3>(); + Eigen::array<Eigen::DenseIndex, NDIM> ix_n; + for (int n = 0; n < N; ++n) { + for (int d = 0; d < NDIM; ++d) ix_n[d] = ix_t(n, d); + EXPECT_EQ(dense_t(ix_n), vals_t(n)); + } + + // Spot checks on the others + EXPECT_EQ(dense_t(0, 0, 1), ""); + EXPECT_EQ(dense_t(0, 0, 3), ""); + EXPECT_EQ(dense_t(3, 3, 3), ""); + EXPECT_EQ(dense_t(3, 3, 4), ""); +} + +TEST(SparseTensorTest, SparseTensorToLargerDenseTensor) { + int N = 5; + const int NDIM = 3; + + Tensor ix(DT_INT64, TensorShape({N, NDIM})); + Tensor vals(DT_STRING, TensorShape({N})); + + auto ix_t = GetSimpleIndexTensor(N, NDIM); + auto vals_t = vals.vec<string>(); + + ix.matrix<int64>() = ix_t; + + vals_t(0) = "hi0"; + vals_t(1) = "hi1"; + vals_t(2) = "hi2"; + vals_t(3) = "hi3"; + vals_t(4) = "hi4"; + + TensorShape shape({4, 4, 5}); + std::vector<int64> order{0, 1, 2}; + SparseTensor st(ix, vals, shape, order); + + Tensor dense(DT_STRING, TensorShape({10, 10, 10})); + st.ToDense<string>(&dense); + + auto dense_t = dense.tensor<string, 3>(); + Eigen::array<Eigen::DenseIndex, NDIM> ix_n; + for (int n = 0; n < N; ++n) { + for (int d = 0; d < NDIM; ++d) ix_n[d] = ix_t(n, d); + EXPECT_EQ(dense_t(ix_n), vals_t(n)); + } + + // Spot checks on the others + EXPECT_EQ(dense_t(0, 0, 1), ""); + EXPECT_EQ(dense_t(0, 0, 3), ""); + EXPECT_EQ(dense_t(3, 3, 3), ""); + EXPECT_EQ(dense_t(3, 3, 4), ""); + EXPECT_EQ(dense_t(9, 0, 0), ""); + EXPECT_EQ(dense_t(9, 0, 9), ""); + EXPECT_EQ(dense_t(9, 9, 9), ""); +} + +TEST(SparseTensorTest, SparseTensorGroup) { + int N = 5; + const int NDIM = 3; + + Tensor ix(DT_INT64, TensorShape({N, NDIM})); + Tensor vals(DT_INT32, TensorShape({N})); + + auto ix_t = ix.matrix<int64>(); + auto vals_t = vals.vec<int32>(); + + ix_t = GetSimpleIndexTensor(N, NDIM); + + vals_t(0) = 1; // associated with ix (000) + vals_t(1) = 2; // associated with ix (300) + vals_t(2) = 3; // associated with ix (200) + vals_t(3) = 4; // associated with ix (010) + vals_t(4) = 5; // associated with ix (002) + + TensorShape shape({10, 10, 10}); + std::vector<int64> order{0, 1, 2}; + + SparseTensor st(ix, vals, shape, order); + st.Reorder<int32>(order); + + std::vector<std::vector<int64> > groups; + std::vector<TTypes<int64>::UnalignedConstMatrix> grouped_indices; + std::vector<TTypes<int32>::UnalignedVec> grouped_values; + + // Group by index 0 + auto gi = st.group({0}); + + // All the hard work is right here! + for (const auto& g : gi) { + groups.push_back(g.group()); + VLOG(1) << "Group: " << str_util::Join(g.group(), ","); + VLOG(1) << "Indices: " << g.indices(); + VLOG(1) << "Values: " << g.values<int32>(); + + grouped_indices.push_back(g.indices()); + grouped_values.push_back(g.values<int32>()); + } + + // Group by dimension 0, we have groups: 0--, 2--, 3-- + EXPECT_EQ(groups.size(), 3); + EXPECT_EQ(groups[0], std::vector<int64>({0})); + EXPECT_EQ(groups[1], std::vector<int64>({2})); + EXPECT_EQ(groups[2], std::vector<int64>({3})); + + std::vector<Eigen::Tensor<int64, 2, Eigen::RowMajor> > expected_indices; + std::vector<Eigen::Tensor<int32, 1, Eigen::RowMajor> > expected_vals; + + // First group: 000, 002, 010 + expected_indices.emplace_back(3, NDIM); // 3 x 3 tensor + expected_vals.emplace_back(3); // 3 x 5 x 1 x 1 tensor + expected_indices[0].setZero(); + expected_indices[0](1, 2) = 2; // 002 + expected_indices[0](2, 1) = 1; // 010 + expected_vals[0].setConstant(-1); + expected_vals[0](0) = 1; // val associated with ix 000 + expected_vals[0](1) = 5; // val associated with ix 002 + expected_vals[0](2) = 4; // val associated with ix 010 + + // Second group: 200 + expected_indices.emplace_back(1, NDIM); + expected_vals.emplace_back(1); + expected_indices[1].setZero(); + expected_indices[1](0, 0) = 2; // 200 + expected_vals[1](0) = 3; // val associated with ix 200 + + // Third group: 300 + expected_indices.emplace_back(1, NDIM); + expected_vals.emplace_back(1); + expected_indices[2].setZero(); + expected_indices[2](0, 0) = 3; // 300 + expected_vals[2](0) = 2; // val associated with ix 300 + + for (std::size_t gix = 0; gix < groups.size(); ++gix) { + // Compare indices + auto gi_t = grouped_indices[gix]; + Eigen::Tensor<bool, 0, Eigen::RowMajor> eval = + (gi_t == expected_indices[gix]).all(); + EXPECT_TRUE(eval()) << gix << " indices: " << gi_t << " vs. " + << expected_indices[gix]; + + // Compare values + auto gv_t = grouped_values[gix]; + eval = (gv_t == expected_vals[gix]).all(); + EXPECT_TRUE(eval()) << gix << " values: " << gv_t << " vs. " + << expected_vals[gix]; + } +} + +TEST(SparseTensorTest, Concat) { + int N = 5; + const int NDIM = 3; + + Tensor ix(DT_INT64, TensorShape({N, NDIM})); + Tensor vals(DT_STRING, TensorShape({N})); + + auto ix_c = GetSimpleIndexTensor(N, NDIM); + + auto ix_t = ix.matrix<int64>(); + auto vals_t = vals.vec<string>(); + + ix_t = ix_c; + + TensorShape shape({10, 10, 10}); + std::vector<int64> order{0, 1, 2}; + + SparseTensor st(ix, vals, shape, order); + EXPECT_FALSE(st.IndicesValid()); + st.Reorder<string>(order); + EXPECT_TRUE(st.IndicesValid()); + + SparseTensor concatted = SparseTensor::Concat<string>({st, st, st, st}); + EXPECT_EQ(concatted.order(), st.order()); + TensorShape expected_shape({40, 10, 10}); + EXPECT_EQ(concatted.shape(), expected_shape); + EXPECT_EQ(concatted.num_entries(), 4 * N); + EXPECT_TRUE(concatted.IndicesValid()); + + auto conc_ix_t = concatted.indices().matrix<int64>(); + auto conc_vals_t = concatted.values().vec<string>(); + + for (int n = 0; n < 4; ++n) { + for (int i = 0; i < N; ++i) { + // Dimensions match except the primary dim, which is offset by + // shape[order[0]] + EXPECT_EQ(conc_ix_t(n * N + i, 0), 10 * n + ix_t(i, 0)); + EXPECT_EQ(conc_ix_t(n * N + i, 1), ix_t(i, 1)); + EXPECT_EQ(conc_ix_t(n * N + i, 1), ix_t(i, 1)); + + // Values match + EXPECT_EQ(conc_vals_t(n * N + i), vals_t(i)); + } + } + + // Concat works if non-primary ix is out of order, but output order + // is not defined + SparseTensor st_ooo(ix, vals, shape, {0, 2, 1}); // non-primary ix OOO + SparseTensor conc_ooo = SparseTensor::Concat<string>({st, st, st, st_ooo}); + std::vector<int64> expected_ooo{-1, -1, -1}; + EXPECT_EQ(conc_ooo.order(), expected_ooo); + EXPECT_EQ(conc_ooo.shape(), expected_shape); + EXPECT_EQ(conc_ooo.num_entries(), 4 * N); +} + +// TODO(ebrevdo): ReduceToDense(R={dim1,dim2,...}, reduce_fn, &output) +// reduce_fn sees slices of resorted values based on generator (dim: DDIMS), and +// slices of resorted indices on generator. + +} // namespace +} // namespace sparse +} // namespace tensorflow diff --git a/tensorflow/core/util/tensor_slice_reader.cc b/tensorflow/core/util/tensor_slice_reader.cc new file mode 100644 index 0000000000..00bc16f105 --- /dev/null +++ b/tensorflow/core/util/tensor_slice_reader.cc @@ -0,0 +1,230 @@ +#include "tensorflow/core/util/tensor_slice_reader.h" + +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/gtl/stl_util.h" +#include "tensorflow/core/lib/io/iterator.h" +#include "tensorflow/core/lib/io/match.h" +#include "tensorflow/core/lib/io/table.h" +#include "tensorflow/core/lib/io/table_options.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/public/env.h" +#include "tensorflow/core/util/saved_tensor_slice_util.h" +#include "tensorflow/core/util/tensor_slice_util.h" + +namespace tensorflow { + +namespace checkpoint { + +TensorSliceReader::Table::~Table() {} + +namespace { +class TensorSliceReaderTable : public TensorSliceReader::Table { + public: + explicit TensorSliceReaderTable(RandomAccessFile* f, table::Table* t) + : file_(f), table_(t) {} + + ~TensorSliceReaderTable() override { + delete table_; + delete file_; + } + + bool Get(const string& key, string* value) override { + std::unique_ptr<table::Iterator> iter(table_->NewIterator()); + iter->Seek(key); + if (iter->Valid() && iter->key() == key) { + StringPiece v = iter->value(); + value->assign(v.data(), v.size()); + return true; + } else { + return false; + } + } + + private: + RandomAccessFile* file_; + table::Table* table_; +}; +} // namespace + +Status OpenTableTensorSliceReader(const string& fname, + TensorSliceReader::Table** result) { + *result = nullptr; + Env* env = Env::Default(); + RandomAccessFile* f = nullptr; + Status s = env->NewRandomAccessFile(fname, &f); + if (s.ok()) { + uint64 file_size; + s = env->GetFileSize(fname, &file_size); + if (s.ok()) { + table::Options options; + table::Table* table; + s = table::Table::Open(options, f, file_size, &table); + if (s.ok()) { + *result = new TensorSliceReaderTable(f, table); + return Status::OK(); + } else { + s = Status(s.code(), + strings::StrCat(s.error_message(), + ": perhaps your file is in a different " + "file format and you need to use a " + "different restore operator?")); + } + } + } + LOG(WARNING) << "Could not open " << fname << ": " << s; + delete f; + return s; +} + +TensorSliceReader::TensorSliceReader(const string& filepattern, + OpenTableFunction open_function) + : TensorSliceReader(filepattern, open_function, kLoadAllShards) {} + +TensorSliceReader::TensorSliceReader(const string& filepattern, + OpenTableFunction open_function, + int preferred_shard) + : filepattern_(filepattern), open_function_(open_function) { + VLOG(1) << "TensorSliceReader for " << filepattern; + Status s = io::GetMatchingFiles(Env::Default(), filepattern, &fnames_); + if (!s.ok()) { + status_ = errors::InvalidArgument( + "Unsuccessful TensorSliceReader constructor: " + "Failed to get matching files on ", + filepattern, ": ", s.ToString()); + return; + } + if (fnames_.empty()) { + status_ = errors::NotFound( + "Unsuccessful TensorSliceReader constructor: " + "Failed to find any matching files for ", + filepattern); + return; + } + sss_.resize(fnames_.size()); + for (size_t shard = 0; shard < fnames_.size(); ++shard) { + fname_to_index_.insert(std::make_pair(fnames_[shard], shard)); + } + if (preferred_shard == kLoadAllShards || fnames_.size() == 1 || + static_cast<size_t>(preferred_shard) >= fnames_.size()) { + LoadAllShards(); + } else { + VLOG(1) << "Loading shard " << preferred_shard << " for " << filepattern_; + LoadShard(preferred_shard); + } +} + +void TensorSliceReader::LoadShard(int shard) const { + CHECK_LT(shard, sss_.size()); + if (sss_[shard] || !status_.ok()) { + return; // Already loaded, or invalid. + } + string value; + SavedTensorSlices sts; + const string fname = fnames_[shard]; + VLOG(1) << "Reading meta data from file " << fname << "..."; + Table* table; + Status s = open_function_(fname, &table); + if (!s.ok()) { + status_ = errors::DataLoss("Unable to open table file ", fname, ": ", + s.ToString()); + return; + } + sss_[shard].reset(table); + if (!(table->Get(kSavedTensorSlicesKey, &value) && + ParseProtoUnlimited(&sts, value))) { + status_ = errors::Internal( + "Failed to find the saved tensor slices at the beginning of the " + "checkpoint file: ", + fname); + return; + } + for (const SavedSliceMeta& ssm : sts.meta().tensor()) { + TensorShape ssm_shape(ssm.shape()); + for (const TensorSliceProto& tsp : ssm.slice()) { + TensorSlice ss_slice(tsp); + RegisterTensorSlice(ssm.name(), ssm_shape, ssm.type(), fname, ss_slice); + } + } +} + +void TensorSliceReader::LoadAllShards() const { + VLOG(1) << "Loading all shards for " << filepattern_; + for (size_t i = 0; i < fnames_.size() && status_.ok(); ++i) { + LoadShard(i); + } + all_shards_loaded_ = true; +} + +const TensorSliceSet* TensorSliceReader::FindTensorSlice( + const string& name, const TensorSlice& slice, + std::vector<std::pair<TensorSlice, string>>* details) const { + const TensorSliceSet* tss = gtl::FindPtrOrNull(tensors_, name); + if (tss && !tss->QueryMeta(slice, details)) { + return nullptr; + } + return tss; +} + +TensorSliceReader::~TensorSliceReader() { gtl::STLDeleteValues(&tensors_); } + +void TensorSliceReader::RegisterTensorSlice(const string& name, + const TensorShape& shape, + DataType type, const string& tag, + const TensorSlice& slice) const { + TensorSliceSet* tss = gtl::FindPtrOrNull(tensors_, name); + // Create a tensor slice set if needed + if (!tss) { + tss = new TensorSliceSet(shape, type); + tensors_.insert(std::make_pair(name, tss)); + } else { + // Check if the shapes match + TensorShape tss_shape(tss->shape()); + if (!shape.IsSameSize(tss_shape)) { + status_ = + errors::Internal("Incompatible tensor shapes detected for tensor ", + name, ": existing = ", tss_shape.DebugString(), + ", new = ", shape.DebugString()); + return; + } + if (type != tss->type()) { + status_ = + errors::Internal("Incompatible tensor types detected for tensor ", + name, ": existing = ", DataTypeString(tss->type()), + ", new = ", DataTypeString(type)); + return; + } + } + // Register the tensor slices without the actual data. + Status s = tss->Register(slice, tag, nullptr); + if (!s.ok()) { + status_ = s; + } +} + +bool TensorSliceReader::HasTensor(const string& name, TensorShape* shape, + DataType* type) const { + mutex_lock l(mu_); + const TensorSliceSet* tss = gtl::FindPtrOrNull(tensors_, name); + if (!tss && !all_shards_loaded_) { + VLOG(1) << "Did not find tensor in preferred shard, loading all shards: " + << name; + LoadAllShards(); + tss = gtl::FindPtrOrNull(tensors_, name); + } + if (tss) { + if (shape) { + *shape = tss->shape(); + } + if (type) { + *type = tss->type(); + } + return true; + } else { + return false; + } +} + +} // namespace checkpoint + +} // namespace tensorflow diff --git a/tensorflow/core/util/tensor_slice_reader.h b/tensorflow/core/util/tensor_slice_reader.h new file mode 100644 index 0000000000..b5f26a689b --- /dev/null +++ b/tensorflow/core/util/tensor_slice_reader.h @@ -0,0 +1,157 @@ +// The utility to read checkpoints for google brain tensor ops and v3 +// checkpoints for dist_belief. +// + +#ifndef TENSORFLOW_UTIL_TENSOR_SLICE_READER_H_ +#define TENSORFLOW_UTIL_TENSOR_SLICE_READER_H_ + +#include <unordered_map> + +#include "tensorflow/core/framework/tensor_slice.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/gtl/map_util.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/public/status.h" +#include "tensorflow/core/public/tensor_shape.h" +#include "tensorflow/core/util/saved_tensor_slice.pb.h" +#include "tensorflow/core/util/saved_tensor_slice_util.h" +#include "tensorflow/core/util/tensor_slice_set.h" +#include "tensorflow/core/util/tensor_slice_util.h" + +namespace tensorflow { + +namespace checkpoint { + +// The reader reads in all the meta data about all the tensor slices. Then it +// will try to read the relevant data on-demand to produce the data for the +// slices needed. +// NOTE(yangke): another way to do this is to first load a list of the tensor +// slices needed and then just selectively read some of the meta data. That +// might optimize the loading but makes the logic a bit more complicated. We +// might want to revisit that. +// TODO(yangke): consider moving to TensorProto. +class TensorSliceReader { + public: + // Abstract interface for reading data out of a tensor slice checkpoint file + class Table { + public: + virtual ~Table(); + virtual bool Get(const string& key, string* value) = 0; + }; + typedef std::function<Status(const string&, Table**)> OpenTableFunction; + + static const int kLoadAllShards = -1; + TensorSliceReader(const string& filepattern, OpenTableFunction open_function); + TensorSliceReader(const string& filepattern, OpenTableFunction open_function, + int preferred_shard); + virtual ~TensorSliceReader(); + + // Get the filename this reader is attached to. + const string& filepattern() const { return filepattern_; } + + // Get the number of files matched. + int num_files() const { return sss_.size(); } + + // Get the status of the reader. + const Status status() const { return status_; } + + // Checks if the reader contains any slice of a tensor. In case the reader + // does contain the tensor, if "shape" is not nullptr, fill "shape" with the + // shape of the tensor; if "type" is not nullptr, fill "type" with the type + // of the tensor. + bool HasTensor(const string& name, TensorShape* shape, DataType* type) const; + + // Checks if the reader contains all the data about a tensor slice, and if + // yes, copies the data of the slice to "data". The caller needs to make sure + // that "data" points to a buffer that holds enough data. + // This is a slow function since it needs to read sstables. + template <typename T> + bool CopySliceData(const string& name, const TensorSlice& slice, + T* data) const; + + // Get the tensors. + const std::unordered_map<string, TensorSliceSet*>& Tensors() const { + return tensors_; + } + + private: + friend class TensorSliceWriteTestHelper; + + void LoadShard(int shard) const; + void LoadAllShards() const; + void RegisterTensorSlice(const string& name, const TensorShape& shape, + DataType type, const string& tag, + const TensorSlice& slice) const; + + const TensorSliceSet* FindTensorSlice( + const string& name, const TensorSlice& slice, + std::vector<std::pair<TensorSlice, string>>* details) const; + + const string filepattern_; + const OpenTableFunction open_function_; + std::vector<string> fnames_; + std::unordered_map<string, int> fname_to_index_; + + // Guards the attributes below. + mutable mutex mu_; + mutable bool all_shards_loaded_ = false; + mutable std::vector<std::unique_ptr<Table>> sss_; + mutable std::unordered_map<string, TensorSliceSet*> tensors_; + mutable Status status_; + + TF_DISALLOW_COPY_AND_ASSIGN(TensorSliceReader); +}; + +Status OpenTableTensorSliceReader(const string& fname, + TensorSliceReader::Table** table); + +template <typename T> +bool TensorSliceReader::CopySliceData(const string& name, + const TensorSlice& slice, T* data) const { + std::vector<std::pair<TensorSlice, string>> details; + const TensorSliceSet* tss; + { + mutex_lock l(mu_); + tss = FindTensorSlice(name, slice, &details); + if (!tss && !all_shards_loaded_) { + VLOG(1) << "Did not find slice in preferred shard, loading all shards." + << name << ": " << slice.DebugString(); + LoadAllShards(); + tss = FindTensorSlice(name, slice, &details); + } + if (!tss) { + // No such tensor + return false; + } + } + // We have the data -- copy it over. + string value; + for (const auto& x : details) { + const TensorSlice& slice_s = x.first; + const string& fname = x.second; + int idx = gtl::FindWithDefault(fname_to_index_, fname, -1); + CHECK_GE(idx, 0) << "Failed to find the index for filename " << fname; + // We read a record in the corresponding sstable + const string key = EncodeTensorNameSlice(name, slice_s); + CHECK(sss_[idx]->Get(key, &value)) + << "Failed to seek to the record for tensor " << name << ", slice " + << slice_s.DebugString() << ": computed key = " << key; + SavedTensorSlices sts; + CHECK(ParseProtoUnlimited(&sts, value)) + << "Failed to parse the record for tensor " << name << ", slice " + << slice_s.DebugString() << ": computed key = " << key; + CopyDataFromTensorSliceToTensorSlice( + tss->shape(), slice_s, slice, + checkpoint::TensorProtoData<T>(sts.data().data()), data); + } + return true; +} + +} // namespace checkpoint + +} // namespace tensorflow + +#endif // TENSORFLOW_UTIL_TENSOR_SLICE_READER_H_ diff --git a/tensorflow/core/util/tensor_slice_reader_cache.cc b/tensorflow/core/util/tensor_slice_reader_cache.cc new file mode 100644 index 0000000000..af81d0115e --- /dev/null +++ b/tensorflow/core/util/tensor_slice_reader_cache.cc @@ -0,0 +1,94 @@ +#include "tensorflow/core/util/tensor_slice_reader_cache.h" + +#include "tensorflow/core/lib/gtl/stl_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { + +namespace checkpoint { + +TensorSliceReaderCacheWrapper::TensorSliceReaderCacheWrapper() {} +TensorSliceReaderCacheWrapper::~TensorSliceReaderCacheWrapper() { + if (cache_) { + delete cache_; + } + cache_ = nullptr; +} + +const TensorSliceReader* TensorSliceReaderCacheWrapper::GetReader( + const string& filepattern, + TensorSliceReader::OpenTableFunction open_function, + int preferred_shard) const { + mutex_lock l(mu_); + if (!cache_) { + cache_ = new TensorSliceReaderCache; + } + return cache_->GetReader(filepattern, open_function, preferred_shard); +} + +TensorSliceReaderCache::TensorSliceReaderCache() {} + +TensorSliceReaderCache::~TensorSliceReaderCache() { + for (auto pair : readers_) { + delete pair.second.second; + } +} + +const TensorSliceReader* TensorSliceReaderCache::GetReader( + const string& filepattern, + TensorSliceReader::OpenTableFunction open_function, int preferred_shard) { + mutex_lock l(mu_); + + // Get the function pointer from the open_function value. + TensorSliceReaderCache::OpenFuncType* func_ptr = + open_function.target<TensorSliceReaderCache::OpenFuncType>(); + if (!func_ptr) { + // We could not get the pointer, no caching is possible. + LOG(WARNING) << "Caching disabled because the open function is a lambda."; + return nullptr; + } + + // Wait if another thread is already trying to open the same files. + while (still_opening_.find(filepattern) != still_opening_.end()) { + cv_.wait(l); + } + + TensorSliceReader* reader = nullptr; + if (readers_.find(filepattern) == readers_.end()) { + VLOG(1) << "Creating new TensorSliceReader for " << filepattern; + still_opening_.insert(filepattern); + // Release the lock temporary as constructing TensorSliceReader is + // expensive. + mu_.unlock(); + TensorSliceReader* tmp_reader( + new TensorSliceReader(filepattern, open_function, preferred_shard)); + // Acquire the lock again. + mu_.lock(); + if (tmp_reader->status().ok()) { + reader = tmp_reader; + readers_[filepattern] = make_pair(*func_ptr, reader); + } else { + delete tmp_reader; + } + CHECK_EQ(1, still_opening_.erase(filepattern)); + VLOG(1) << "Cached TensorSliceReader for " << filepattern << ": " << reader; + } else { + auto cached_val = readers_[filepattern]; + if (cached_val.first == *func_ptr) { + reader = cached_val.second; + VLOG(1) << "Using cached TensorSliceReader for " << filepattern << ": " + << reader; + } else { + LOG(WARNING) << "Caching disabled because the checkpoint file " + << "is being opened with two different open functions: " + << filepattern; + } + } + + cv_.notify_all(); + return reader; +} + +} // namespace checkpoint + +} // namespace tensorflow diff --git a/tensorflow/core/util/tensor_slice_reader_cache.h b/tensorflow/core/util/tensor_slice_reader_cache.h new file mode 100644 index 0000000000..eaeeeec83f --- /dev/null +++ b/tensorflow/core/util/tensor_slice_reader_cache.h @@ -0,0 +1,73 @@ +// The utility to read checkpoints for google brain tensor ops and v3 +// checkpoints for dist_belief. +// + +#ifndef TENSORFLOW_UTIL_TENSOR_SLICE_READER_CACHE_H_ +#define TENSORFLOW_UTIL_TENSOR_SLICE_READER_CACHE_H_ + +#include <unordered_map> + +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/util/tensor_slice_reader.h" +#include "tensorflow/core/public/status.h" + +namespace tensorflow { + +namespace checkpoint { + +class TensorSliceReaderCache; + +// Wrapper to a lazily allocated TensorSliceReaderCache. +class TensorSliceReaderCacheWrapper { + public: + TensorSliceReaderCacheWrapper(); + ~TensorSliceReaderCacheWrapper(); + + // Same as TensorSliceReaderCache::GetReader(). + const TensorSliceReader* GetReader( + const string& filepattern, + TensorSliceReader::OpenTableFunction open_function, + int preferred_shard) const; + + private: + mutable mutex mu_; + mutable TensorSliceReaderCache* cache_ = nullptr; +}; + +// A cache of TensorSliceReaders. +class TensorSliceReaderCache { + public: + TensorSliceReaderCache(); + ~TensorSliceReaderCache(); + + // Returns the TensorSliceReader corresponding to 'filepattern' and the + // open_function. May return nullptr if we can not create a new + // TensorSliceReader for the filepattern/open_function combination. + const TensorSliceReader* GetReader( + const string& filepattern, + TensorSliceReader::OpenTableFunction open_function, int preferred_shard); + + private: + // Need to use a regular function type in the key map as std::function does + // not support ==. + typedef Status (*OpenFuncType)(const string&, TensorSliceReader::Table**); + + // Protects attributes below. + mutex mu_; + + // Maps of opened readers. + std::unordered_map<string, std::pair<OpenFuncType, TensorSliceReader*>> + readers_; + + // Set of keys that a previous GetReader() call is still trying to populate. + std::set<string> still_opening_; + + // Condition variable to notify when a reader has been created. + condition_variable cv_; +}; + +} // namespace checkpoint + +} // namespace tensorflow + +#endif // TENSORFLOW_UTIL_TENSOR_SLICE_READER_CACHE_H_ diff --git a/tensorflow/core/util/tensor_slice_reader_test.cc b/tensorflow/core/util/tensor_slice_reader_test.cc new file mode 100644 index 0000000000..e14b920003 --- /dev/null +++ b/tensorflow/core/util/tensor_slice_reader_test.cc @@ -0,0 +1,395 @@ +#include "tensorflow/core/util/tensor_slice_reader.h" + +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/util/saved_tensor_slice_util.h" +#include "tensorflow/core/util/tensor_slice_writer.h" +#include "tensorflow/core/util/tensor_slice_reader_cache.h" +#include <gtest/gtest.h> + +namespace tensorflow { + +namespace checkpoint { + +namespace { + +// A simple test where we write a few tensor slices with a number of tensor +// slice writers and then read them back from a tensor slice reader. +// +// We have a 2-d tensor of shape 4 X 5 that looks like this: +// +// 0 1 2 3 4 +// 5 6 7 8 9 +// 10 11 12 13 14 +// 15 16 17 18 19 +// +// We assume this is a row-major matrix. + +void SimpleFloatHelper(TensorSliceWriter::CreateBuilderFunction create_function, + TensorSliceReader::OpenTableFunction open_function) { + const string fname_base = io::JoinPath(testing::TmpDir(), "float_checkpoint"); + + TensorShape shape({4, 5}); + + // File #0 contains a slice that is the top two rows: + // + // 0 1 2 3 4 + // 5 6 7 8 9 + // . . . . . + // . . . . . + { + const string fname = strings::StrCat(fname_base, "_0"); + TensorSliceWriter writer(fname, create_function); + const float data[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; + TensorSlice slice = TensorSlice::ParseOrDie("0,2:-"); + TF_CHECK_OK(writer.Add("test", shape, slice, data)); + TF_CHECK_OK(writer.Finish()); + } + + // File #1 contains two slices: + // + // slice #0 is the bottom left corner + // . . . . . + // . . . . . + // 10 11 12 . . + // 15 16 17 . . + // + // slice #1 is the bottom right corner + // . . . . . + // . . . . . + // . . . . . + // . . . 18 19 + { + const string fname = strings::StrCat(fname_base, "_1"); + TensorSliceWriter writer(fname, create_function); + // slice #0 + { + const float data[] = {10, 11, 12, 15, 16, 17}; + TensorSlice slice = TensorSlice::ParseOrDie("2,2:0,3"); + TF_CHECK_OK(writer.Add("test", shape, slice, data)); + } + // slice #1 + { + const float data[] = {18, 19}; + TensorSlice slice = TensorSlice::ParseOrDie("3,1:3,2"); + TF_CHECK_OK(writer.Add("test", shape, slice, data)); + } + TF_CHECK_OK(writer.Finish()); + } + + // Notice that we leave a hole in the tensor + // . . . . . + // . . . . . + // . . . (13) (14) + // . . . . . + + // Now we need to read the tensor slices + const string filepattern = strings::StrCat(fname_base, "_*"); + TensorSliceReader reader(filepattern, open_function); + EXPECT_OK(reader.status()); + EXPECT_EQ(2, reader.num_files()); + + // We query some of the tensors + { + TensorShape shape; + DataType type; + EXPECT_TRUE(reader.HasTensor("test", &shape, &type)); + EXPECT_EQ( + "dim { size: 4 } " + "dim { size: 5 }", + shape.DebugString()); + EXPECT_EQ(DT_FLOAT, type); + EXPECT_FALSE(reader.HasTensor("don't exist", nullptr, nullptr)); + } + + // Now we query some slices + // + // Slice #1 is an exact match + // 0 1 2 3 4 + // 5 6 7 8 9 + // . . . . . + // . . . . . + { + TensorSlice s = TensorSlice::ParseOrDie("0,2:-"); + float expected[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; + float results[10]; + EXPECT_TRUE(reader.CopySliceData("test", s, results)); + for (int i = 0; i < 10; ++i) { + EXPECT_EQ(expected[i], results[i]); + } + } + + // Slice #2 is a subset match + // . . . . . + // 5 6 7 8 9 + // . . . . . + // . . . . . + { + TensorSlice s = TensorSlice::ParseOrDie("1,1:-"); + float expected[] = {5, 6, 7, 8, 9}; + float results[5]; + EXPECT_TRUE(reader.CopySliceData("test", s, results)); + for (int i = 0; i < 5; ++i) { + EXPECT_EQ(expected[i], results[i]); + } + } + + // Slice #4 includes the hole and so there is no match + // . . . . . + // . . 7 8 9 + // . . 12 13 14 + // . . . . . + { + TensorSlice s = TensorSlice::ParseOrDie("1,2:2,3"); + float results[6]; + EXPECT_FALSE(reader.CopySliceData("test", s, results)); + } +} + +TEST(TensorSliceReaderTest, SimpleFloat) { + SimpleFloatHelper(CreateTableTensorSliceBuilder, OpenTableTensorSliceReader); +} + +template <typename T, typename U> +void SimpleIntXHelper(TensorSliceWriter::CreateBuilderFunction create_function, + TensorSliceReader::OpenTableFunction open_function, + const string& checkpoint_file) { + const string fname_base = io::JoinPath(testing::TmpDir(), checkpoint_file); + + TensorShape shape({4, 5}); + + // File #0 contains a slice that is the top two rows: + // + // 0 1 2 3 4 + // 5 6 7 8 9 + // . . . . . + // . . . . . + { + const string fname = strings::StrCat(fname_base, "_0"); + TensorSliceWriter writer(fname, create_function); + const T data[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; + TensorSlice slice = TensorSlice::ParseOrDie("0,2:-"); + TF_CHECK_OK(writer.Add("test", shape, slice, data)); + TF_CHECK_OK(writer.Finish()); + } + + // File #1 contains two slices: + // + // slice #0 is the bottom left corner + // . . . . . + // . . . . . + // 10 11 12 . . + // 15 16 17 . . + // + // slice #1 is the bottom right corner + // . . . . . + // . . . . . + // . . . . . + // . . . 18 19 + { + const string fname = strings::StrCat(fname_base, "_1"); + TensorSliceWriter writer(fname, create_function); + // slice #0 + { + const T data[] = {10, 11, 12, 15, 16, 17}; + TensorSlice slice = TensorSlice::ParseOrDie("2,2:0,3"); + TF_CHECK_OK(writer.Add("test", shape, slice, data)); + } + // slice #1 + { + const T data[] = {18, 19}; + TensorSlice slice = TensorSlice::ParseOrDie("3,1:3,2"); + TF_CHECK_OK(writer.Add("test", shape, slice, data)); + } + TF_CHECK_OK(writer.Finish()); + } + + // Notice that we leave a hole in the tensor + // . . . . . + // . . . . . + // . . . (13) (14) + // . . . . . + + // Now we need to read the tensor slices + const string filepattern = strings::StrCat(fname_base, "_*"); + TensorSliceReader reader(filepattern, open_function); + EXPECT_OK(reader.status()); + EXPECT_EQ(2, reader.num_files()); + + // We query some of the tensors + { + TensorShape shape; + DataType type; + EXPECT_TRUE(reader.HasTensor("test", &shape, &type)); + EXPECT_EQ( + "dim { size: 4 } " + "dim { size: 5 }", + shape.DebugString()); + EXPECT_EQ(DataTypeToEnum<T>::v(), type); + EXPECT_FALSE(reader.HasTensor("don't exist", nullptr, nullptr)); + } + + // Now we query some slices + // + // Slice #1 is an exact match + // 0 1 2 3 4 + // 5 6 7 8 9 + // . . . . . + // . . . . . + { + TensorSlice s = TensorSlice::ParseOrDie("0,2:-"); + T expected[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; + U results[10]; + EXPECT_TRUE(reader.CopySliceData("test", s, results)); + for (int i = 0; i < 10; ++i) { + EXPECT_EQ(expected[i], results[i]); + } + } + + // Slice #2 is a subset match + // . . . . . + // 5 6 7 8 9 + // . . . . . + // . . . . . + { + TensorSlice s = TensorSlice::ParseOrDie("1,1:-"); + T expected[] = {5, 6, 7, 8, 9}; + U results[5]; + EXPECT_TRUE(reader.CopySliceData("test", s, results)); + for (int i = 0; i < 5; ++i) { + EXPECT_EQ(expected[i], results[i]); + } + } + + // Slice #4 includes the hole and so there is no match + // . . . . . + // . . 7 8 9 + // . . 12 13 14 + // . . . . . + { + TensorSlice s = TensorSlice::ParseOrDie("1,2:2,3"); + U results[6]; + EXPECT_FALSE(reader.CopySliceData("test", s, results)); + } +} + +#define TEST_SIMPLE_INT(TYPE, SAVED_TYPE) \ + TEST(TensorSliceReaderTest, Simple##TYPE) { \ + SimpleIntXHelper<TYPE, SAVED_TYPE>(CreateTableTensorSliceBuilder, \ + OpenTableTensorSliceReader, \ + #TYPE "_checkpoint"); \ + } + +TEST_SIMPLE_INT(int32, int32) +TEST_SIMPLE_INT(int64, int64) +TEST_SIMPLE_INT(int16, int32) +TEST_SIMPLE_INT(int8, int32) +TEST_SIMPLE_INT(uint8, int32) + +void CachedTensorSliceReaderTesterHelper( + TensorSliceWriter::CreateBuilderFunction create_function, + TensorSliceReader::OpenTableFunction open_function) { + const string fname_base = io::JoinPath(testing::TmpDir(), "float_checkpoint"); + + TensorShape shape({4, 5}); + + // File #0 contains a slice that is the top two rows: + // + // 0 1 2 3 4 + // 5 6 7 8 9 + // . . . . . + // . . . . . + { + const string fname = strings::StrCat(fname_base, "_0"); + TensorSliceWriter writer(fname, create_function); + const float data[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; + TensorSlice slice = TensorSlice::ParseOrDie("0,2:-"); + TF_CHECK_OK(writer.Add("test", shape, slice, data)); + TF_CHECK_OK(writer.Finish()); + } + + // File #1 contains two slices: + // + // slice #0 is the bottom left corner + // . . . . . + // . . . . . + // 10 11 12 . . + // 15 16 17 . . + // + // slice #1 is the bottom right corner + // . . . . . + // . . . . . + // . . . . . + // . . . 18 19 + { + const string fname = strings::StrCat(fname_base, "_1"); + TensorSliceWriter writer(fname, create_function); + // slice #0 + { + const float data[] = {10, 11, 12, 15, 16, 17}; + TensorSlice slice = TensorSlice::ParseOrDie("2,2:0,3"); + TF_CHECK_OK(writer.Add("test", shape, slice, data)); + } + // slice #1 + { + const float data[] = {18, 19}; + TensorSlice slice = TensorSlice::ParseOrDie("3,1:3,2"); + TF_CHECK_OK(writer.Add("test", shape, slice, data)); + } + TF_CHECK_OK(writer.Finish()); + } + + // Notice that we leave a hole in the tensor + // . . . . . + // . . . . . + // . . . (13) (14) + // . . . . . + + // Now we need to read the tensor slices + TensorSliceReaderCache cache; + const string filepattern = strings::StrCat(fname_base, "_*"); + const TensorSliceReader* reader = cache.GetReader( + filepattern, open_function, TensorSliceReader::kLoadAllShards); + EXPECT_TRUE(reader != nullptr); + EXPECT_EQ(2, reader->num_files()); + + // We query some of the tensors + { + TensorShape shape; + DataType type; + EXPECT_TRUE(reader->HasTensor("test", &shape, &type)); + EXPECT_EQ( + "dim { size: 4 } " + "dim { size: 5 }", + shape.DebugString()); + EXPECT_EQ(DT_FLOAT, type); + EXPECT_FALSE(reader->HasTensor("don't exist", nullptr, nullptr)); + } + + // Make sure the reader is cached. + const TensorSliceReader* reader2 = cache.GetReader( + filepattern, open_function, TensorSliceReader::kLoadAllShards); + EXPECT_EQ(reader, reader2); + + reader = cache.GetReader("file_does_not_exist", open_function, + TensorSliceReader::kLoadAllShards); + EXPECT_TRUE(reader == nullptr); +} + +TEST(CachedTensorSliceReaderTest, SimpleFloat) { + CachedTensorSliceReaderTesterHelper(CreateTableTensorSliceBuilder, + OpenTableTensorSliceReader); +} + +} // namespace + +} // namespace checkpoint + +} // namespace tensorflow diff --git a/tensorflow/core/util/tensor_slice_set.cc b/tensorflow/core/util/tensor_slice_set.cc new file mode 100644 index 0000000000..765686f189 --- /dev/null +++ b/tensorflow/core/util/tensor_slice_set.cc @@ -0,0 +1,148 @@ +#include "tensorflow/core/util/tensor_slice_set.h" + +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/util/tensor_slice_util.h" +#include "tensorflow/core/lib/gtl/map_util.h" + +namespace tensorflow { + +namespace checkpoint { + +TensorSliceSet::TensorSliceSet(const TensorShape& shape, DataType type) + : shape_(shape), type_(type) {} + +TensorSliceSet::~TensorSliceSet() {} + +Status TensorSliceSet::Register(const TensorSlice& slice, + const string& tag, const float* data) { + TensorShape result_shape; + TF_RETURN_IF_ERROR(slice.SliceTensorShape(shape_, &result_shape)); + string str = slice.DebugString(); + // We check if there is any intersection between this slice and any of the + // registered slices. + for (const auto x : slices_) { + if (slice.Overlaps(x.second.slice)) { + return errors::Internal("Overlapping slices: existing slice = ", x.first, + ", new slice = ", str); + } + } + // No overlap: we can now insert the slice + TensorSliceSet::SliceInfo info = {slice, tag, data, + result_shape.num_elements()}; + slices_.insert(std::make_pair(str, info)); + return Status::OK(); +} + +// TODO(yangke): merge Query() with QueryMeta() +bool TensorSliceSet::Query(const TensorSlice& slice, float* data) const { + Status s; + string str = slice.DebugString(); + // First we check if there is an exactly match (this is the dominant case). + const TensorSliceSet::SliceInfo* info = gtl::FindOrNull(slices_, str); + if (info) { + if (data) { + std::copy_n(info->data, info->num_floats, data); + } + return true; + } else { + // We didn't find any exact match but there is still a posibility that + // mutliple existing slices can be patched together to output the slice. + // We figure this out by computing the intersection of each of the existing + // slices with the query slice, and check if the union of all these + // intersections cover the entire slice. We rely on the fact that the + // existing slices don't have any intersection among themselves. + TensorShape target_shape; + Status s; + s = slice.SliceTensorShape(shape_, &target_shape); + if (!s.ok()) { + LOG(WARNING) << s; + return false; + } + int64 total_size = target_shape.num_elements(); + + int64 overlap_size = 0; + TensorSlice intersection; + TensorShape inter_shape; + for (const auto x : slices_) { + if (slice.Intersect(x.second.slice, &intersection)) { + s = intersection.SliceTensorShape(shape_, &inter_shape); + if (!s.ok()) { + LOG(WARNING) << s; + return false; + } + overlap_size += inter_shape.num_elements(); + } + } + if (total_size == overlap_size) { + // We have it! + // Now we need to copy the data to "data" + if (data) { + for (const auto x : slices_) { + CopyDataFromTensorSliceToTensorSlice(shape_, x.second.slice, slice, + x.second.data, data); + } + } + return true; + } else { + // We don't have all the data for the asked tensor slice + return false; + } + } +} + +bool TensorSliceSet::QueryMeta( + const TensorSlice& slice, + std::vector<std::pair<TensorSlice, string>>* results) const { + results->clear(); + Status s; + string str = slice.DebugString(); + // First we check if there is an exactly match (this is the dominant case). + const TensorSliceSet::SliceInfo* info = gtl::FindOrNull(slices_, str); + if (info) { + results->emplace_back(std::make_pair(info->slice, info->tag)); + return true; + } else { + // We didn't find any exact match but there is still a posibility that + // multiple existing slices can be patched together to output the slice. + // We figure this out by computing the intersection of each of the existing + // slices with the query slice, and check if the union of all these + // intersections cover the entire slice. We rely on the fact that the + // existing slices don't have any intersection among themselves. + TensorShape target_shape; + Status s; + s = slice.SliceTensorShape(shape_, &target_shape); + if (!s.ok()) { + LOG(WARNING) << s; + return false; + } + int64 total_size = target_shape.num_elements(); + + int64 overlap_size = 0; + TensorSlice intersection; + TensorShape inter_shape; + for (const auto x : slices_) { + if (slice.Intersect(x.second.slice, &intersection)) { + s = intersection.SliceTensorShape(shape_, &inter_shape); + if (!s.ok()) { + LOG(WARNING) << s; + return false; + } + overlap_size += inter_shape.num_elements(); + results->emplace_back(std::make_pair(x.second.slice, x.second.tag)); + } + } + if (total_size == overlap_size) { + // We have it! + return true; + } else { + // We don't have all the data for the asked tensor slice + results->clear(); + return false; + } + } +} + +} // namespace checkpoint + +} // namespace tensorflow diff --git a/tensorflow/core/util/tensor_slice_set.h b/tensorflow/core/util/tensor_slice_set.h new file mode 100644 index 0000000000..f3f7ac0e76 --- /dev/null +++ b/tensorflow/core/util/tensor_slice_set.h @@ -0,0 +1,73 @@ +// A class to manage slices of a tensor. You can "register" set of slices for a +// tensor and then "query" if we have data for a given slice. + +// TODO(yangke): consider moving it to a more private place so that we don't +// need to expose the API. + +#ifndef TENSORFLOW_UTIL_TENSOR_SLICE_SET_H_ +#define TENSORFLOW_UTIL_TENSOR_SLICE_SET_H_ + +#include <string> // for string +#include <unordered_map> + +#include "tensorflow/core/platform/port.h" // for int64 +#include "tensorflow/core/framework/tensor_slice.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/public/tensor_shape.h" +#include "tensorflow/core/lib/core/stringpiece.h" // for StringPiece +#include "tensorflow/core/public/status.h" // for Status + +namespace tensorflow { + +namespace checkpoint { + +class TensorSliceSet { + public: + TensorSliceSet(const TensorShape& shape, DataType type); + virtual ~TensorSliceSet(); + + const TensorShape& shape() const { return shape_; } + const DataType type() const { return type_; } + + // Register a new slice for the tensor. The "tag" is an arbitrary string + // associated with the slice (in one application it denotes the name of the + // file that contains the slice); the "data" points to the data of the tensor + // slice (it can be a nullptr). + // We don't take the ownership of "data" and the caller needs to make sure + // the data is always available during the life time of the tensor slice set + // if it is not nullptr. + Status Register(const TensorSlice& slice, const string& tag, + const float* data); + + // Query about a new slice: checks if we have data for "slice" and if we have + // the data and "data" is not nullptr, fill "data" with the slice data. The + // caller needs to make sure "data" point to a large eough buffer. + // TODO(yangke): avoid unnecessary copying by using a core::RefCounted + // pointer. + bool Query(const TensorSlice& slice, float* data) const; + + // Alternative way of querying about a new slice: instead of copying the + // data, it returns a list of meta data about the stored slices that will + // supply data for the slice. + bool QueryMeta( + const TensorSlice& slice, + std::vector<std::pair<tensorflow::TensorSlice, string>>* results) const; + + private: + const TensorShape shape_; + const DataType type_; + struct SliceInfo { + TensorSlice slice; + const string tag; + const float* data; + int64 num_floats; + }; + // We maintain a mapping from the slice string to the slice information. + std::unordered_map<string, SliceInfo> slices_; +}; + +} // namespace checkpoint + +} // namespace tensorflow + +#endif // TENSORFLOW_UTIL_TENSOR_SLICE_SET_H_ diff --git a/tensorflow/core/util/tensor_slice_set_test.cc b/tensorflow/core/util/tensor_slice_set_test.cc new file mode 100644 index 0000000000..fb2f46f34c --- /dev/null +++ b/tensorflow/core/util/tensor_slice_set_test.cc @@ -0,0 +1,227 @@ +#include "tensorflow/core/util/tensor_slice_set.h" + +#include "tensorflow/core/platform/logging.h" +#include <gtest/gtest.h> +#include "tensorflow/core/public/status.h" + +namespace tensorflow { + +namespace checkpoint { + +namespace { + +// A simple test: we have a 2-d tensor of shape 4 X 5 that looks like this: +// +// 0 1 2 3 4 +// 5 6 7 8 9 +// 10 11 12 13 14 +// 15 16 17 18 19 +// +// We assume this is a row-major matrix. +// +// We store the tensor in a couple of slices and verify that we can recover all +// of them. +TEST(TensorSliceSetTest, QueryTwoD) { + TensorShape shape({4, 5}); + + TensorSliceSet tss(shape, DT_FLOAT); + // We store a few slices. + + // Slice #1 is the top two rows: + // 0 1 2 3 4 + // 5 6 7 8 9 + // . . . . . + // . . . . . + const float src_1[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; + TensorSlice slice_1 = TensorSlice::ParseOrDie("0,2:-"); + TF_CHECK_OK(tss.Register(slice_1, "", src_1)); + + // Slice #2 is the bottom left corner + // . . . . . + // . . . . . + // 10 11 12 . . + // 15 16 17 . . + const float src_2[] = {10, 11, 12, 15, 16, 17}; + TensorSlice slice_2 = TensorSlice::ParseOrDie("2,2:0,3"); + TF_CHECK_OK(tss.Register(slice_2, "", src_2)); + + // Slice #3 is the bottom right corner + // . . . . . + // . . . . . + // . . . . . + // . . . 18 19 + const float src_3[] = {18, 19}; + TensorSlice slice_3 = TensorSlice::ParseOrDie("3,1:3,2"); + TF_CHECK_OK(tss.Register(slice_3, "", src_3)); + + // Notice that we leave a hole in the tensor + // . . . . . + // . . . . . + // . . . (13) (14) + // . . . . . + + // Now we query some of the slices + + // Slice #1 is an exact match + // 0 1 2 3 4 + // 5 6 7 8 9 + // . . . . . + // . . . . . + { + TensorSlice s = TensorSlice::ParseOrDie("0,2:-"); + float expected[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; + float results[10]; + EXPECT_TRUE(tss.Query(s, results)); + for (int i = 0; i < 10; ++i) { + EXPECT_EQ(expected[i], results[i]); + } + } + + // Slice #2 is a subset match + // . . . . . + // 5 6 7 8 9 + // . . . . . + // . . . . . + { + TensorSlice s = TensorSlice::ParseOrDie("1,1:-"); + float expected[] = {5, 6, 7, 8, 9}; + float results[5]; + EXPECT_TRUE(tss.Query(s, results)); + for (int i = 0; i < 5; ++i) { + EXPECT_EQ(expected[i], results[i]); + } + } + + // Slice #3 is a more complicated match: it needs the combination of a couple + // of slices + // . . . . . + // 5 6 7 . . + // 10 11 12 . . + // . . . . . + { + TensorSlice s = TensorSlice::ParseOrDie("1,2:0,3"); + float expected[] = {5, 6, 7, 10, 11, 12}; + float results[6]; + EXPECT_TRUE(tss.Query(s, results)); + for (int i = 0; i < 6; ++i) { + EXPECT_EQ(expected[i], results[i]); + } + } + + // Slice #4 includes the hole and so there is no match + // . . . . . + // . . 7 8 9 + // . . 12 13 14 + // . . . . . + { + TensorSlice s = TensorSlice::ParseOrDie("1,2:2,3"); + float results[6]; + EXPECT_FALSE(tss.Query(s, results)); + } +} + +// Testing the meta version of the tensor slice set. +TEST(TensorSliceSetTest, QueryMetaTwoD) { + TensorShape shape({4, 5}); + + TensorSliceSet tss(shape, DT_INT32); + // We store a few slices. + + // Slice #1 is the top two rows: + // 0 1 2 3 4 + // 5 6 7 8 9 + // . . . . . + // . . . . . + TensorSlice slice_1 = TensorSlice::ParseOrDie("0,2:-"); + TF_CHECK_OK(tss.Register(slice_1, "slice_1", nullptr)); + + // Slice #2 is the bottom left corner + // . . . . . + // . . . . . + // 10 11 12 . . + // 15 16 17 . . + TensorSlice slice_2 = TensorSlice::ParseOrDie("2,2:0,3"); + TF_CHECK_OK(tss.Register(slice_2, "slice_2", nullptr)); + + // Slice #3 is the bottom right corner + // . . . . . + // . . . . . + // . . . . . + // . . . 18 19 + TensorSlice slice_3 = TensorSlice::ParseOrDie("3,1:3,2"); + TF_CHECK_OK(tss.Register(slice_3, "slice_3", nullptr)); + + // Notice that we leave a hole in the tensor + // . . . . . + // . . . . . + // . . . (13) (14) + // . . . . . + + // Now we query some of the slices + + // Slice #1 is an exact match + // 0 1 2 3 4 + // 5 6 7 8 9 + // . . . . . + // . . . . . + // We just need slice_1 for this + { + TensorSlice s = TensorSlice::ParseOrDie("0,2:-"); + std::vector<std::pair<TensorSlice, string>> results; + EXPECT_TRUE(tss.QueryMeta(s, &results)); + EXPECT_EQ(1, results.size()); + EXPECT_EQ("0,2:-", results[0].first.DebugString()); + EXPECT_EQ("slice_1", results[0].second); + } + + // Slice #2 is a subset match + // . . . . . + // 5 6 7 8 9 + // . . . . . + // . . . . . + // We just need slice_1 for this + { + TensorSlice s = TensorSlice::ParseOrDie("1,1:-"); + std::vector<std::pair<TensorSlice, string>> results; + EXPECT_TRUE(tss.QueryMeta(s, &results)); + EXPECT_EQ(1, results.size()); + EXPECT_EQ("0,2:-", results[0].first.DebugString()); + EXPECT_EQ("slice_1", results[0].second); + } + + // Slice #3 is a more complicated match: it needs the combination of a couple + // of slices + // . . . . . + // 5 6 7 . . + // 10 11 12 . . + // . . . . . + // We need both slice_1 and slice_2 for this. + { + TensorSlice s = TensorSlice::ParseOrDie("1,2:0,3"); + std::vector<std::pair<TensorSlice, string>> results; + EXPECT_TRUE(tss.QueryMeta(s, &results)); + EXPECT_EQ(2, results.size()); + EXPECT_EQ("2,2:0,3", results[0].first.DebugString()); + EXPECT_EQ("slice_2", results[0].second); + EXPECT_EQ("0,2:-", results[1].first.DebugString()); + EXPECT_EQ("slice_1", results[1].second); + } + + // Slice #4 includes the hole and so there is no match + // . . . . . + // . . 7 8 9 + // . . 12 13 14 + // . . . . . + { + TensorSlice s = TensorSlice::ParseOrDie("1,2:2,3"); + std::vector<std::pair<TensorSlice, string>> results; + EXPECT_FALSE(tss.QueryMeta(s, &results)); + EXPECT_EQ(0, results.size()); + } +} + +} // namespace + +} // namespace checkpoint + +} // namespace tensorflow diff --git a/tensorflow/core/util/tensor_slice_util.h b/tensorflow/core/util/tensor_slice_util.h new file mode 100644 index 0000000000..5422c3bef3 --- /dev/null +++ b/tensorflow/core/util/tensor_slice_util.h @@ -0,0 +1,88 @@ +#ifndef TENSORFLOW_UTIL_TENSOR_SLICE_UTIL_H_ +#define TENSORFLOW_UTIL_TENSOR_SLICE_UTIL_H_ + +#include "tensorflow/core/framework/tensor_slice.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/public/tensor_shape.h" + +namespace tensorflow { + +// Some hackery to invoke eigen tensor to copy over tensor slices with variable +// dimension tensors. +// TODO(yangke): get rid of that once the variable dimension tensor support is +// in. +static const int kTensorSliceMaxRank = 8; + +// Create a tensor map with the given shape: we support up to 8 dimensions. If +// the shape has less than 8 dimensions, we pad the remaining dimension with 1. +template <typename T> +Eigen::TensorMap<Eigen::Tensor<T, kTensorSliceMaxRank, Eigen::RowMajor>> +GetEigenTensorMapFromTensorShape(const TensorShape& shape, T* data) { + Eigen::DSizes<Eigen::DenseIndex, kTensorSliceMaxRank> dsizes = + shape.AsEigenDSizesWithPadding<kTensorSliceMaxRank>(); + Eigen::TensorMap<Eigen::Tensor<T, kTensorSliceMaxRank, Eigen::RowMajor>> eig( + data, dsizes); + return eig; +} + +// Given a tensor described by "shape", two slices "slice_s" and "slice_d", +// and two pointers "ptr_s" and "ptr_d", where "ptr_s" points to a chunk of +// memory that stores the data for "slice_s" and "ptr_d" points to a chunk of +// memory that stores the data for "slice_d". This function copies the data +// that belongs to the intersection of the two slices from slice_s to +// slice_d. Uses Tensor cast<DstT>() to convert from SrcT to DstT. Returns true +// iff the two slices share any intersection (and thus some data is copied). +// TODO(yangke): figure out if we can make it private. +template <typename SrcT, typename DstT> +static bool CopyDataFromTensorSliceToTensorSlice(const TensorShape& shape, + const TensorSlice& slice_s, + const TensorSlice& slice_d, + const SrcT* ptr_s, + DstT* ptr_d) { + CHECK_LE(shape.dims(), kTensorSliceMaxRank) << "Only tensors of size up to " + << kTensorSliceMaxRank + << " are supported"; + // We need to compute the intersection of the two slices. + TensorSlice inter; + if (!slice_s.Intersect(slice_d, &inter)) { + // There is no intersection: returns false. + return false; + } else { + // We need to compute the applied shapes after applying slice_s and + // slice_d. + TensorShape shp_s, shp_d; + Status s; + s = slice_s.SliceTensorShape(shape, &shp_s); + if (!s.ok()) { + LOG(WARNING) << s; + return false; + } + s = slice_d.SliceTensorShape(shape, &shp_d); + if (!s.ok()) { + LOG(WARNING) << s; + return false; + } + + // We need to compute the relative slice of "inter" w.r.t. both slice_s and + // slice_d. + TensorSlice rel_s, rel_d; + slice_s.ComputeRelative(inter, &rel_s); + slice_d.ComputeRelative(inter, &rel_d); + + // Get the eigen tensor maps to the data. + auto t_s = GetEigenTensorMapFromTensorShape(shp_s, ptr_s); + auto t_d = GetEigenTensorMapFromTensorShape(shp_d, ptr_d); + + Eigen::DSizes<Eigen::DenseIndex, kTensorSliceMaxRank> s_start, s_len, + d_start, d_len; + + rel_s.FillIndicesAndSizes<kTensorSliceMaxRank>(shp_s, &s_start, &s_len); + rel_d.FillIndicesAndSizes<kTensorSliceMaxRank>(shp_d, &d_start, &d_len); + t_d.slice(d_start, d_len) = t_s.slice(s_start, s_len).template cast<DstT>(); + return true; + } +} + +} // namespace tensorflow + +#endif // TENSORFLOW_UTIL_TENSOR_SLICE_UTIL_H_ diff --git a/tensorflow/core/util/tensor_slice_util_test.cc b/tensorflow/core/util/tensor_slice_util_test.cc new file mode 100644 index 0000000000..348b0c884e --- /dev/null +++ b/tensorflow/core/util/tensor_slice_util_test.cc @@ -0,0 +1,91 @@ +#include "tensorflow/core/util/tensor_slice_util.h" + +#include <gtest/gtest.h> + +namespace tensorflow { +namespace { + +// Testing copying data from one tensor slice to another tensor slice +TEST(TensorSliceUtilTest, CopyTensorSliceToTensorSlice) { + // We map out a 2-d tensor of size 4 X 5 and we want the final results look + // like this: + // + // 0 1 2 3 4 + // 5 6 7 8 9 + // 10 11 12 13 14 + // 15 16 17 18 19 + // + // We assume this is a row-major matrix + // + TensorShape shape({4, 5}); + + // We will try to do a couple of slice to slice copies. + + // Case 1: simple identity copy + // The slice is the "interior" of the matrix + // . . . . . + // . 6 7 8 . + // , 11 12 13 . + // . . . . . + { + TensorSlice slice_s = TensorSlice::ParseOrDie("1,2:1,3"); + TensorSlice slice_d = TensorSlice::ParseOrDie("1,2:1,3"); + const float ptr_s[] = {6, 7, 8, 11, 12, 13}; + float ptr_d[6]; + for (int i = 0; i < 6; ++i) { + ptr_d[i] = 0; + } + EXPECT_TRUE(CopyDataFromTensorSliceToTensorSlice(shape, slice_s, slice_d, + ptr_s, ptr_d)); + for (int i = 0; i < 6; ++i) { + EXPECT_EQ(ptr_s[i], ptr_d[i]); + } + } + + // Case 2: no intersection + { + TensorSlice slice_s = TensorSlice::ParseOrDie("1,2:1,3"); + TensorSlice slice_d = TensorSlice::ParseOrDie("3,1:2,3"); + const float ptr_s[] = {6, 7, 8, 11, 12, 13}; + float ptr_d[6]; + EXPECT_FALSE(CopyDataFromTensorSliceToTensorSlice(shape, slice_s, slice_d, + ptr_s, ptr_d)); + } + + // Case 3: a trickier case + // The source slice is on the upper left corner: + // 0 1 2 . . + // 5 6 7 . . + // 10 11 12 . . + // . . . . . + // + // The destination slice is the right part of the middle stripe: + // . . . . . + // . X X X X + // . X X X X + // . . . . . + // + // So we expect to copy over the 2X2 block: + // . . . . . + // . 6 7 . . + // . 11 12 . . + // . . . . . + { + TensorSlice slice_s = TensorSlice::ParseOrDie("0,3:0,3"); + TensorSlice slice_d = TensorSlice::ParseOrDie("1,2:1,4"); + const float ptr_s[] = {0, 1, 2, 5, 6, 7, 10, 11, 12}; + float ptr_d[8]; + for (int i = 0; i < 8; ++i) { + ptr_d[i] = 0; + } + EXPECT_TRUE(CopyDataFromTensorSliceToTensorSlice(shape, slice_s, slice_d, + ptr_s, ptr_d)); + const float expected[] = {6, 7, 0, 0, 11, 12, 0, 0}; + for (int i = 0; i < 8; ++i) { + EXPECT_EQ(expected[i], ptr_d[i]); + } + } +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/util/tensor_slice_writer.cc b/tensorflow/core/util/tensor_slice_writer.cc new file mode 100644 index 0000000000..bb2fd96c05 --- /dev/null +++ b/tensorflow/core/util/tensor_slice_writer.cc @@ -0,0 +1,110 @@ +#include "tensorflow/core/util/tensor_slice_writer.h" + +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/lib/io/table_builder.h" +#include "tensorflow/core/lib/random/random.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/public/env.h" +#include "tensorflow/core/util/saved_tensor_slice_util.h" + +namespace tensorflow { + +namespace checkpoint { + +namespace { + +class TableBuilder : public TensorSliceWriter::Builder { + public: + TableBuilder(const string& name, WritableFile* f) + : name_(name), + file_(f), + builder_(new table::TableBuilder(table::Options(), f)) {} + void Add(StringPiece key, StringPiece val) override { + builder_->Add(key, val); + } + Status Finish(int64* file_size) override { + *file_size = -1; + Status s = builder_->Finish(); + if (s.ok()) { + s = file_->Close(); + if (s.ok()) { + *file_size = builder_->FileSize(); + } + } + if (!s.ok()) { + s = errors::Internal("Error writing (tmp) checkpoint file: ", name_, ": ", + s.ToString()); + } + builder_.reset(); + file_.reset(); + return s; + } + + private: + string name_; + std::unique_ptr<WritableFile> file_; + std::unique_ptr<table::TableBuilder> builder_; +}; +} // anonymous namespace + +Status CreateTableTensorSliceBuilder( + const string& name, TensorSliceWriter::Builder** builder) { + *builder = nullptr; + WritableFile* f; + Status s = Env::Default()->NewWritableFile(name, &f); + if (s.ok()) { + *builder = new TableBuilder(name, f); + return Status::OK(); + } else { + return s; + } +} + +TensorSliceWriter::TensorSliceWriter(const string& filename, + CreateBuilderFunction create_builder) + : filename_(filename), + create_builder_(create_builder), + tmpname_(strings::StrCat(filename, ".tempstate", random::New64())), + slices_(0) {} + +Status TensorSliceWriter::Finish() { + Builder* b; + Status s = create_builder_(tmpname_, &b); + if (!s.ok()) { + delete b; + return s; + } + std::unique_ptr<Builder> builder(b); + + // We save the saved tensor slice metadata as the first element. + string meta; + sts_.AppendToString(&meta); + builder->Add(kSavedTensorSlicesKey, meta); + + // Go through all the data and add them + for (const auto& x : data_) { + builder->Add(x.first, x.second); + } + + int64 file_size; + s = builder->Finish(&file_size); + // We need to rename the file to the proper name + if (s.ok()) { + s = Env::Default()->RenameFile(tmpname_, filename_); + if (s.ok()) { + VLOG(1) << "Written " << slices_ << " slices for " + << sts_.meta().tensor_size() << " tensors (" << file_size + << " bytes) to " << filename_; + } else { + LOG(ERROR) << "Failed to rename file " << tmpname_ << " to " << filename_; + } + } else { + Env::Default()->DeleteFile(tmpname_); + } + return s; +} + +} // namespace checkpoint + +} // namespace tensorflow diff --git a/tensorflow/core/util/tensor_slice_writer.h b/tensorflow/core/util/tensor_slice_writer.h new file mode 100644 index 0000000000..cce3880cb3 --- /dev/null +++ b/tensorflow/core/util/tensor_slice_writer.h @@ -0,0 +1,149 @@ +// The utility to write checkpoints for google brain tensor ops and v3 +// checkpoints for dist_belief. +// + +#ifndef TENSORFLOW_UTIL_TENSOR_SLICE_WRITER_H_ +#define TENSORFLOW_UTIL_TENSOR_SLICE_WRITER_H_ + +#include <unordered_map> + +#include "tensorflow/core/framework/tensor_slice.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/gtl/map_util.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/public/tensor_shape.h" +#include "tensorflow/core/public/status.h" +#include "tensorflow/core/util/saved_tensor_slice.pb.h" +#include "tensorflow/core/util/saved_tensor_slice_util.h" + +namespace tensorflow { + +namespace checkpoint { + +class TensorSliceWriter { + public: + // Abstract interface that TensorSliceWriter uses for building + class Builder { + public: + virtual ~Builder() {} + virtual void Add(StringPiece key, StringPiece value) = 0; + virtual Status Finish(int64* file_size) = 0; + }; + typedef std::function<Status(const string&, Builder**)> + CreateBuilderFunction; + + TensorSliceWriter(const string& filename, + CreateBuilderFunction create_builder); + virtual ~TensorSliceWriter() {} + // Adds a slice. We support float and int32 for now. + // TODO(yangke): add more supports + template <typename T> + Status Add(const string& name, const TensorShape& shape, + const TensorSlice& slice, const T* data); + Status Finish(); + + private: + // Allocate "num_elements" elements in "ss" and save the data in "data" + // there. + template <typename T> + static void SaveData(const T* data, int num_elements, SavedSlice* ss); + + const string filename_; + const CreateBuilderFunction create_builder_; + const string tmpname_; + + // A mapping from the tensor names to their index in meta_.saved_slice_meta() + std::unordered_map<string, int> name_to_index_; + // The metadata that holds all the saved tensor slices. + SavedTensorSlices sts_; + // The data to be written to the builder + std::map<string, string> data_; + // Total number of slices written + int slices_; + TF_DISALLOW_COPY_AND_ASSIGN(TensorSliceWriter); +}; + +template <typename T> +Status TensorSliceWriter::Add(const string& name, const TensorShape& shape, + const TensorSlice& slice, const T* data) { + // The tensor and the slice have to be compatible + if (shape.dims() != slice.dims()) { + return errors::Internal("Incompatible tensor shape and slice: ", "shape = ", + shape.DebugString(), ", slice = ", + slice.DebugString()); + } + DataType dt = DataTypeToEnum<T>::value; + // We need to add an entry for "name" if there isn't an entry already. + int index = gtl::FindWithDefault(name_to_index_, name, -1); + if (index >= 0) { + // The same tensor has been registered -- we verify that the shapes and the + // type agree. + const SavedSliceMeta& ssm = sts_.meta().tensor(index); + CHECK_EQ(name, ssm.name()) << ssm.ShortDebugString(); + TensorShape ssm_shape(ssm.shape()); + if (!shape.IsSameSize(ssm_shape)) { + return errors::Internal("Mismatching shapes: existing tensor = ", + ssm_shape.DebugString(), ", trying to add name ", + name, ", shape = ", shape.DebugString()); + } + if (dt != ssm.type()) { + return errors::Internal( + "Mismatching types: existing type = ", DataTypeString(ssm.type()), + ", trying to add name ", name, ", type = ", DataTypeString(dt)); + } + } else { + // Insert the new tensor name with the shape information + index = sts_.meta().tensor_size(); + name_to_index_.insert(std::make_pair(name, index)); + SavedSliceMeta* ssm = sts_.mutable_meta()->add_tensor(); + ssm->set_name(name); + shape.AsProto(ssm->mutable_shape()); + ssm->set_type(dt); + } + // Now we need to add the slice info the list of slices. + SavedSliceMeta* ssm = sts_.mutable_meta()->mutable_tensor(index); + slice.AsProto(ssm->add_slice()); + + // Now we need to add the real data. + { + SavedTensorSlices sts; + SavedSlice* ss = sts.mutable_data(); + ss->set_name(name); + slice.AsProto(ss->mutable_slice()); + TensorShape saved_shape(ssm->shape()); + TensorShape sliced_shape; + TF_RETURN_IF_ERROR(slice.SliceTensorShape(saved_shape, &sliced_shape)); + SaveData(data, sliced_shape.num_elements(), ss); + string key = EncodeTensorNameSlice(name, slice); + // TODO(yangke): consider doing a two-pass thing where the first pass just + // list the tensor slices we want to save and then another pass to actually + // set the data. Need to figure out if the interface works well. + std::pair<string, string> key_value(key, ""); + sts.AppendToString(&key_value.second); + data_.insert(key_value); + } + ++slices_; + return Status::OK(); +} + +template <typename T> +void TensorSliceWriter::SaveData(const T* data, int num_elements, + SavedSlice* ss) { + Fill(data, num_elements, ss->mutable_data()); +} + +// Create a table builder that will write to "filename" in +// tensorflow::io::Table format. If successful, return OK +// and set "*builder" to the allocated builder. Otherwise, return a +// non-OK status. +Status CreateTableTensorSliceBuilder(const string& filename, + TensorSliceWriter::Builder** builder); + +} // namespace checkpoint + +} // namespace tensorflow + +#endif // TENSORFLOW_UTIL_TENSOR_SLICE_WRITER_H_ diff --git a/tensorflow/core/util/tensor_slice_writer_test.cc b/tensorflow/core/util/tensor_slice_writer_test.cc new file mode 100644 index 0000000000..ca3dffe422 --- /dev/null +++ b/tensorflow/core/util/tensor_slice_writer_test.cc @@ -0,0 +1,248 @@ +#include "tensorflow/core/util/tensor_slice_writer.h" + +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/util/saved_tensor_slice_util.h" +#include "tensorflow/core/util/tensor_slice_reader.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/platform/test.h" +#include <gtest/gtest.h> + +namespace tensorflow { + +namespace checkpoint { + +class TensorSliceWriteTestHelper { + public: + static void CheckEntries(const string& fname); + static void GetData(TensorSliceReader::Table* table, const string& name, + const TensorSlice& slice, SavedSlice* ss); +}; + +namespace { + +// Testing that an array is what is expected +void ExpectIdenticalFloatArrays(const float* expected, int size, + const float* actual) { + // TODO(yangke): copy some of the Dump* functions over + // LOG(INFO) << "Expected = " << DumpFloatArray(expected, size); + // LOG(INFO) << "Actual = " << DumpFloatArray(actual, size); + for (int i = 0; i < size; ++i) { + EXPECT_NEAR(expected[i], actual[i], 1e-6); + } +} + +template <typename T, typename U> +void ExpectIdenticalIntArrays(const T* expected, int size, const U* actual) { + for (int i = 0; i < size; ++i) { + EXPECT_EQ(expected[i], static_cast<T>(actual[i])); + } +} + +// Nifty routine to get the size of an array +template <typename T, unsigned SIZE> +inline size_t ArraySize(const T(&v)[SIZE]) { + return SIZE; +} + +// A simple test on writing a few tensor slices +// TODO(yangke): refactor into smaller tests: will do as we add more stuff to +// the writer. +TEST(TensorSliceWriteTest, SimpleWrite) { + const string filename = io::JoinPath(testing::TmpDir(), "checkpoint"); + + TensorSliceWriter writer(filename, CreateTableTensorSliceBuilder); + + // Add some int32 tensor slices + { + TensorShape shape({5, 10}); + TensorSlice slice = TensorSlice::ParseOrDie("-:0,1"); + const int32 data[] = {0, 1, 2, 3, 4}; + TF_CHECK_OK(writer.Add("test", shape, slice, data)); + } + + // Two slices share the same tensor name + { + TensorShape shape({5, 10}); + TensorSlice slice = TensorSlice::ParseOrDie("-:3,1"); + const int32 data[] = {10, 11, 12, 13, 14}; + TF_CHECK_OK(writer.Add("test", shape, slice, data)); + } + + // Another slice from a different float tensor -- it has a different name and + // should be inserted in front of the previous tensor + { + TensorShape shape({3, 2}); + TensorSlice slice = TensorSlice::ParseOrDie("-:-"); + const float data[] = {1.2, 1.3, 1.4, 2.1, 2.2, 2.3}; + TF_CHECK_OK(writer.Add("AA", shape, slice, data)); + } + + // A slice with int64 data + { + TensorShape shape({5, 10}); + TensorSlice slice = TensorSlice::ParseOrDie("-:3,1"); + const int64 data[] = {10, 11, 12, 13, 14}; + TF_CHECK_OK(writer.Add("int64", shape, slice, data)); + } + + // A slice with int16 data + { + TensorShape shape({5, 10}); + TensorSlice slice = TensorSlice::ParseOrDie("-:3,1"); + const int16 data[] = {10, 11, 12, 13, 14}; + TF_CHECK_OK(writer.Add("int16", shape, slice, data)); + } + + TF_CHECK_OK(writer.Finish()); + + // Now we examine the checkpoint file manually. + TensorSliceWriteTestHelper::CheckEntries(filename); +} + +} // namespace + +void TensorSliceWriteTestHelper::GetData(TensorSliceReader::Table* table, + const string& name, + const TensorSlice& slice, + SavedSlice* ss) { + string key = EncodeTensorNameSlice(name, slice); + string value; + EXPECT_TRUE(table->Get(key, &value)); + SavedTensorSlices sts; + EXPECT_TRUE(ParseProtoUnlimited(&sts, value)); + EXPECT_FALSE(sts.has_meta()); + *ss = sts.data(); + EXPECT_EQ(name, ss->name()); + TensorSlice slice2(ss->slice()); + EXPECT_EQ(slice.DebugString(), slice2.DebugString()); +} + +void TensorSliceWriteTestHelper::CheckEntries(const string& fname) { + TensorSliceReader::Table* tptr; + TF_CHECK_OK(OpenTableTensorSliceReader(fname, &tptr)); + std::unique_ptr<TensorSliceReader::Table> table(tptr); + CHECK_NOTNULL(table.get()); + + // We expect a block of SavedTensorSlices + string value; + ASSERT_TRUE(table->Get(kSavedTensorSlicesKey, &value)); + { + SavedTensorSlices sts; + EXPECT_TRUE(ParseProtoUnlimited(&sts, value)); + // We also expect two entries for the tensors + EXPECT_TRUE(sts.has_meta()); + EXPECT_EQ(4, sts.meta().tensor_size()); + // We don't expect any data in the first block. + EXPECT_FALSE(sts.has_data()); + // The two tensors should be stored in the same order as they are first + // created. + { + // The two slices of the "test" tensor + const SavedSliceMeta& ssm = sts.meta().tensor(0); + EXPECT_EQ("test", ssm.name()); + EXPECT_EQ( + "dim { size: 5 } " + "dim { size: 10 }", + ssm.shape().ShortDebugString()); + EXPECT_EQ(DT_INT32, ssm.type()); + EXPECT_EQ(2, ssm.slice_size()); + TensorSlice s0(ssm.slice(0)); + TensorSlice s1(ssm.slice(1)); + EXPECT_EQ("-:0,1", s0.DebugString()); + EXPECT_EQ("-:3,1", s1.DebugString()); + } + { + // The "AA" tensor + const SavedSliceMeta& ssm = sts.meta().tensor(1); + EXPECT_EQ("AA", ssm.name()); + EXPECT_EQ( + "dim { size: 3 } " + "dim { size: 2 }", + ssm.shape().ShortDebugString()); + EXPECT_EQ(DT_FLOAT, ssm.type()); + EXPECT_EQ(1, ssm.slice_size()); + TensorSlice s0(ssm.slice(0)); + EXPECT_EQ("-:-", s0.DebugString()); + } + { + // The "int64" tensor + const SavedSliceMeta& ssm = sts.meta().tensor(2); + EXPECT_EQ("int64", ssm.name()); + EXPECT_EQ( + "dim { size: 5 } " + "dim { size: 10 }", + ssm.shape().ShortDebugString()); + EXPECT_EQ(DT_INT64, ssm.type()); + EXPECT_EQ(1, ssm.slice_size()); + TensorSlice s0(ssm.slice(0)); + EXPECT_EQ("-:3,1", s0.DebugString()); + } + { + // The "int16" tensor + const SavedSliceMeta& ssm = sts.meta().tensor(3); + EXPECT_EQ("int16", ssm.name()); + EXPECT_EQ( + "dim { size: 5 } " + "dim { size: 10 }", + ssm.shape().ShortDebugString()); + EXPECT_EQ(DT_INT16, ssm.type()); + EXPECT_EQ(1, ssm.slice_size()); + TensorSlice s0(ssm.slice(0)); + EXPECT_EQ("-:3,1", s0.DebugString()); + } + } + + // We expect 5 blocks of tensor data + { + // Block 1: we expect it to be the full slice of the "AA" tensor + SavedSlice ss; + GetData(table.get(), "AA", TensorSlice(2), &ss); + const float data[] = {1.2, 1.3, 1.4, 2.1, 2.2, 2.3}; + EXPECT_EQ(ArraySize(data), ss.data().float_val_size()); + ExpectIdenticalFloatArrays(data, ArraySize(data), + ss.data().float_val().data()); + } + + { + // Block 2: we expect it to be the first slice of the "test" tensor + SavedSlice ss; + GetData(table.get(), "test", TensorSlice({{0, -1}, {0, 1}}), &ss); + const int32 data[] = {0, 1, 2, 3, 4}; + EXPECT_EQ(ArraySize(data), ss.data().int_val_size()); + ExpectIdenticalIntArrays(data, ArraySize(data), ss.data().int_val().data()); + } + + { + // Block 3: we expect it to be the second slice of the "test" tensor + SavedSlice ss; + GetData(table.get(), "test", TensorSlice({{0, -1}, {3, 1}}), &ss); + const int32 data[] = {10, 11, 12, 13, 14}; + EXPECT_EQ(ArraySize(data), ss.data().int_val_size()); + ExpectIdenticalIntArrays(data, ArraySize(data), ss.data().int_val().data()); + } + + { + // Block 4: we expect it to be the slice of the "int64" tensor + SavedSlice ss; + GetData(table.get(), "int64", TensorSlice({{0, -1}, {3, 1}}), &ss); + const int64 data[] = {10, 11, 12, 13, 14}; + EXPECT_EQ(ArraySize(data), ss.data().int64_val_size()); + ExpectIdenticalIntArrays(data, ArraySize(data), + ss.data().int64_val().data()); + } + + { + // Block 5: we expect it to be the slice of the "int16" tensor + SavedSlice ss; + GetData(table.get(), "int16", TensorSlice({{0, -1}, {3, 1}}), &ss); + const int16 data[] = {10, 11, 12, 13, 14}; + EXPECT_EQ(ArraySize(data), ss.data().int_val_size()); + ExpectIdenticalIntArrays(data, ArraySize(data), ss.data().int_val().data()); + } +} + +} // namespace checkpoint + +} // namespace tensorflow diff --git a/tensorflow/core/util/use_cudnn.cc b/tensorflow/core/util/use_cudnn.cc new file mode 100644 index 0000000000..544b48a679 --- /dev/null +++ b/tensorflow/core/util/use_cudnn.cc @@ -0,0 +1,20 @@ +#include "tensorflow/core/util/use_cudnn.h" + +#include <stdlib.h> + +#include "tensorflow/core/platform/port.h" + +namespace tensorflow { + +bool CanUseCudnn() { + const char* tf_use_cudnn = getenv("TF_USE_CUDNN"); + if (tf_use_cudnn != nullptr) { + string tf_use_cudnn_str = tf_use_cudnn; + if (tf_use_cudnn_str == "0") { + return false; + } + } + return true; +} + +} // namespace tensorflow diff --git a/tensorflow/core/util/use_cudnn.h b/tensorflow/core/util/use_cudnn.h new file mode 100644 index 0000000000..20ce24c513 --- /dev/null +++ b/tensorflow/core/util/use_cudnn.h @@ -0,0 +1,12 @@ +// The utility to check whether we have Cudnn depenedency. + +#ifndef TENSORFLOW_UTIL_USE_CUDNN_H_ +#define TENSORFLOW_UTIL_USE_CUDNN_H_ + +namespace tensorflow { + +bool CanUseCudnn(); + +} // namespace tensorflow + +#endif // TENSORFLOW_UTIL_USE_CUDNN_H_ diff --git a/tensorflow/core/util/util.cc b/tensorflow/core/util/util.cc new file mode 100644 index 0000000000..14ac513074 --- /dev/null +++ b/tensorflow/core/util/util.cc @@ -0,0 +1,81 @@ +#include "tensorflow/core/util/util.h" + +#include "tensorflow/core/platform/logging.h" +namespace tensorflow { + +StringPiece NodeNamePrefix(const StringPiece& op_name) { + StringPiece sp(op_name); + auto p = sp.find('/'); + if (p == StringPiece::npos || p == 0) { + return ""; + } else { + return StringPiece(sp.data(), p); + } +} + +StringPiece NodeNameFullPrefix(const StringPiece& op_name) { + StringPiece sp(op_name); + auto p = sp.rfind('/'); + if (p == StringPiece::npos || p == 0) { + return ""; + } else { + return StringPiece(sp.data(), p); + } +} + +MovingAverage::MovingAverage(int window) + : window_(window), + sum_(0.0), + data_(new double[window_]), + head_(0), + count_(0) { + CHECK_GE(window, 1); +} + +MovingAverage::~MovingAverage() { delete[] data_; } + +void MovingAverage::Clear() { + count_ = 0; + head_ = 0; + sum_ = 0; +} + +double MovingAverage::GetAverage() const { + if (count_ == 0) { + return 0; + } else { + return static_cast<double>(sum_) / count_; + } +} + +void MovingAverage::AddValue(double v) { + if (count_ < window_) { + // This is the warmup phase. We don't have a full window's worth of data. + head_ = count_; + data_[count_++] = v; + } else { + if (window_ == ++head_) { + head_ = 0; + } + // Toss the oldest element + sum_ -= data_[head_]; + // Add the newest element + data_[head_] = v; + } + sum_ += v; +} + +static char hex_char[] = "0123456789abcdef"; + +string PrintMemory(const char* ptr, int n) { + string ret; + ret.resize(n * 3); + for (int i = 0; i < n; ++i) { + ret[i * 3] = ' '; + ret[i * 3 + 1] = hex_char[ptr[i] >> 4]; + ret[i * 3 + 2] = hex_char[ptr[i] & 0xf]; + } + return ret; +} + +} // namespace tensorflow diff --git a/tensorflow/core/util/util.h b/tensorflow/core/util/util.h new file mode 100644 index 0000000000..52650bd8ea --- /dev/null +++ b/tensorflow/core/util/util.h @@ -0,0 +1,40 @@ +#ifndef TENSORFLOW_UTIL_UTIL_H_ +#define TENSORFLOW_UTIL_UTIL_H_ + +#include "tensorflow/core/lib/core/stringpiece.h" + +namespace tensorflow { + +// If op_name has '/' in it, then return everything before the first '/'. +// Otherwise return empty string. +StringPiece NodeNamePrefix(const StringPiece& op_name); + +// If op_name has '/' in it, then return everything before the last '/'. +// Otherwise return empty string. +StringPiece NodeNameFullPrefix(const StringPiece& op_name); + +class MovingAverage { + public: + explicit MovingAverage(int window); + ~MovingAverage(); + + void Clear(); + + double GetAverage() const; + void AddValue(double v); + + private: + const int window_; // Max size of interval + double sum_; // Sum over interval + double* data_; // Actual data values + int head_; // Offset of the newest statistic in data_ + int count_; // # of valid data elements in window +}; + +// Returns a string printing bytes in ptr[0..n). The output looks +// like "00 01 ef cd cd ef". +string PrintMemory(const char* ptr, int n); + +} // namespace tensorflow + +#endif // TENSORFLOW_UTIL_UTIL_H_ diff --git a/tensorflow/core/util/work_sharder.cc b/tensorflow/core/util/work_sharder.cc new file mode 100644 index 0000000000..d9ab0805c5 --- /dev/null +++ b/tensorflow/core/util/work_sharder.cc @@ -0,0 +1,57 @@ +#include "tensorflow/core/util/work_sharder.h" + +#include <vector> +#include "tensorflow/core/lib/core/blocking_counter.h" +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { + +void Shard(int num_workers, thread::ThreadPool* workers, int64 total, + int64 cost_per_unit, std::function<void(int64, int64)> work) { + CHECK_GE(total, 0); + if (total == 0) { + return; + } + if (num_workers <= 1) { + // Just inline the whole work since we only have 1 thread (core). + work(0, total); + return; + } + cost_per_unit = std::max(1LL, cost_per_unit); + // We shard [0, total) into "num_shards" shards. + // 1 <= num_shards <= num worker threads + // + // If total * cost_per_unit is small, it is not worth shard too + // much. Let us assume each cost unit is 1ns, kMinCostPerShard=10000 + // is 10us. + static const int64 kMinCostPerShard = 10000; + const int num_shards = std::max( + 1, std::min<int>(num_workers, total * cost_per_unit / kMinCostPerShard)); + // Each shard contains up to "block_size" units. [0, total) is sharded + // into: + // [0, block_size), [block_size, 2*block_size), ... + // The 1st shard is done by the caller thread and the other shards + // are dispatched to the worker threads. The last shard may be smaller than + // block_size. + const int64 block_size = (total + num_shards - 1) / num_shards; + CHECK_GT(block_size, 0); // total > 0 guarantees this. + if (block_size >= total) { + work(0, total); + return; + } + const int num_shards_used = (total + block_size - 1) / block_size; + BlockingCounter counter(num_shards_used - 1); + for (int64 start = block_size; start < total; start += block_size) { + auto limit = std::min(start + block_size, total); + workers->Schedule([&work, &counter, start, limit]() { + work(start, limit); // Compute the shard. + counter.DecrementCount(); // The shard is done. + }); + } + + // Inline execute the 1st shard. + work(0, std::min(block_size, total)); + counter.Wait(); +} + +} // end namespace tensorflow diff --git a/tensorflow/core/util/work_sharder.h b/tensorflow/core/util/work_sharder.h new file mode 100644 index 0000000000..1ea2cf4397 --- /dev/null +++ b/tensorflow/core/util/work_sharder.h @@ -0,0 +1,33 @@ +#ifndef TENSORFLOW_UTIL_WORK_SHARDER_H_ +#define TENSORFLOW_UTIL_WORK_SHARDER_H_ + +#include <functional> + +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/lib/core/threadpool.h" + +namespace tensorflow { + +// Shards the "total" unit of work assuming each unit of work having +// roughly "cost_per_unit". Each unit of work is indexed 0, 1, ..., +// total - 1. Each shard contains 1 or more units of work and the +// total cost of each shard is roughly the same. The total number of +// shards is no more than num_workers. The calling thread and the +// "workers" are used to compute each shard (calling work(start, +// limit). A common configuration is that "workers" is a thread pool +// with "num_workers" threads. +// +// "work" should be a callable taking (int64, int64) arguments. +// work(start, limit) computes the work units from [start, +// limit), i.e., [start, limit) is a shard. +// +// REQUIRES: num_workers >= 0 +// REQUIRES: workers != nullptr +// REQUIRES: total >= 0 +// REQUIRES: cost_per_unit >= 0 +void Shard(int num_workers, thread::ThreadPool* workers, int64 total, + int64 cost_per_unit, std::function<void(int64, int64)> work); + +} // end namespace tensorflow + +#endif // TENSORFLOW_UTIL_WORK_SHARDER_H_ diff --git a/tensorflow/core/util/work_sharder_test.cc b/tensorflow/core/util/work_sharder_test.cc new file mode 100644 index 0000000000..d9792c0e8d --- /dev/null +++ b/tensorflow/core/util/work_sharder_test.cc @@ -0,0 +1,57 @@ +#include "tensorflow/core/util/work_sharder.h" + +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/platform/test_benchmark.h" +#include <gtest/gtest.h> + +namespace tensorflow { +namespace { + +void RunSharding(int64 num_workers, int64 total, int64 cost_per_unit) { + thread::ThreadPool threads(Env::Default(), "test", 16); + mutex mu; + int64 num_shards = 0; + int64 num_done_work = 0; + std::vector<bool> work(total, false); + Shard(num_workers, &threads, total, cost_per_unit, + [&mu, &num_shards, &num_done_work, &work](int start, int limit) { + VLOG(1) << "Shard [" << start << "," << limit << ")"; + mutex_lock l(mu); + ++num_shards; + for (; start < limit; ++start) { + EXPECT_FALSE(work[start]); // No duplicate + ++num_done_work; + work[start] = true; + } + }); + EXPECT_LE(num_shards, num_workers + 1); + EXPECT_EQ(num_done_work, total); + LOG(INFO) << num_workers << " " << total << " " << cost_per_unit << " " + << num_shards; +} + +TEST(Shard, Basic) { + for (auto workers : {0, 1, 2, 3, 5, 7, 10, 11, 15, 100, 1000}) { + for (auto total : {0, 1, 7, 10, 64, 100, 256, 1000, 9999}) { + for (auto cost_per_unit : {0, 1, 11, 102, 1003, 10005, 1000007}) { + RunSharding(workers, total, cost_per_unit); + } + } + } +} + +void BM_Sharding(int iters, int arg) { + thread::ThreadPool threads(Env::Default(), "test", 16); + const int64 total = 1LL << 30; + auto lambda = [](int64 start, int64 limit) {}; + auto work = std::cref(lambda); + for (; iters > 0; iters -= arg) { + Shard(arg - 1, &threads, total, 1, work); + } +} +BENCHMARK(BM_Sharding)->Range(1, 128); + +} // namespace +} // namespace tensorflow |