aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--eigen.BUILD2
-rw-r--r--tensorflow/contrib/linear_optimizer/kernels/BUILD17
-rw-r--r--tensorflow/contrib/linear_optimizer/kernels/hinge-loss.h46
-rw-r--r--tensorflow/contrib/linear_optimizer/kernels/logistic-loss.h75
-rw-r--r--tensorflow/contrib/linear_optimizer/kernels/loss.h53
-rw-r--r--tensorflow/contrib/linear_optimizer/kernels/loss_updaters_test.cc58
-rw-r--r--tensorflow/contrib/linear_optimizer/kernels/resources.cc14
-rw-r--r--tensorflow/contrib/linear_optimizer/kernels/resources.h26
-rw-r--r--tensorflow/contrib/linear_optimizer/kernels/resources_test.cc4
-rw-r--r--tensorflow/contrib/linear_optimizer/kernels/sdca_ops.cc389
-rw-r--r--tensorflow/contrib/linear_optimizer/kernels/squared-loss.h31
-rw-r--r--tensorflow/contrib/linear_optimizer/ops/sdca_ops.cc56
-rw-r--r--tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py259
-rw-r--r--tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py89
-rw-r--r--tensorflow/core/BUILD114
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_device_factory.cc44
-rw-r--r--tensorflow/core/common_runtime/threadpool_device.cc10
-rw-r--r--tensorflow/core/distributed_runtime/BUILD12
-rw-r--r--tensorflow/core/distributed_runtime/rpc/BUILD7
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc122
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h65
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_server_lib_test.cc4
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_tensorflow_server.cc5
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_testlib_server.cc3
-rw-r--r--tensorflow/core/distributed_runtime/server_lib.cc73
-rw-r--r--tensorflow/core/distributed_runtime/server_lib.h98
-rw-r--r--tensorflow/core/framework/op_kernel.h8
-rw-r--r--tensorflow/core/kernels/BUILD45
-rw-r--r--tensorflow/core/kernels/bounds_check.h13
-rw-r--r--tensorflow/core/kernels/decode_csv_op.cc7
-rw-r--r--tensorflow/core/kernels/in_topk_op.cc6
-rw-r--r--tensorflow/core/kernels/segment_reduction_ops.cc37
-rw-r--r--tensorflow/core/kernels/stack_ops.cc75
-rw-r--r--tensorflow/core/kernels/transpose_op.cc2
-rw-r--r--tensorflow/core/lib/random/philox_random.h2
-rw-r--r--tensorflow/core/platform/macros.h4
-rw-r--r--tensorflow/core/protobuf/tensorflow_server.proto5
-rw-r--r--tensorflow/examples/udacity/README.md7
-rw-r--r--tensorflow/models/image/mnist/convolutional.py2
-rw-r--r--tensorflow/python/BUILD26
-rw-r--r--tensorflow/python/__init__.py4
-rw-r--r--tensorflow/python/client/client_lib.py6
-rw-r--r--tensorflow/python/client/server_lib.i88
-rw-r--r--tensorflow/python/client/server_lib.py86
-rw-r--r--tensorflow/python/client/server_lib_test.py65
-rw-r--r--tensorflow/python/framework/gen_docs_combined.py4
-rw-r--r--tensorflow/python/framework/ops.py19
-rw-r--r--tensorflow/python/framework/ops_test.py8
-rw-r--r--tensorflow/python/kernel_tests/constant_op_test.py13
-rw-r--r--tensorflow/python/kernel_tests/control_flow_ops_py_test.py2
-rw-r--r--tensorflow/python/kernel_tests/decode_csv_op_test.py15
-rw-r--r--tensorflow/python/kernel_tests/in_topk_op_test.py8
-rw-r--r--tensorflow/python/kernel_tests/rnn_test.py19
-rw-r--r--tensorflow/python/ops/array_ops.py10
-rw-r--r--tensorflow/python/ops/control_flow_grad.py82
-rw-r--r--tensorflow/python/ops/control_flow_ops.py429
-rw-r--r--tensorflow/python/ops/gradients.py18
-rw-r--r--tensorflow/python/ops/image_ops.py3
-rw-r--r--tensorflow/python/ops/rnn.py58
-rw-r--r--tensorflow/python/tensorflow.i1
-rw-r--r--tensorflow/python/training/coordinator.py7
-rw-r--r--tensorflow/tensorboard/components/tf-categorizer/test/categorizerTest.ts2
-rw-r--r--tensorflow/tensorboard/components/tf-categorizer/test/index.html13
-rw-r--r--tensorflow/tensorboard/components/tf-dashboard-common/urlGenerator.ts19
-rw-r--r--tensorflow/tensorboard/components/tf-event-dashboard/dataCoordinator.ts3
-rw-r--r--tensorflow/tensorboard/components/tf-event-dashboard/dataset.ts3
-rw-r--r--tensorflow/tensorboard/components/tf-event-dashboard/tf-chart.ts3
-rw-r--r--tensorflow/tensorboard/components/tf-graph-common/lib/graph.ts2
-rw-r--r--tensorflow/tensorboard/components/tf-graph-common/lib/hierarchy.ts4
-rw-r--r--tensorflow/tensorboard/components/tf-graph-common/lib/layout.ts4
-rw-r--r--tensorflow/tensorboard/components/tf-graph-common/lib/parser.ts2
-rw-r--r--tensorflow/tensorboard/components/tf-graph-common/lib/render.ts5
-rw-r--r--tensorflow/tensorboard/components/tf-graph-common/lib/scene/annotation.ts7
-rw-r--r--tensorflow/tensorboard/components/tf-graph-common/lib/scene/edge.ts5
-rw-r--r--tensorflow/tensorboard/components/tf-graph-common/lib/scene/minimap.ts3
-rw-r--r--tensorflow/tensorboard/components/tf-graph-common/lib/scene/node.ts6
-rw-r--r--tensorflow/tensorboard/components/tf-graph-common/lib/scene/scene.ts6
-rw-r--r--tensorflow/tensorboard/components/tf-graph-common/lib/template.ts4
-rw-r--r--tensorflow/tensorboard/components/tf-tensorboard/demo/index.html43
-rw-r--r--tensorflow/tensorboard/components/tf-tensorboard/tf-tensorboard-demo.html19
-rw-r--r--tensorflow/tensorboard/components/tf-tensorboard/tf-tensorboard.html70
-rw-r--r--tensorflow/tensorboard/components/tf-test/index.html16
-rw-r--r--tensorflow/tensorboard/gulpfile.js3
-rw-r--r--tensorflow/tensorboard/lib/js/backend/test/index.html6
-rw-r--r--tensorflow/tensorboard/lib/js/nanite/test/index.html9
-rw-r--r--tensorflow/tensorboard/lib/js/nanite/test/naniteTest.ts6
-rw-r--r--tensorflow/tensorboard/lib/js/node-radar/test/index.html6
-rw-r--r--tensorflow/tensorboard/lib/js/requestManager/test/index.html3
-rw-r--r--tensorflow/tensorboard/package.json2
-rw-r--r--tensorflow/tensorboard/wct.conf.json12
-rw-r--r--tensorflow/tools/docker/notebooks/3_mnist_from_scratch.ipynb4
-rw-r--r--tensorflow/workspace.bzl4
-rw-r--r--third_party/eigen3/Eigen/Cholesky2
-rw-r--r--third_party/eigen3/Eigen/Core2
-rw-r--r--third_party/eigen3/Eigen/Eigenvalues2
-rw-r--r--third_party/eigen3/Eigen/LU2
-rw-r--r--third_party/eigen3/Eigen/QR2
-rw-r--r--third_party/eigen3/unsupported/Eigen/CXX11/Tensor2
-rw-r--r--tools/bazel.rc.template2
99 files changed, 2009 insertions, 1159 deletions
diff --git a/eigen.BUILD b/eigen.BUILD
index 958772ee9d..806b6d36b9 100644
--- a/eigen.BUILD
+++ b/eigen.BUILD
@@ -1,6 +1,6 @@
package(default_visibility = ["//visibility:public"])
-archive_dir = "eigen-eigen-f1ce2528ee99"
+archive_dir = "eigen-eigen-88444e025a5c"
cc_library(
name = "eigen",
diff --git a/tensorflow/contrib/linear_optimizer/kernels/BUILD b/tensorflow/contrib/linear_optimizer/kernels/BUILD
index 682fd6f822..2e56171211 100644
--- a/tensorflow/contrib/linear_optimizer/kernels/BUILD
+++ b/tensorflow/contrib/linear_optimizer/kernels/BUILD
@@ -7,16 +7,27 @@ exports_files(["LICENSE"])
package(default_visibility = ["//tensorflow:__subpackages__"])
cc_library(
- name = "losses",
+ name = "loss_updaters",
hdrs = [
"hinge-loss.h",
"logistic-loss.h",
+ "loss.h",
"squared-loss.h",
],
deps = ["//tensorflow/core:lib"],
)
-# TODO(katsiapis): Add tests for losses.
+cc_test(
+ name = "loss_updaters_test",
+ srcs = ["loss_updaters_test.cc"],
+ deps = [
+ ":loss_updaters",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
+)
cc_library(
name = "resources",
@@ -44,7 +55,7 @@ cc_library(
name = "sdca_ops",
srcs = ["sdca_ops.cc"],
deps = [
- ":losses",
+ ":loss_updaters",
":resources",
"//third_party/eigen3",
"//tensorflow/core:framework",
diff --git a/tensorflow/contrib/linear_optimizer/kernels/hinge-loss.h b/tensorflow/contrib/linear_optimizer/kernels/hinge-loss.h
index 877fceeeb4..3655fa707e 100644
--- a/tensorflow/contrib/linear_optimizer/kernels/hinge-loss.h
+++ b/tensorflow/contrib/linear_optimizer/kernels/hinge-loss.h
@@ -19,8 +19,13 @@ limitations under the License.
#include <algorithm>
#include <cmath>
+#include "tensorflow/contrib/linear_optimizer/kernels/loss.h"
+#include "tensorflow/core/lib/core/errors.h"
+
namespace tensorflow {
-struct hinge_loss {
+
+class HingeLossUpdater : public DualLossUpdater {
+ public:
// Computes the updated dual variable (corresponding) to a single example. The
// updated dual value maximizes the objective function of the dual
// optimization problem associated with hinge loss (conditioned on keeping the
@@ -30,13 +35,11 @@ struct hinge_loss {
// and the particular form of conjugate function for hinge loss.
// TODO(pmol): Write up a doc with concrete derivation and point to it from
// here.
- inline static double ComputeUpdatedDual(const double label,
- const double example_weight,
- const double current_dual,
- const double wx,
- const double weighted_example_norm,
- const double primal_loss,
- const double dual_loss) {
+ double ComputeUpdatedDual(const double label, const double example_weight,
+ const double current_dual, const double wx,
+ const double weighted_example_norm,
+ const double primal_loss,
+ const double dual_loss) const final {
// Intutitvely there are 3 cases:
// a. new optimal value of the dual variable falls withing the admissible
// range [0, 1]. In this case we set new dual to this value.
@@ -65,9 +68,8 @@ struct hinge_loss {
// on its label. In particular:
// \phi_y*(z) = y*z if y*z \in [-w, 0] and +infinity everywhere else where
// y \in {-1,1}. The following method implements \phi_y*(-\alpha/w).
- inline static double ComputeDualLoss(const double current_dual,
- const double example_label,
- const double example_weight) {
+ double ComputeDualLoss(const double current_dual, const double example_label,
+ const double example_weight) const final {
// For binary classification, there are 2 conjugate functions, one per
// label value (-1 and 1).
const double y_alpha = current_dual * example_label; // y \alpha
@@ -80,13 +82,29 @@ struct hinge_loss {
// Hinge loss for binary classification for a single example. Hinge loss
// equals max(0, 1 - y * wx) (see https://en.wikipedia.org/wiki/Hinge_loss).
// For weighted instances loss should be multiplied by the instance weight.
- inline static double ComputePrimalLoss(const double wx,
- const double example_label,
- const double example_weight) {
+ double ComputePrimalLoss(const double wx, const double example_label,
+ const double example_weight) const final {
const double y_wx = example_label * wx;
return std::max(0.0, 1 - y_wx) * example_weight;
}
+
+ // Converts binary example labels from 0.0 or 1.0 to -1.0 or 1.0 respectively
+ // as expected by hinge loss.
+ Status ConvertLabel(float* const example_label) const final {
+ if (*example_label == 0.0) {
+ *example_label = -1;
+ return Status::OK();
+ }
+ if (*example_label == 1.0) {
+ return Status::OK();
+ }
+ return errors::InvalidArgument(
+ "Only labels of 0.0 or 1.0 are supported right now. "
+ "Found example with label: ",
+ *example_label);
+ }
};
+
} // namespace tensorflow
#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LINEAR_OPTIMIZER_KERNELS_HINGE_LOSS_H_
diff --git a/tensorflow/contrib/linear_optimizer/kernels/logistic-loss.h b/tensorflow/contrib/linear_optimizer/kernels/logistic-loss.h
index d75a707820..b18116be9d 100644
--- a/tensorflow/contrib/linear_optimizer/kernels/logistic-loss.h
+++ b/tensorflow/contrib/linear_optimizer/kernels/logistic-loss.h
@@ -19,44 +19,21 @@ limitations under the License.
#include <algorithm>
#include <cmath>
+#include "tensorflow/contrib/linear_optimizer/kernels/loss.h"
#include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/lib/core/status.h"
namespace tensorflow {
-struct logistic_loss {
- // Partial derivative of the logistic loss w.r.t (1 + exp(-ywx)).
- inline static double PartialDerivativeLogisticLoss(const double wx,
- const double label) {
- // To avoid overflow, we compute partial derivative of logistic loss as
- // follows.
- const double ywx = label * wx;
- if (ywx > 0) {
- const double exp_minus_ywx = exp(-ywx);
- return exp_minus_ywx / (1 + exp_minus_ywx);
- }
- return 1 / (1 + exp(ywx));
- }
-
- // Smoothness constant for the logistic loss.
- inline static double SmoothnessConstantLogisticLoss(
- const double partial_derivative_loss, const double wx,
- const double label) {
- // Upper bound on the smoothness constant of log loss. This is 0.25 i.e.
- // when log-odds is zero.
- return (wx == 0) ? 0.25
- : (1 - 2 * partial_derivative_loss) / (2 * label * wx);
- }
+class LogisticLossUpdater : public DualLossUpdater {
+ public:
// Use an approximate step that is guaranteed to decrease the dual loss.
// Derivation of this is available in Page 14 Eq 16 of
// http://arxiv.org/pdf/1211.2717v1.pdf
- inline static double ComputeUpdatedDual(const double label,
- const double example_weight,
- const double current_dual,
- const double wx,
- const double weighted_example_norm,
- const double primal_loss,
- const double dual_loss) {
+ double ComputeUpdatedDual(const double label, const double example_weight,
+ const double current_dual, const double wx,
+ const double weighted_example_norm,
+ const double primal_loss,
+ const double dual_loss) const final {
const double partial_derivative_loss =
PartialDerivativeLogisticLoss(label, wx);
// f(a) = sup (a*x - f(x)) then a = f'(x), where a is the aproximate dual.
@@ -81,9 +58,8 @@ struct logistic_loss {
// Dual of logisitic loss function.
// https://en.wikipedia.org/wiki/Convex_conjugate
- inline static double ComputeDualLoss(const double current_dual,
- const double example_label,
- const double example_weight) {
+ double ComputeDualLoss(const double current_dual, const double example_label,
+ const double example_weight) const final {
// Dual of the logistic loss function is
// ay * log(ay) + (1-ay) * log (1-ay), where a is the dual variable.
const double ay = current_dual * example_label;
@@ -95,9 +71,8 @@ struct logistic_loss {
// Logistic loss for binary classification.
// https://en.wikipedia.org/wiki/Loss_functions_for_classification
- inline static double ComputePrimalLoss(const double wx,
- const double example_label,
- const double example_weight) {
+ double ComputePrimalLoss(const double wx, const double example_label,
+ const double example_weight) const final {
// Logistic loss:
// log(1 + e^(-ywx))
// log(e^0 + e^(-ywx))
@@ -117,7 +92,7 @@ struct logistic_loss {
// Converts binary example labels from 0.0 or 1.0 to -1.0 or 1.0 respectively
// as expected by logistic regression.
- inline static Status ConvertLabel(float* const example_label) {
+ Status ConvertLabel(float* const example_label) const final {
if (*example_label == 0.0) {
*example_label = -1;
return Status::OK();
@@ -130,6 +105,30 @@ struct logistic_loss {
"Found example with label: ",
*example_label);
}
+
+ private:
+ // Partial derivative of the logistic loss w.r.t (1 + exp(-ywx)).
+ static inline double PartialDerivativeLogisticLoss(const double wx,
+ const double label) {
+ // To avoid overflow, we compute partial derivative of logistic loss as
+ // follows.
+ const double ywx = label * wx;
+ if (ywx > 0) {
+ const double exp_minus_ywx = exp(-ywx);
+ return exp_minus_ywx / (1 + exp_minus_ywx);
+ }
+ return 1 / (1 + exp(ywx));
+ }
+
+ // Smoothness constant for the logistic loss.
+ static inline double SmoothnessConstantLogisticLoss(
+ const double partial_derivative_loss, const double wx,
+ const double label) {
+ // Upper bound on the smoothness constant of log loss. This is 0.25 i.e.
+ // when log-odds is zero.
+ return (wx == 0) ? 0.25
+ : (1 - 2 * partial_derivative_loss) / (2 * label * wx);
+ }
};
} // namespace tensorflow
diff --git a/tensorflow/contrib/linear_optimizer/kernels/loss.h b/tensorflow/contrib/linear_optimizer/kernels/loss.h
new file mode 100644
index 0000000000..d827d6f764
--- /dev/null
+++ b/tensorflow/contrib/linear_optimizer/kernels/loss.h
@@ -0,0 +1,53 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LINEAR_OPTIMIZER_KERNELS_LOSS_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LINEAR_OPTIMIZER_KERNELS_LOSS_H_
+
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+
+class DualLossUpdater {
+ public:
+ virtual ~DualLossUpdater() {}
+
+ // Compute update dual (alpha), based on a single example. Various strategies
+ // can be employed here, like newton step and/or line search or approximate
+ // step that decreases the dual sub-optimality.
+ virtual double ComputeUpdatedDual(const double label,
+ const double example_weight,
+ const double current_dual, const double wx,
+ const double weighted_example_norm,
+ const double primal_loss,
+ const double dual_loss) const = 0;
+
+ // Compute dual loss based on the current dual (alpha), example label (y)
+ // and example weight (cost).
+ virtual double ComputeDualLoss(const double current_dual,
+ const double example_label,
+ const double example_weight) const = 0;
+
+ // Compute the primal loss based on current estimate of log-odds(wx),
+ // example label (y) and example weight (cost).
+ virtual double ComputePrimalLoss(const double wx, const double example_label,
+ const double example_weight) const = 0;
+
+ // Converts binary example labels from 0.0 or 1.0 to appropriate range for
+ // each loss function.
+ virtual Status ConvertLabel(float* const example_label) const = 0;
+};
+
+} // namespace tensorflow
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LINEAR_OPTIMIZER_KERNELS_LOSS_H_
diff --git a/tensorflow/contrib/linear_optimizer/kernels/loss_updaters_test.cc b/tensorflow/contrib/linear_optimizer/kernels/loss_updaters_test.cc
new file mode 100644
index 0000000000..7d9f05609b
--- /dev/null
+++ b/tensorflow/contrib/linear_optimizer/kernels/loss_updaters_test.cc
@@ -0,0 +1,58 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/linear_optimizer/kernels/logistic-loss.h"
+#include "tensorflow/contrib/linear_optimizer/kernels/squared-loss.h"
+
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace {
+
+TEST(LogisticLoss, ComputePrimalLoss) {
+ LogisticLossUpdater loss_updater;
+ EXPECT_NEAR(0.693147, loss_updater.ComputePrimalLoss(
+ 0 /* wx */, 1 /* label */, 1 /* example weight */),
+ 1e-3);
+ EXPECT_NEAR(0.0, loss_updater.ComputePrimalLoss(70 /* wx */, 1 /* label */,
+ 1 /* example weight */),
+ 1e-3);
+ EXPECT_NEAR(0.0, loss_updater.ComputePrimalLoss(-70 /* wx */, -1 /* label */,
+ 1 /* example weight */),
+ 1e-3);
+}
+
+TEST(LogisticLoss, ComputeDualLoss) {
+ LogisticLossUpdater loss_updater;
+ EXPECT_NEAR(0.0,
+ loss_updater.ComputeDualLoss(0 /* current dual */, 1 /* label */,
+ 1 /* example weight */),
+ 1e-3);
+ EXPECT_NEAR(0.0,
+ loss_updater.ComputeDualLoss(1 /* current dual */, 1 /* label */,
+ 1 /* example weight */),
+ 1e-3);
+ EXPECT_NEAR(-0.693147, loss_updater.ComputeDualLoss(0.5 /* current dual */,
+ 1 /* label */,
+ 1 /* example weight */),
+ 1e-3);
+}
+
+// TODO(rohananil): Add tests for dual update.
+// TODO(dbaylor): Add tests for squared loss.
+// TODO(pmol): Add tests for hinge loss.
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/contrib/linear_optimizer/kernels/resources.cc b/tensorflow/contrib/linear_optimizer/kernels/resources.cc
index 392ceac12a..d6266616f1 100644
--- a/tensorflow/contrib/linear_optimizer/kernels/resources.cc
+++ b/tensorflow/contrib/linear_optimizer/kernels/resources.cc
@@ -44,6 +44,16 @@ DataByExample::Key DataByExample::MakeKey(const string& example_id) {
Hash64(example_id.data(), example_id.size(), kSeed2) & 0xFFFFFFFF);
}
+DataByExample::Data DataByExample::Get(const Key& key) {
+ mutex_lock l(mu_);
+ return data_by_key_[key];
+}
+
+void DataByExample::Set(const Key& key, const Data& data) {
+ mutex_lock l(mu_);
+ data_by_key_[key] = data;
+}
+
Status DataByExample::Visit(
std::function<void(const Data& data)> visitor) const {
struct State {
@@ -71,8 +81,8 @@ Status DataByExample::Visit(
// be successful if and only if the size of the backing store hasn't
// changed (since the body of this while-loop is under lock).
if (data_by_key_.size() != state.size) {
- return errors::Aborted("The number of elements for ", solver_uuid_,
- " has changed which nullifies a visit.");
+ return errors::Unavailable("The number of elements for ", solver_uuid_,
+ " has changed which nullifies a visit.");
}
for (size_t i = 0; i < kVisitChunkSize && state.num_visited < state.size;
++i, ++state.num_visited, ++state.it) {
diff --git a/tensorflow/contrib/linear_optimizer/kernels/resources.h b/tensorflow/contrib/linear_optimizer/kernels/resources.h
index cb0ea8433e..4578e3442f 100644
--- a/tensorflow/contrib/linear_optimizer/kernels/resources.h
+++ b/tensorflow/contrib/linear_optimizer/kernels/resources.h
@@ -47,8 +47,10 @@ class DataByExample : public ResourceBase {
static Key MakeKey(const string& example_id);
struct Data {
- // TODO(rohananil): Add extra data needed for duality gap computation here.
float dual = 0;
+ float primal_loss = 0;
+ float dual_loss = 0;
+ float example_weight = 0;
// Comparison operators for ease of testing.
bool operator==(const Data& other) const { return dual == other.dual; }
@@ -58,22 +60,17 @@ class DataByExample : public ResourceBase {
// Accessor and mutator for the entry at Key. Accessor creates an entry with
// default value (default constructed object) if the key is not present and
// returns it.
- inline Data Get(const Key& key) LOCKS_EXCLUDED(mu_) {
- mutex_lock l(mu_);
- return data_by_key_[key];
- }
- inline void Set(const Key& key, const Data& data) LOCKS_EXCLUDED(mu_) {
- mutex_lock l(mu_);
- data_by_key_[key] = data;
- }
+ Data Get(const Key& key) LOCKS_EXCLUDED(mu_);
+ void Set(const Key& key, const Data& data) LOCKS_EXCLUDED(mu_);
// Visits all elements in this resource. The view of each element (Data) is
// atomic, but the entirety of the visit is not (ie the visitor might see
// different versions of the Data across elements).
//
- // Returns OK on success or ABORTED if the number of elements in this
+ // Returns OK on success or UNAVAILABLE if the number of elements in this
// container has changed since the beginning of the visit (in which case the
- // visit cannot be completed and is aborted early).
+ // visit cannot be completed and is aborted early, and computation can be
+ // restarted).
Status Visit(std::function<void(const Data& data)> visitor) const
LOCKS_EXCLUDED(mu_);
@@ -86,8 +83,11 @@ class DataByExample : public ResourceBase {
// Backing container.
//
- // sizeof(EntryPayload) = sizeof(Key) + sizeof(Data) = 16.
- // So on average we use ~35 bytes per entry in this table.
+ // sizeof(EntryPayload) =
+ // sizeof(Key) + sizeof(Data) =
+ // 12 + 16 = 28.
+ //
+ // So on average we use ~47.5 (28 + 19.5) bytes per entry in this table.
using DataByKey = std::unordered_map<Key, Data, KeyHash>;
// TODO(katsiapis): Benchmark and/or optimize this.
diff --git a/tensorflow/contrib/linear_optimizer/kernels/resources_test.cc b/tensorflow/contrib/linear_optimizer/kernels/resources_test.cc
index 9a94c54bc5..1981db3160 100644
--- a/tensorflow/contrib/linear_optimizer/kernels/resources_test.cc
+++ b/tensorflow/contrib/linear_optimizer/kernels/resources_test.cc
@@ -103,7 +103,7 @@ TEST_F(DataByExampleTest, VisitMany) {
(kNumElements - 1) * kNumElements / 2.0, total_dual);
}
-TEST_F(DataByExampleTest, VisitAborted) {
+TEST_F(DataByExampleTest, VisitUnavailable) {
// Populate enough entries so that Visiting will be chunked.
for (size_t i = 0; i < 2 * VisitChunkSize(); ++i) {
data_by_example_->Get(DataByExample::MakeKey(strings::StrCat(i)));
@@ -151,7 +151,7 @@ TEST_F(DataByExampleTest, VisitAborted) {
});
wait(&completed_visit);
EXPECT_FALSE(thread_pool.HasPendingClosures());
- EXPECT_TRUE(errors::IsAborted(status));
+ EXPECT_TRUE(errors::IsUnavailable(status));
}
} // namespace tensorflow
diff --git a/tensorflow/contrib/linear_optimizer/kernels/sdca_ops.cc b/tensorflow/contrib/linear_optimizer/kernels/sdca_ops.cc
index 209c9a85e8..76671a47cd 100644
--- a/tensorflow/contrib/linear_optimizer/kernels/sdca_ops.cc
+++ b/tensorflow/contrib/linear_optimizer/kernels/sdca_ops.cc
@@ -78,7 +78,6 @@ struct RegularizationLoss {
};
struct PerExampleData {
- double old_wx = 0;
double wx = 0;
double norm = 0;
};
@@ -256,7 +255,6 @@ inline PerExampleData ComputeWxAndWeightedExampleNorm(
const int64 index = indices(dim);
const double weight = weights(index);
const double value = values(dim);
- result.old_wx += Shrink(weight, shrink_by) * value;
result.wx += Shrink(weight + delta_weights(index), shrink_by) * value;
}
result.norm += sparse_indices_values[example_id]->norm;
@@ -265,7 +263,6 @@ inline PerExampleData ComputeWxAndWeightedExampleNorm(
for (size_t i = 0; i < dense_features_by_group.size(); ++i) {
const double weight = dense_weights_by_group[i](0);
const double value = dense_features_by_group[i](example_id);
- result.old_wx += Shrink(weight, shrink_by) * value;
result.wx +=
Shrink(weight + dense_delta_weights_by_group[i](0), shrink_by) * value;
result.norm += value * value;
@@ -300,6 +297,7 @@ void AddDeltaWeights(const WeightsByGroup& src, WeightsByGroup* const dst) {
void ShrinkWeights(const Regularizations& regularizations,
WeightsByGroup* const sparse_weights_by_group,
WeightsByGroup* const dense_weights_by_group) {
+ // TODO(rohananil): Parallelize shrinking.
const double shrink_by = ShrinkageFactor(regularizations);
for (TTypes<float>::Vec weights : *sparse_weights_by_group) {
for (int64 i = 0; i < weights.size(); ++i) {
@@ -380,6 +378,88 @@ WeightsByGroup MakeDeltaWeightsFrom(std::vector<Tensor>* const tensors) {
return result;
}
+Status RunTrainStepsForMiniBatch(
+ const std::vector<int64>& example_indices,
+ const TTypes<const string>::Vec example_ids,
+ const TTypes<const float>::Vec example_labels,
+ const TTypes<const float>::Vec example_weights,
+ const DeviceBase::CpuWorkerThreads& worker_threads,
+ const Regularizations& regularizations,
+ const WeightsByGroup& sparse_weights_by_group,
+ const SparseExamplesByGroup& sparse_examples_by_group,
+ const WeightsByGroup& dense_weights_by_group,
+ const DenseFeaturesByGroup& dense_features_by_group,
+ const DualLossUpdater& loss_updater,
+ WeightsByGroup* const sparse_delta_weights_by_group,
+ WeightsByGroup* const dense_delta_weights_by_group,
+ DataByExample* const data_by_example) {
+ // Process examples in parallel, in a partitioned fashion.
+ mutex mu;
+ Status train_step_status GUARDED_BY(mu);
+ auto train_step = [&](const int64 begin, const int64 end) {
+ for (int64 offset = begin; offset < end; ++offset) {
+ // Get example id, label, and weight.
+ const int64 example_index = example_indices[offset];
+ const DataByExample::Key example_key =
+ DataByExample::MakeKey(example_ids(example_index));
+ DataByExample::Data data = data_by_example->Get(example_key);
+ const double example_weight = example_weights(example_index);
+ float example_label = example_labels(example_index);
+ const Status conversion_status =
+ loss_updater.ConvertLabel(&example_label);
+ if (!conversion_status.ok()) {
+ mutex_lock l(mu);
+ train_step_status = conversion_status;
+ // Return from this worker thread - the calling thread is
+ // responsible for checking context status and returning on error.
+ return;
+ }
+
+ // Compute wx, example norm weighted by regularization, dual loss,
+ // primal loss.
+ const PerExampleData per_example_data = ComputeWxAndWeightedExampleNorm(
+ example_index, sparse_weights_by_group,
+ *sparse_delta_weights_by_group, sparse_examples_by_group,
+ dense_weights_by_group, *dense_delta_weights_by_group,
+ dense_features_by_group, regularizations);
+
+ const double primal_loss = loss_updater.ComputePrimalLoss(
+ per_example_data.wx, example_label, example_weight);
+
+ const double dual_loss = loss_updater.ComputeDualLoss(
+ data.dual, example_label, example_weight);
+
+ const double new_dual = loss_updater.ComputeUpdatedDual(
+ example_label, example_weight, data.dual, per_example_data.wx,
+ per_example_data.norm, primal_loss, dual_loss);
+
+ // Compute new weights.
+ const double bounded_dual_delta = (new_dual - data.dual) * example_weight;
+ UpdateDeltaWeights(
+ example_index, sparse_examples_by_group, dense_features_by_group,
+ bounded_dual_delta, regularizations.symmetric_l2,
+ sparse_delta_weights_by_group, dense_delta_weights_by_group);
+
+ // Update example data.
+ data.dual = new_dual;
+ data.primal_loss = primal_loss;
+ data.dual_loss = dual_loss;
+ data.example_weight = example_weight;
+ data_by_example->Set(example_key, data);
+ }
+ // TODO(rohananil): We may in the future want to make the primal-dual
+ // relationship consistent as our current updates are not
+ // transactional.
+ };
+ // TODO(rohananil): Current multiplier 100000 works well empirically
+ // but perhaps we can tune it better.
+ const int64 kCostPerUnit = 100000 * (sparse_examples_by_group.size() +
+ dense_features_by_group.size());
+ Shard(worker_threads.num_threads, worker_threads.workers,
+ example_indices.size(), kCostPerUnit, train_step);
+ return train_step_status;
+}
+
} // namespace
class SdcaSolver : public OpKernel {
@@ -388,21 +468,11 @@ class SdcaSolver : public OpKernel {
string loss_type;
OP_REQUIRES_OK(context, context->GetAttr("loss_type", &loss_type));
if (loss_type == "logistic_loss") {
- compute_dual_loss_ = logistic_loss::ComputeDualLoss;
- compute_primal_loss_ = logistic_loss::ComputePrimalLoss;
- compute_dual_update_ = logistic_loss::ComputeUpdatedDual;
- convert_label_ = logistic_loss::ConvertLabel;
+ loss_updater_.reset(new LogisticLossUpdater);
} else if (loss_type == "squared_loss") {
- compute_dual_loss_ = squared_loss::ComputeDualLoss;
- compute_primal_loss_ = squared_loss::ComputePrimalLoss;
- compute_dual_update_ = squared_loss::ComputeUpdatedDual;
- convert_label_ = squared_loss::ConvertLabel;
+ loss_updater_.reset(new SquaredLossUpdater);
} else if (loss_type == "hinge_loss") {
- compute_dual_loss_ = hinge_loss::ComputeDualLoss;
- compute_primal_loss_ = hinge_loss::ComputePrimalLoss;
- compute_dual_update_ = hinge_loss::ComputeUpdatedDual;
- // Label conversion is identical for hinge and logistic loss.
- convert_label_ = logistic_loss::ConvertLabel;
+ loss_updater_.reset(new HingeLossUpdater);
} else {
OP_REQUIRES(context, false, errors::InvalidArgument(
"Unsupported loss type: ", loss_type));
@@ -424,8 +494,16 @@ class SdcaSolver : public OpKernel {
regularizations_.symmetric_l2 =
std::max(regularizations_.symmetric_l2, 1.0f);
- OP_REQUIRES_OK(context, context->GetAttr("duality_gap_threshold",
- &duality_gap_threshold_));
+ OP_REQUIRES_OK(context, context->GetAttr("num_inner_iterations",
+ &num_inner_iterations_));
+
+ // TODO(rohananil): Provide emperical evidence for this. It is better to run
+ // more than one iteration on single mini-batch as we want to spend more
+ // time in compute. SDCA works better with larger mini batches and there
+ // is also recent work that shows its better to reuse old samples than train
+ // on new samples. See: http://arxiv.org/abs/1602.02136.
+ num_inner_iterations_ =
+ std::max(num_inner_iterations_, static_cast<int64>(2));
OP_REQUIRES_OK(context, context->GetAttr("container", &container_));
OP_REQUIRES_OK(context, context->GetAttr("solver_uuid", &solver_uuid_));
}
@@ -448,7 +526,7 @@ class SdcaSolver : public OpKernel {
}));
OP_REQUIRES(
context, !data_by_example->RefCountIsOne(),
- errors::Internal("Expected shared-ownership of duals_by_example."));
+ errors::Internal("Expected shared-ownership of data_by_example."));
const Tensor* example_weights_t;
OP_REQUIRES_OK(context,
@@ -467,14 +545,6 @@ class SdcaSolver : public OpKernel {
errors::InvalidArgument("No weighted examples in ",
num_examples, " training examples"));
- Tensor primal_loss_t;
- OP_REQUIRES_OK(context,
- context->mutable_input("primal_loss", &primal_loss_t,
- /*lock_held=*/true));
- OP_REQUIRES(context, TensorShapeUtils::IsScalar(primal_loss_t.shape()),
- errors::InvalidArgument("primal_loss should be a scalar."));
- auto primal_loss = primal_loss_t.scalar<double>();
-
OpInputList dense_features_inputs;
OP_REQUIRES_OK(
context, context->input_list("dense_features", &dense_features_inputs));
@@ -527,6 +597,7 @@ class SdcaSolver : public OpKernel {
OpMutableInputList sparse_weights_inputs;
OP_REQUIRES_OK(context, context->mutable_input_list(
"sparse_weights", &sparse_weights_inputs));
+
WeightsByGroup sparse_weights_by_group =
MakeWeightsFrom(&sparse_weights_inputs);
std::vector<Tensor> sparse_delta_weights_by_group_backing_store =
@@ -534,11 +605,10 @@ class SdcaSolver : public OpKernel {
WeightsByGroup sparse_delta_weights_by_group =
MakeDeltaWeightsFrom(&sparse_delta_weights_by_group_backing_store);
- // TODO(rohananil): Remove the code duplication between sparse and
- // dense weights.
OpMutableInputList dense_weights_inputs;
OP_REQUIRES_OK(context, context->mutable_input_list("dense_weights",
&dense_weights_inputs));
+
WeightsByGroup dense_weights_by_group =
MakeWeightsFrom(&dense_weights_inputs);
std::vector<Tensor> dense_delta_weights_by_group_backing_store =
@@ -563,140 +633,165 @@ class SdcaSolver : public OpKernel {
*context->device()->tensorflow_cpu_worker_threads(),
&sparse_examples_by_group));
- // Those will be shuffled below at each iteration and processed in a
- // partitioned fashion across multiple threads.
- std::vector<int64> example_indices(num_examples);
- std::iota(example_indices.begin(), example_indices.end(), 0);
-
- std::random_device random_device;
- std::mt19937 random_generator(random_device());
-
- // Break when duality gap |P(w) - D(alpha)| is less than
- // duality_gap_threshold_
- double total_duality_gap = std::numeric_limits<double>::max();
- while ((total_duality_gap / weighted_examples) > duality_gap_threshold_) {
- std::atomic<double> total_primal_loss(0);
- std::atomic<double> total_dual_loss(0);
- SetZeroDeltaWeights(&sparse_delta_weights_by_group,
- &dense_delta_weights_by_group);
-
- // Compute regularization loss at the start of the iteration so that
- // we can compute an exact value of duality gap (for the weights from
- // the previous iteration).
- const RegularizationLoss regularization_loss = ComputeRegularizationLoss(
- sparse_weights_by_group, dense_weights_by_group, regularizations_);
+ SetZeroDeltaWeights(&sparse_delta_weights_by_group,
+ &dense_delta_weights_by_group);
+ // Examples are shuffled below at each iteration and processed in a
+ // partitioned fashion across multiple threads.
+ // TODO(rohananil): We may want to avoid shuffling inside the op
+ // but instead shuffle outside as part of the input reader. Re-evaluate
+ // once we have more data on how this op is used.
+ const std::vector<int64> example_indices = [num_examples]() {
+ std::vector<int64> result(num_examples);
+ std::iota(result.begin(), result.end(), 0);
+ std::random_device random_device;
+ std::mt19937 random_generator(random_device());
// Randomize the examples across iterations for faster convergence.
- std::shuffle(example_indices.begin(), example_indices.end(),
- random_generator);
-
- {
- // Process examples in parallel, in a partitioned fashion.
- mutex mu;
- Status update_status GUARDED_BY(mu);
- auto update_partition = [&](const int64 begin, const int64 end) {
- double dual_loss_on_example_subset = 0;
- double primal_loss_on_example_subset = 0;
- for (int64 offset = begin; offset < end; ++offset) {
- // Get example id, label, and weight.
- const int64 example_index = example_indices[offset];
- const DataByExample::Key example_key =
- DataByExample::MakeKey(example_ids(example_index));
- DataByExample::Data data = data_by_example->Get(example_key);
- const double example_weight = example_weights(example_index);
- float example_label = example_labels(example_index);
- const Status conversion_status = convert_label_(&example_label);
- if (!conversion_status.ok()) {
- mutex_lock l(mu);
- update_status = conversion_status;
- // Return from this worker thread - the calling thread is
- // responsible for checking context status and returning on error.
- return;
- }
-
- // Compute wx, example norm weighted by regularization, dual loss,
- // primal loss.
- const PerExampleData per_example_data =
- ComputeWxAndWeightedExampleNorm(
- example_index, sparse_weights_by_group,
- sparse_delta_weights_by_group, sparse_examples_by_group,
- dense_weights_by_group, dense_delta_weights_by_group,
- dense_features_by_group, regularizations_);
- // Compute primal based on the previous iteration.
- primal_loss_on_example_subset += compute_primal_loss_(
- per_example_data.old_wx, example_label, example_weight);
-
- const double primal_loss = compute_primal_loss_(
- per_example_data.wx, example_label, example_weight);
-
- const double dual_loss =
- compute_dual_loss_(data.dual, example_label, example_weight);
- dual_loss_on_example_subset += dual_loss;
-
- const double new_dual = compute_dual_update_(
- example_label, example_weight, data.dual, per_example_data.wx,
- per_example_data.norm, primal_loss, dual_loss);
-
- // Compute new weights.
- const double bounded_dual_delta =
- (new_dual - data.dual) * example_weight;
- UpdateDeltaWeights(example_index, sparse_examples_by_group,
- dense_features_by_group, bounded_dual_delta,
- regularizations_.symmetric_l2,
- &sparse_delta_weights_by_group,
- &dense_delta_weights_by_group);
-
- // Update dual variable.
- data.dual = new_dual;
- data_by_example->Set(example_key, data);
- }
- AtomicAdd(primal_loss_on_example_subset, &total_primal_loss);
- AtomicAdd(dual_loss_on_example_subset, &total_dual_loss);
- // TODO(rohananil): We may in the future want to make the primal-dual
- // relationship consistent as our current updates are not
- // transactional.
- };
- const DeviceBase::CpuWorkerThreads& worker_threads =
- *context->device()->tensorflow_cpu_worker_threads();
- // TODO(katsiapis): Current multiplier (100,000) works well empirically
- // but perhaps we can tune it better.
- const int64 kCostPerUnit =
- 100000 * (num_sparse_features_ + num_dense_features_);
- Shard(worker_threads.num_threads, worker_threads.workers, num_examples,
- kCostPerUnit, update_partition);
- OP_REQUIRES_OK(context, update_status);
- }
+ std::shuffle(result.begin(), result.end(), random_generator);
+ return result;
+ }();
- total_duality_gap = total_primal_loss.load() + total_dual_loss.load() +
- regularization_loss.l1_loss +
- regularization_loss.l2_loss;
- primal_loss() = (total_primal_loss.load() + regularization_loss.l1_loss +
- regularization_loss.l2_loss) /
- weighted_examples;
- AddDeltaWeights(sparse_delta_weights_by_group, &sparse_weights_by_group);
- AddDeltaWeights(dense_delta_weights_by_group, &dense_weights_by_group);
+ // TODO(rohananil): Provide emperical evidence for this. It is better to run
+ // more than one iteration on single mini-batch as we want to spend more
+ // time in compute. SDCA works better with larger mini batches and there
+ // is also recent work that shows its better to reuse old samples than train
+ // on new samples. See: http://arxiv.org/abs/1602.02136.
+ for (int64 i = 0; i < num_inner_iterations_; ++i) {
+ OP_REQUIRES_OK(
+ context,
+ RunTrainStepsForMiniBatch(
+ example_indices, example_ids, example_labels, example_weights,
+ *context->device()->tensorflow_cpu_worker_threads(),
+ regularizations_, sparse_weights_by_group,
+ sparse_examples_by_group, dense_weights_by_group,
+ dense_features_by_group, *loss_updater_,
+ &sparse_delta_weights_by_group, &dense_delta_weights_by_group,
+ data_by_example));
}
- ShrinkWeights(regularizations_, &sparse_weights_by_group,
- &dense_weights_by_group);
+
+ // TODO(rohananil): Change to atomic<float> as we are not exposing delta
+ // weights to users. This will allows us to simplify the code that currently
+ // keeps a backing store for the tensors. This also avoids losing updates
+ // when done in a lockless way.
+ AddDeltaWeights(sparse_delta_weights_by_group, &sparse_weights_by_group);
+ AddDeltaWeights(dense_delta_weights_by_group, &dense_weights_by_group);
// TODO(katsiapis): Use core::ScopedUnref once it's moved out of internal.
data_by_example->Unref();
}
private:
- std::function<decltype(logistic_loss::ComputeDualLoss)> compute_dual_loss_;
- std::function<decltype(logistic_loss::ComputePrimalLoss)>
- compute_primal_loss_;
- std::function<decltype(logistic_loss::ComputeUpdatedDual)>
- compute_dual_update_;
- std::function<decltype(logistic_loss::ConvertLabel)> convert_label_;
+ // TODO(rohananil): We could use the type-constraint on loss_type, and
+ // template the entire class to avoid the virtual table lookup penalty in
+ // the inner loop.
+ std::unique_ptr<DualLossUpdater> loss_updater_;
int64 num_sparse_features_;
int64 num_dense_features_;
Regularizations regularizations_;
- float duality_gap_threshold_;
+ int64 num_inner_iterations_;
string container_;
string solver_uuid_;
};
REGISTER_KERNEL_BUILDER(Name("SdcaSolver").Device(DEVICE_CPU), SdcaSolver);
+class SdcaShrinkL1 : public OpKernel {
+ public:
+ explicit SdcaShrinkL1(OpKernelConstruction* context) : OpKernel(context) {
+ OP_REQUIRES_OK(context,
+ context->GetAttr("l1", &regularizations_.symmetric_l1));
+ OP_REQUIRES_OK(context,
+ context->GetAttr("l2", &regularizations_.symmetric_l2));
+ // We enforce a minimal l2, required by the algorithm.
+ regularizations_.symmetric_l2 =
+ std::max(regularizations_.symmetric_l2, 1.0f);
+ }
+
+ void Compute(OpKernelContext* context) override {
+ OpMutableInputList sparse_weights_inputs;
+ OP_REQUIRES_OK(context, context->mutable_input_list(
+ "sparse_weights", &sparse_weights_inputs));
+ WeightsByGroup sparse_weights_by_group =
+ MakeWeightsFrom(&sparse_weights_inputs);
+
+ OpMutableInputList dense_weights_inputs;
+ OP_REQUIRES_OK(context, context->mutable_input_list("dense_weights",
+ &dense_weights_inputs));
+ WeightsByGroup dense_weights_by_group =
+ MakeWeightsFrom(&dense_weights_inputs);
+
+ ShrinkWeights(regularizations_, &sparse_weights_by_group,
+ &dense_weights_by_group);
+ }
+
+ private:
+ Regularizations regularizations_;
+};
+REGISTER_KERNEL_BUILDER(Name("SdcaShrinkL1").Device(DEVICE_CPU), SdcaShrinkL1);
+
+class ComputeDualityGap : public OpKernel {
+ public:
+ explicit ComputeDualityGap(OpKernelConstruction* context)
+ : OpKernel(context) {
+ // TODO(rohananil): Refactor grabbing common attributes across ops related
+ // to sdca.
+ OP_REQUIRES_OK(context,
+ context->GetAttr("l1", &regularizations_.symmetric_l1));
+ OP_REQUIRES_OK(context,
+ context->GetAttr("l2", &regularizations_.symmetric_l2));
+ // We enforce a minimal l2, required by the algorithm.
+ regularizations_.symmetric_l2 =
+ std::max(regularizations_.symmetric_l2, 1.0f);
+ OP_REQUIRES_OK(context, context->GetAttr("container", &container_));
+ OP_REQUIRES_OK(context, context->GetAttr("solver_uuid", &solver_uuid_));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ DataByExample* data_by_example = nullptr;
+ OP_REQUIRES_OK(context, context->resource_manager()->Lookup<DataByExample>(
+ container_, solver_uuid_, &data_by_example));
+ OP_REQUIRES(
+ context, !data_by_example->RefCountIsOne(),
+ errors::Internal("Expected shared-ownership of data_by_example."));
+
+ OpMutableInputList sparse_weights_inputs;
+ OP_REQUIRES_OK(context, context->mutable_input_list(
+ "sparse_weights", &sparse_weights_inputs));
+ WeightsByGroup sparse_weights_by_group =
+ MakeWeightsFrom(&sparse_weights_inputs);
+
+ OpMutableInputList dense_weights_inputs;
+ OP_REQUIRES_OK(context, context->mutable_input_list("dense_weights",
+ &dense_weights_inputs));
+ WeightsByGroup dense_weights_by_group =
+ MakeWeightsFrom(&dense_weights_inputs);
+
+ double example_weight_sum = 0;
+ double total_duality_gap = 0;
+ OP_REQUIRES_OK(context,
+ data_by_example->Visit([&](const DataByExample::Data& data) {
+ example_weight_sum += data.example_weight;
+ total_duality_gap += data.primal_loss + data.dual_loss;
+ }));
+
+ const RegularizationLoss regularization_loss = ComputeRegularizationLoss(
+ sparse_weights_by_group, dense_weights_by_group, regularizations_);
+ total_duality_gap +=
+ regularization_loss.l2_loss + regularization_loss.l1_loss;
+
+ Tensor* duality_gap_t = nullptr;
+ OP_REQUIRES_OK(context,
+ context->allocate_output("duality_gap", {}, &duality_gap_t));
+ duality_gap_t->scalar<float>()() = total_duality_gap / example_weight_sum;
+
+ // TODO(katsiapis): Use core::ScopedUnref once it's moved out of internal.
+ data_by_example->Unref();
+ }
+
+ private:
+ Regularizations regularizations_;
+ string container_;
+ string solver_uuid_;
+};
+REGISTER_KERNEL_BUILDER(Name("ComputeDualityGap").Device(DEVICE_CPU),
+ ComputeDualityGap);
} // namespace tensorflow
diff --git a/tensorflow/contrib/linear_optimizer/kernels/squared-loss.h b/tensorflow/contrib/linear_optimizer/kernels/squared-loss.h
index 94d3c6f8c4..fc37a98a4f 100644
--- a/tensorflow/contrib/linear_optimizer/kernels/squared-loss.h
+++ b/tensorflow/contrib/linear_optimizer/kernels/squared-loss.h
@@ -19,19 +19,19 @@ limitations under the License.
#include <algorithm>
#include <cmath>
-#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/contrib/linear_optimizer/kernels/loss.h"
namespace tensorflow {
-struct squared_loss {
+
+class SquaredLossUpdater : public DualLossUpdater {
+ public:
// Closed form solution that decreases the dual squared loss.
// See page 23 of http://arxiv.org/pdf/1309.2375v2.pdf
- inline static double ComputeUpdatedDual(const double label,
- const double example_weight,
- const double current_dual,
- const double wx,
- const double weighted_example_norm,
- const double primal_loss_unused,
- const double dual_loss_unused) {
+ double ComputeUpdatedDual(const double label, const double example_weight,
+ const double current_dual, const double wx,
+ const double weighted_example_norm,
+ const double primal_loss_unused,
+ const double dual_loss_unused) const final {
const double delta_numerator = (label - current_dual - wx) * example_weight;
const double delta_denominator =
1 + weighted_example_norm * example_weight * example_weight * 0.5;
@@ -40,27 +40,26 @@ struct squared_loss {
// Dual of squared loss function.
// https://en.wikipedia.org/wiki/Convex_conjugate
- inline static double ComputeDualLoss(const double current_dual,
- const double example_label,
- const double example_weight) {
+ double ComputeDualLoss(const double current_dual, const double example_label,
+ const double example_weight) const final {
// Dual of the squared loss function = b * (y + b/2), where b is the
// dual variable and y is the label. This is Dual(-b).
return current_dual * (0.5 * current_dual - example_label) * example_weight;
}
// Squared loss for linear regression.
- inline static double ComputePrimalLoss(const double wx,
- const double example_label,
- const double example_weight) {
+ double ComputePrimalLoss(const double wx, const double example_label,
+ const double example_weight) const final {
const double error = wx - example_label;
return error * error * example_weight * 0.5;
}
// Labels don't require conversion for linear regression.
- inline static Status ConvertLabel(float* const example_label) {
+ Status ConvertLabel(float* const example_label) const final {
return Status::OK();
}
};
+
} // namespace tensorflow
#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LINEAR_OPTIMIZER_KERNELS_SQUARED_LOSS_H_
diff --git a/tensorflow/contrib/linear_optimizer/ops/sdca_ops.cc b/tensorflow/contrib/linear_optimizer/ops/sdca_ops.cc
index fb3c8154dd..ff2bae8fea 100644
--- a/tensorflow/contrib/linear_optimizer/ops/sdca_ops.cc
+++ b/tensorflow/contrib/linear_optimizer/ops/sdca_ops.cc
@@ -23,8 +23,8 @@ REGISTER_OP("SdcaSolver")
.Attr("num_sparse_features: int >= 0")
.Attr("num_dense_features: int >= 0")
.Attr("l1: float >= 0")
- .Attr("l2: float >= 0")
- .Attr("duality_gap_threshold: float = 0.01")
+ .Attr("l2: float >= 1")
+ .Attr("num_inner_iterations: int >= 2")
.Attr("container: string")
.Attr("solver_uuid: string")
.Input("sparse_features_indices: num_sparse_features * int64")
@@ -35,7 +35,6 @@ REGISTER_OP("SdcaSolver")
.Input("example_ids: string")
.Input("sparse_weights: Ref(num_sparse_features * float)")
.Input("dense_weights: Ref(num_dense_features * float)")
- .Input("primal_loss: Ref(double)")
.Doc(R"doc(
Stochastic Dual Coordinate Ascent (SDCA) optimizer for linear models with
L1 + L2 regularization. As global optimization objective is strongly-convex, the
@@ -54,7 +53,7 @@ num_sparse_features: Number of sparse feature groups to train on.
num_dense_features: Number of dense feature groups to train on.
l1: Symmetric l1 regularization strength.
l2: Symmetric l2 regularization strength.
-duality_gap_threshold: Gap threshold at which we should stop training.
+num_inner_iterations: Number of iterations per mini-batch.
container: Name of the Container that stores data across invocations of this
Kernel. Together with SolverUUID form an isolation unit for this solver.
solver_uuid: Universally Unique Identifier for this solver.
@@ -75,4 +74,53 @@ dense_weights: a list of vectors where the value is the weight associated with
a dense feature group.
)doc");
+REGISTER_OP("SdcaShrinkL1")
+ .Attr("num_sparse_features: int >= 0")
+ .Attr("num_dense_features: int >= 0")
+ .Attr("l1: float >= 0")
+ .Attr("l2: float >= 1")
+ .Input("sparse_weights: Ref(num_sparse_features * float)")
+ .Input("dense_weights: Ref(num_dense_features * float)")
+ .Doc(R"doc(
+Applies L1 regularization shrink step on the parameters.
+
+num_sparse_features: Number of sparse feature groups to train on.
+num_dense_features: Number of dense feature groups to train on.
+l1: Symmetric l1 regularization strength.
+l2: Symmetric l2 regularization strength.
+sparse_weights: a list of vectors where each value is the weight associated with
+ a feature index.
+dense_weights: a list of vectors where the value is the weight associated with
+ a dense feature group.
+)doc");
+
+// TODO(katsiapis): We should expand this scope of this op to compute other
+// statistics about the data.
+REGISTER_OP("ComputeDualityGap")
+ .Attr("num_sparse_features: int >= 0")
+ .Attr("num_dense_features: int >= 0")
+ .Attr("l1: float >= 0")
+ .Attr("l2: float >= 1")
+ .Attr("container: string")
+ .Attr("solver_uuid: string")
+ .Input("sparse_weights: Ref(num_sparse_features * float)")
+ .Input("dense_weights: Ref(num_dense_features * float)")
+ .Output("duality_gap: float")
+ .Doc(R"doc(
+Computes duality gap over all examples seen by the optimizer.
+
+num_sparse_features: Number of sparse feature groups to train on.
+num_dense_features: Number of dense feature groups to train on.
+l1: Symmetric l1 regularization strength.
+l2: Symmetric l2 regularization strength.
+container: Name of the Container that stores data across invocations of this
+ Kernel. Together with SolverUUID form an isolation unit for this solver.
+solver_uuid: Universally Unique Identifier for this solver.
+sparse_weights: a list of vectors where each value is the weight associated with
+ a feature index.
+dense_weights: a list of vectors where the value is the weight associated with
+ a dense feature group.
+duality_gap: duality gap over all examples seen by the optimizer.
+)doc");
+
} // namespace tensorflow
diff --git a/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py b/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py
index 71994fd51e..13968457f7 100644
--- a/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py
+++ b/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py
@@ -89,11 +89,8 @@ def make_variable_dict(max_age, max_gender):
# examples_dict.
age_weights = tf.Variable(tf.zeros([max_age + 1], dtype=tf.float32))
gender_weights = tf.Variable(tf.zeros([max_gender + 1], dtype=tf.float32))
- primal_loss = tf.Variable(tf.zeros([], dtype=tf.float64))
return dict(sparse_features_weights=[age_weights, gender_weights],
- dense_features_weights=[],
- primal_loss=primal_loss)
-
+ dense_features_weights=[])
def make_dense_variable_dict(num_dense_features, num_examples):
feature_weights = ([
@@ -121,6 +118,7 @@ def get_binary_predictions_for_hinge(predictions):
all_ones = tf.ones_like(predictions)
return tf.add(tf.sign(predictions), all_ones) / 2
+
# Setup the single container shared across all tests. This is testing proper
# isolation across optimizers instantiated in each of the tests below.
CONTAINER = uuid.uuid4().hex
@@ -155,22 +153,32 @@ class SdcaOptimizerTest(TensorFlowTestCase):
variables = make_variable_dict(1, 1)
options = dict(symmetric_l2_regularization=1,
symmetric_l1_regularization=0,
- loss_type='logistic_loss',
- prior=0.0)
- tf.initialize_all_variables().run()
+ loss_type='logistic_loss')
+
lr = SdcaModel(CONTAINER, examples, variables, options)
+ tf.initialize_all_variables().run()
unregularized_loss = lr.unregularized_loss(examples)
loss = lr.regularized_loss(examples)
predictions = lr.predictions(examples)
self.assertAllClose(0.693147, unregularized_loss.eval())
self.assertAllClose(0.693147, loss.eval())
- lr.minimize().run()
- self.assertAllClose(0.395226, unregularized_loss.eval(),
- rtol=3e-2, atol=3e-2)
- self.assertAllClose(0.657446, loss.eval(),
- rtol=3e-2, atol=3e-2)
+ for _ in xrange(5):
+ lr.minimize().run()
+ # The high tolerance in unregularized_loss comparisons is due to the
+ # fact that it's possible to trade off unregularized_loss vs.
+ # regularization and still have a sum that is quite close to the
+ # optimal regularized_loss value. SDCA's duality gap only ensures that
+ # the regularized_loss is within 0.01 of optimal.
+ # 0.525457 is the optimal regularized_loss.
+ # 0.411608 is the unregularized_loss at that optimum.
+ self.assertAllClose(0.411608, unregularized_loss.eval(), rtol=0.11)
+ self.assertAllClose(0.525457, loss.eval(), atol=0.01)
predicted_labels = get_binary_predictions_for_logistic(predictions)
self.assertAllEqual([0, 1], predicted_labels.eval())
+ self.assertAllClose(0.01,
+ lr.approximate_duality_gap().eval(),
+ rtol=1e-2,
+ atol=1e-2)
def testSomeUnweightedExamples(self):
# Setup test data with 4 examples, but should produce the same
@@ -201,18 +209,22 @@ class SdcaOptimizerTest(TensorFlowTestCase):
options = dict(symmetric_l2_regularization=1,
symmetric_l1_regularization=0,
loss_type='logistic_loss')
- tf.initialize_all_variables().run()
+
lr = SdcaModel(CONTAINER, examples, variables, options)
+ tf.initialize_all_variables().run()
unregularized_loss = lr.unregularized_loss(examples)
loss = lr.regularized_loss(examples)
predictions = lr.predictions(examples)
- lr.minimize().run()
- self.assertAllClose(0.395226, unregularized_loss.eval(),
- rtol=3e-2, atol=3e-2)
- self.assertAllClose(0.657446, loss.eval(),
- rtol=3e-2, atol=3e-2)
+ for _ in xrange(5):
+ lr.minimize().run()
+ self.assertAllClose(0.411608, unregularized_loss.eval(), rtol=0.12)
+ self.assertAllClose(0.525457, loss.eval(), atol=0.01)
predicted_labels = get_binary_predictions_for_logistic(predictions)
self.assertAllClose([0, 1, 1, 1], predicted_labels.eval())
+ self.assertAllClose(0.01,
+ lr.approximate_duality_gap().eval(),
+ rtol=1e-2,
+ atol=1e-2)
def testFractionalLogisticExample(self):
# Setup test data with 1 positive, and 1 mostly-negative example.
@@ -231,10 +243,12 @@ class SdcaOptimizerTest(TensorFlowTestCase):
options = dict(symmetric_l2_regularization=1,
symmetric_l1_regularization=0,
loss_type='logistic_loss')
+
+ lr = SdcaModel(CONTAINER, examples, variables, options)
tf.initialize_all_variables().run()
with self.assertRaisesOpError(
'Only labels of 0.0 or 1.0 are supported right now.'):
- SdcaModel(CONTAINER, examples, variables, options).minimize().run()
+ lr.minimize().run()
def testNoWeightedExamples(self):
# Setup test data with 1 positive, and 1 negative example.
@@ -254,8 +268,9 @@ class SdcaOptimizerTest(TensorFlowTestCase):
options = dict(symmetric_l2_regularization=1,
symmetric_l1_regularization=0,
loss_type='logistic_loss')
- tf.initialize_all_variables().run()
+
lr = SdcaModel(CONTAINER, examples, variables, options)
+ tf.initialize_all_variables().run()
self.assertAllClose([0.5, 0.5], lr.predictions(examples).eval())
with self.assertRaisesOpError(
'No weighted examples in 2 training examples'):
@@ -281,8 +296,9 @@ class SdcaOptimizerTest(TensorFlowTestCase):
options = dict(symmetric_l2_regularization=0.5,
symmetric_l1_regularization=0,
loss_type='logistic_loss')
- tf.initialize_all_variables().run()
+
lr = SdcaModel(CONTAINER, examples, variables, options)
+ tf.initialize_all_variables().run()
self.assertAllClose([0.5, 0.5], lr.predictions(examples).eval())
with self.assertRaisesOpError('Detected 1 duplicates in example_ids'):
lr.minimize().run()
@@ -310,19 +326,25 @@ class SdcaOptimizerTest(TensorFlowTestCase):
variables = make_variable_dict(3, 1)
options = dict(symmetric_l2_regularization=1,
symmetric_l1_regularization=0,
- loss_type='logistic_loss',
- prior=-1.09861)
- tf.initialize_all_variables().run()
+ loss_type='logistic_loss')
+
lr = SdcaModel(CONTAINER, examples, variables, options)
+ tf.initialize_all_variables().run()
unregularized_loss = lr.unregularized_loss(examples)
loss = lr.regularized_loss(examples)
predictions = lr.predictions(examples)
- lr.minimize().run()
- self.assertAllClose(0.331710, unregularized_loss.eval(),
- rtol=3e-2, atol=3e-2)
- self.assertAllClose(0.591295, loss.eval(), rtol=3e-2, atol=3e-2)
+ for _ in xrange(5):
+ lr.minimize().run()
+ self.assertAllClose(0.226487 + 0.102902,
+ unregularized_loss.eval(),
+ rtol=0.08)
+ self.assertAllClose(0.328394 + 0.131364, loss.eval(), atol=0.01)
predicted_labels = get_binary_predictions_for_logistic(predictions)
self.assertAllEqual([0, 0, 0, 1], predicted_labels.eval())
+ self.assertAllClose(0.01,
+ lr.approximate_duality_gap().eval(),
+ rtol=1e-2,
+ atol=1e-2)
def testImbalancedWithExampleWeights(self):
# Setup test data with 1 positive, and 1 negative example.
@@ -341,17 +363,22 @@ class SdcaOptimizerTest(TensorFlowTestCase):
options = dict(symmetric_l2_regularization=1,
symmetric_l1_regularization=0,
loss_type='logistic_loss')
- tf.initialize_all_variables().run()
+
lr = SdcaModel(CONTAINER, examples, variables, options)
+ tf.initialize_all_variables().run()
unregularized_loss = lr.unregularized_loss(examples)
loss = lr.regularized_loss(examples)
predictions = lr.predictions(examples)
- lr.minimize().run()
- self.assertAllClose(0.266189, unregularized_loss.eval(),
- rtol=3e-2, atol=3e-2)
- self.assertAllClose(0.571912, loss.eval(), rtol=3e-2, atol=3e-2)
+ for _ in xrange(5):
+ lr.minimize().run()
+ self.assertAllClose(0.284860, unregularized_loss.eval(), rtol=0.08)
+ self.assertAllClose(0.408044, loss.eval(), atol=0.012)
predicted_labels = get_binary_predictions_for_logistic(predictions)
self.assertAllEqual([0, 1], predicted_labels.eval())
+ self.assertAllClose(0.01,
+ lr.approximate_duality_gap().eval(),
+ rtol=1e-2,
+ atol=1e-2)
def testInstancesOfOneClassOnly(self):
# Setup test data with 1 positive (ignored), and 1 negative example.
@@ -367,24 +394,25 @@ class SdcaOptimizerTest(TensorFlowTestCase):
with self._single_threaded_test_session():
examples = make_example_dict(example_protos, example_weights)
variables = make_variable_dict(1, 1)
- options = dict(symmetric_l2_regularization=0.25,
+ options = dict(symmetric_l2_regularization=1,
symmetric_l1_regularization=0,
loss_type='logistic_loss')
- tf.initialize_all_variables().run()
+
lr = SdcaModel(CONTAINER, examples, variables, options)
+ tf.initialize_all_variables().run()
unregularized_loss = lr.unregularized_loss(examples)
loss = lr.regularized_loss(examples)
- prediction = lr.predictions(examples)
- lr.minimize().run()
- self.assertAllClose(0.395226,
- unregularized_loss.eval(),
- rtol=3e-2,
- atol=3e-2)
- self.assertAllClose(0.460781, loss.eval(), rtol=3e-2, atol=3e-2)
- predicted_labels = tf.cast(
- tf.greater_equal(prediction,
- tf.ones_like(prediction) * 0.5), tf.float32)
+ predictions = lr.predictions(examples)
+ for _ in xrange(5):
+ lr.minimize().run()
+ self.assertAllClose(0.411608, unregularized_loss.eval(), rtol=0.12)
+ self.assertAllClose(0.525457, loss.eval(), atol=0.01)
+ predicted_labels = get_binary_predictions_for_logistic(predictions)
self.assertAllEqual([0, 0], predicted_labels.eval())
+ self.assertAllClose(0.01,
+ lr.approximate_duality_gap().eval(),
+ rtol=1e-2,
+ atol=1e-2)
def testSimpleLinear(self):
# Setup test data
@@ -402,21 +430,26 @@ class SdcaOptimizerTest(TensorFlowTestCase):
variables = make_variable_dict(1, 1)
options = dict(symmetric_l2_regularization=1,
symmetric_l1_regularization=0,
- loss_type='squared_loss',
- prior=0.0)
- tf.initialize_all_variables().run()
+ loss_type='squared_loss')
+
lr = SdcaModel(CONTAINER, examples, variables, options)
- prediction = lr.predictions(examples)
+ tf.initialize_all_variables().run()
+ predictions = lr.predictions(examples)
- lr.minimize().run()
+ for _ in xrange(20):
+ lr.minimize().run()
# Predictions should be 2/3 of label due to minimizing regularized loss:
# (label - 2 * weight)^2 / 2 + L2 * 2 * weight^2
self.assertAllClose([-20.0 / 3.0, 28.0 / 3.0],
- prediction.eval(),
+ predictions.eval(),
rtol=0.005)
+ self.assertAllClose(0.01,
+ lr.approximate_duality_gap().eval(),
+ rtol=1e-2,
+ atol=1e-2)
- def testLinearRegularization(self):
+ def testLinearL2Regularization(self):
# Setup test data
example_protos = [
# 2 identical examples
@@ -440,13 +473,14 @@ class SdcaOptimizerTest(TensorFlowTestCase):
variables = make_variable_dict(1, 1)
options = dict(symmetric_l2_regularization=16,
symmetric_l1_regularization=0,
- loss_type='squared_loss',
- prior=0.0)
- tf.initialize_all_variables().run()
+ loss_type='squared_loss')
+
lr = SdcaModel(CONTAINER, examples, variables, options)
- prediction = lr.predictions(examples)
+ tf.initialize_all_variables().run()
+ predictions = lr.predictions(examples)
- lr.minimize().run()
+ for _ in xrange(5):
+ lr.minimize().run()
# Predictions should be 1/5 of label due to minimizing regularized loss:
# (label - 2 * weight)^2 + L2 * 16 * weight^2
@@ -454,9 +488,42 @@ class SdcaOptimizerTest(TensorFlowTestCase):
optimal2 = 14.0 / 5.0
self.assertAllClose(
[optimal1, optimal1, optimal2, optimal2],
- prediction.eval(),
+ predictions.eval(),
rtol=0.01)
+ def testLinearL1Regularization(self):
+ # Setup test data
+ example_protos = [
+ make_example_proto(
+ {'age': [0],
+ 'gender': [0]}, -10.0),
+ make_example_proto(
+ {'age': [1],
+ 'gender': [1]}, 14.0),
+ ]
+ example_weights = [1.0, 1.0]
+ with self._single_threaded_test_session():
+ examples = make_example_dict(example_protos, example_weights)
+ variables = make_variable_dict(1, 1)
+ options = dict(symmetric_l2_regularization=1.0,
+ symmetric_l1_regularization=4.0,
+ loss_type='squared_loss')
+ lr = SdcaModel(CONTAINER, examples, variables, options)
+ tf.initialize_all_variables().run()
+ prediction = lr.predictions(examples)
+ loss = lr.regularized_loss(examples)
+
+ for _ in xrange(5):
+ lr.minimize().run()
+
+ # Predictions should be -4.0, 48/5 due to minimizing regularized loss:
+ # (label - 2 * weight)^2 / 2 + L2 * 2 * weight^2 + L1 * 4 * weight
+ self.assertAllClose([-4.0, 20.0 / 3.0], prediction.eval(), rtol=0.08)
+
+ # Loss should be the sum of the regularized loss value from above per
+ # example after plugging in the optimal weights.
+ self.assertAllClose(308.0 / 6.0, loss.eval(), atol=0.01)
+
def testLinearFeatureValues(self):
# Setup test data
example_protos = [
@@ -474,18 +541,19 @@ class SdcaOptimizerTest(TensorFlowTestCase):
variables = make_variable_dict(1, 1)
options = dict(symmetric_l2_regularization=1,
symmetric_l1_regularization=0,
- loss_type='squared_loss',
- prior=0.0)
- tf.initialize_all_variables().run()
+ loss_type='squared_loss')
+
lr = SdcaModel(CONTAINER, examples, variables, options)
- prediction = lr.predictions(examples)
+ tf.initialize_all_variables().run()
+ predictions = lr.predictions(examples)
- lr.minimize().run()
+ for _ in xrange(20):
+ lr.minimize().run()
# Predictions should be 8/9 of label due to minimizing regularized loss:
# (label - 2 * 2 * weight)^2 / 2 + L2 * 2 * weight^2
self.assertAllClose([-10.0 * 8 / 9, 14.0 * 8 / 9],
- prediction.eval(),
+ predictions.eval(),
rtol=0.07)
def testLinearDenseFeatures(self):
@@ -497,25 +565,22 @@ class SdcaOptimizerTest(TensorFlowTestCase):
variables = make_dense_variable_dict(2, 2)
options = dict(symmetric_l2_regularization=1,
symmetric_l1_regularization=0,
- loss_type='squared_loss',
- prior=0.0)
- tf.initialize_all_variables().run()
+ loss_type='squared_loss')
lr = SdcaModel(CONTAINER, examples, variables, options)
- prediction = lr.predictions(examples)
+ tf.initialize_all_variables().run()
+ predictions = lr.predictions(examples)
- lr.minimize().run()
+ for _ in xrange(20):
+ lr.minimize().run()
# Predictions should be 4/5 of label due to minimizing regularized loss:
# (label - 2 * weight)^2 / 2 + L2 * weight^2
self.assertAllClose([-10.0 * 4 / 5, 14.0 * 4 / 5],
- prediction.eval(),
+ predictions.eval(),
rtol=0.01)
loss = lr.regularized_loss(examples)
- self.assertAllClose(
- (4.0 + 7.84 + 16.0 + 31.36) / 2,
- loss.eval(),
- rtol=0.01)
+ self.assertAllClose(148.0 / 10.0, loss.eval(), atol=0.01)
def testSimpleHinge(self):
# Setup test data
@@ -533,10 +598,9 @@ class SdcaOptimizerTest(TensorFlowTestCase):
variables = make_variable_dict(1, 1)
options = dict(symmetric_l2_regularization=1.0,
symmetric_l1_regularization=0,
- loss_type='hinge_loss',
- prior=0.0)
- tf.initialize_all_variables().run()
+ loss_type='hinge_loss')
model = SdcaModel(CONTAINER, examples, variables, options)
+ tf.initialize_all_variables().run()
# Before minimization, the weights default to zero. There is no loss due
# to regularization, only unregularized loss which is 0.5 * (1+1) = 1.0.
@@ -551,13 +615,15 @@ class SdcaOptimizerTest(TensorFlowTestCase):
# are 4 sparse weights: 2 for age (say w1, w2) and 2 for gender (say w3
# and w4). Solving the system w1 + w3 = 1.0, w2 + w4 = -1.0 and minimizing
# wrt to \|\vec{w}\|_2, gives w1=w3=1/2 and w2=w4=-1/2. This gives 0.0
- # unregularized loss and 0.5 L2 loss.
- model.minimize().run()
+ # unregularized loss and 0.25 L2 loss.
+ for _ in xrange(5):
+ model.minimize().run()
+
binary_predictions = get_binary_predictions_for_hinge(predictions)
self.assertAllEqual([-1.0, 1.0], predictions.eval())
self.assertAllEqual([0.0, 1.0], binary_predictions.eval())
self.assertAllClose(0.0, unregularized_loss.eval())
- self.assertAllClose(0.5, regularized_loss.eval(), atol=0.05)
+ self.assertAllClose(0.25, regularized_loss.eval(), atol=0.05)
def testHingeDenseFeaturesPerfectlySeparable(self):
with self._single_threaded_test_session():
@@ -569,22 +635,25 @@ class SdcaOptimizerTest(TensorFlowTestCase):
options = dict(symmetric_l2_regularization=1.0,
symmetric_l1_regularization=0,
loss_type='hinge_loss')
- tf.initialize_all_variables().run()
model = SdcaModel(CONTAINER, examples, variables, options)
+ tf.initialize_all_variables().run()
predictions = model.predictions(examples)
binary_predictions = get_binary_predictions_for_hinge(predictions)
- model.minimize().run()
+
+ for _ in xrange(5):
+ model.minimize().run()
+
self.assertAllClose([1.0, -1.0], predictions.eval(), atol=0.05)
self.assertAllClose([1.0, 0.0], binary_predictions.eval())
# (1.0, 1.0) and (1.0, -1.0) are perfectly separable by x-axis (that is,
# the SVM's functional margin >=1), so the unregularized loss is ~0.0.
# There is only loss due to l2-regularization. For these datapoints, it
- # turns out that w_1~=0.0 and w_2~=1.0 which means that l2 loss is ~0.5.
+ # turns out that w_1~=0.0 and w_2~=1.0 which means that l2 loss is ~0.25.
unregularized_loss = model.unregularized_loss(examples)
regularized_loss = model.regularized_loss(examples)
self.assertAllClose(0.0, unregularized_loss.eval(), atol=0.02)
- self.assertAllClose(0.5, regularized_loss.eval(), atol=0.02)
+ self.assertAllClose(0.25, regularized_loss.eval(), atol=0.02)
def testHingeDenseFeaturesSeparableWithinMargins(self):
with self._single_threaded_test_session():
@@ -596,22 +665,24 @@ class SdcaOptimizerTest(TensorFlowTestCase):
options = dict(symmetric_l2_regularization=1.0,
symmetric_l1_regularization=0,
loss_type='hinge_loss')
- tf.initialize_all_variables().run()
model = SdcaModel(CONTAINER, examples, variables, options)
+ tf.initialize_all_variables().run()
predictions = model.predictions(examples)
binary_predictions = get_binary_predictions_for_hinge(predictions)
- model.minimize().run()
+
+ for _ in xrange(5):
+ model.minimize().run()
# (1.0, 0.5) and (1.0, -0.5) are separable by x-axis but the datapoints
# are within the margins so there is unregularized loss (1/2 per example).
# For these datapoints, optimal weights are w_1~=0.0 and w_2~=1.0 which
- # gives an L2 loss of ~0.5.
+ # gives an L2 loss of ~0.25.
self.assertAllClose([0.5, -0.5], predictions.eval(), rtol=0.05)
self.assertAllClose([1.0, 0.0], binary_predictions.eval())
unregularized_loss = model.unregularized_loss(examples)
regularized_loss = model.regularized_loss(examples)
self.assertAllClose(0.5, unregularized_loss.eval(), atol=0.02)
- self.assertAllClose(1.0, regularized_loss.eval(), atol=0.02)
+ self.assertAllClose(0.75, regularized_loss.eval(), atol=0.02)
def testHingeDenseFeaturesWeightedExamples(self):
with self._single_threaded_test_session():
@@ -623,24 +694,26 @@ class SdcaOptimizerTest(TensorFlowTestCase):
options = dict(symmetric_l2_regularization=1.0,
symmetric_l1_regularization=0,
loss_type='hinge_loss')
- tf.initialize_all_variables().run()
model = SdcaModel(CONTAINER, examples, variables, options)
+ tf.initialize_all_variables().run()
predictions = model.predictions(examples)
binary_predictions = get_binary_predictions_for_hinge(predictions)
- model.minimize().run()
+ for _ in xrange(5):
+ model.minimize().run()
# Point (1.0, 0.5) has higher weight than (1.0, -0.5) so the model will
# try to increase the margin from (1.0, 0.5). Due to regularization,
# (1.0, -0.5) will be within the margin. For these points and example
# weights, the optimal weights are w_1~=0.4 and w_2~=1.2 which give an L2
- # loss of 0.25 * 1.6 = 0.4. The binary predictions will be correct, but
- # the boundary will be much closer to the 2nd point than the first one.
+ # loss of 0.5 * 0.25 * 0.25 * 1.6 = 0.2. The binary predictions will be
+ # correct, but the boundary will be much closer to the 2nd point than the
+ # first one.
self.assertAllClose([1.0, -0.2], predictions.eval(), atol=0.05)
self.assertAllClose([1.0, 0.0], binary_predictions.eval(), atol=0.05)
unregularized_loss = model.unregularized_loss(examples)
regularized_loss = model.regularized_loss(examples)
self.assertAllClose(0.2, unregularized_loss.eval(), atol=0.02)
- self.assertAllClose(0.6, regularized_loss.eval(), atol=0.02)
+ self.assertAllClose(0.4, regularized_loss.eval(), atol=0.02)
if __name__ == '__main__':
googletest.main()
diff --git a/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py b/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py
index e986d40338..957a734b07 100644
--- a/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py
+++ b/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py
@@ -20,11 +20,16 @@ from __future__ import print_function
import os.path
import uuid
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
from tensorflow.python.framework.load_library import load_op_library
from tensorflow.python.framework.ops import convert_to_tensor
from tensorflow.python.framework.ops import name_scope
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import variables as var_ops
from tensorflow.python.ops.nn import sigmoid_cross_entropy_with_logits
from tensorflow.python.platform import resource_loader
@@ -99,11 +104,12 @@ class SdcaModel(object):
the model, by resetting its (possibly shared) container.
```python
- # Execute opt_op once to perform training, which continues until
- convergence.
- The op makes use of duality gap as a certificate for termination. Duality
- gap is set to 0.01 as default.
- opt_op.run()
+ # Execute opt_op and train for num_steps.
+ for _ in xrange(num_steps):
+ opt_op.run()
+
+ # You can also check for convergence by calling
+ # lr.approximate_duality_gap()
```
"""
@@ -125,8 +131,7 @@ class SdcaModel(object):
self._assertList(['sparse_features', 'dense_features'], examples)
self._assertSpecified(
- ['sparse_features_weights', 'dense_features_weights',
- 'primal_loss'], variables)
+ ['sparse_features_weights', 'dense_features_weights'], variables)
self._assertList(
['sparse_features_weights', 'dense_features_weights'], variables)
@@ -138,9 +143,26 @@ class SdcaModel(object):
self._examples = examples
self._variables = variables
self._options = options
- self._primal_loss = convert_to_tensor(self._variables['primal_loss'],
- as_ref=True)
self._solver_uuid = uuid.uuid4().hex
+ self._create_slots(variables)
+
+ # TODO(rohananil): Use optimizer interface to make use of slot creation
+ # logic
+ def _create_slots(self, variables):
+ self._slots = {}
+ # TODO(rohananil): Rename the slot keys to "unshrinked" weights.
+ self._slots['sparse_features_weights'] = []
+ self._slots['dense_features_weights'] = []
+ self._assign_ops = []
+ # Make an internal variable which has the updates before applying L1
+ # regularization.
+ for var_type in ['sparse_features_weights', 'dense_features_weights']:
+ for var in variables[var_type]:
+ if var is not None:
+ self._slots[var_type].append(var_ops.Variable(array_ops.zeros_like(
+ var.initialized_value(), dtypes.float32)))
+ self._assign_ops.append(state_ops.assign(var, self._slots[var_type][
+ -1]))
def _assertSpecified(self, items, check_in):
for x in items:
@@ -160,7 +182,7 @@ class SdcaModel(object):
dense_weights = self._convert_n_to_tensor(self._variables[
'dense_features_weights'])
l1 = self._options['symmetric_l1_regularization']
- loss = 0
+ loss = 0.0
for w in sparse_weights:
loss += l1 * math_ops.reduce_sum(abs(w))
for w in dense_weights:
@@ -175,12 +197,13 @@ class SdcaModel(object):
dense_weights = self._convert_n_to_tensor(self._variables[
'dense_features_weights'])
l2 = self._options['symmetric_l2_regularization']
- loss = 0
+ loss = 0.0
for w in sparse_weights:
loss += l2 * math_ops.reduce_sum(math_ops.square(w))
for w in dense_weights:
loss += l2 * math_ops.reduce_sum(math_ops.square(w))
- return loss
+ # SDCA L2 regularization cost is 1/2 * l2 * sum(weights^2)
+ return loss / 2.0
def _convert_n_to_tensor(self, input_list, as_ref=False):
"""Converts input list to a set of tensors."""
@@ -247,23 +270,52 @@ class SdcaModel(object):
sparse_features_indices.append(convert_to_tensor(sf.indices))
sparse_features_weights.append(convert_to_tensor(sf.values))
- return _sdca_ops.sdca_solver(
+ step_op = _sdca_ops.sdca_solver(
sparse_features_indices,
sparse_features_weights,
self._convert_n_to_tensor(self._examples['dense_features']),
convert_to_tensor(self._examples['example_weights']),
convert_to_tensor(self._examples['example_labels']),
convert_to_tensor(self._examples['example_ids']),
- self._convert_n_to_tensor(self._variables['sparse_features_weights'],
+ self._convert_n_to_tensor(self._slots['sparse_features_weights'],
as_ref=True),
- self._convert_n_to_tensor(self._variables['dense_features_weights'],
+ self._convert_n_to_tensor(self._slots['dense_features_weights'],
as_ref=True),
- self._primal_loss,
l1=self._options['symmetric_l1_regularization'],
l2=self._options['symmetric_l2_regularization'],
+ num_inner_iterations=2,
loss_type=self._options['loss_type'],
container=self._container,
solver_uuid=self._solver_uuid)
+ with ops.control_dependencies([step_op]):
+ assign_ops = control_flow_ops.group(*self._assign_ops)
+ with ops.control_dependencies([assign_ops]):
+ return _sdca_ops.sdca_shrink_l1(
+ self._convert_n_to_tensor(
+ self._variables['sparse_features_weights'],
+ as_ref=True),
+ self._convert_n_to_tensor(
+ self._variables['dense_features_weights'],
+ as_ref=True),
+ l1=self._options['symmetric_l1_regularization'],
+ l2=self._options['symmetric_l2_regularization'])
+
+ def approximate_duality_gap(self):
+ """Add operations to compute the approximate duality gap.
+
+ Returns:
+ An Operation that computes the approximate duality gap over all
+ examples.
+ """
+ return _sdca_ops.compute_duality_gap(
+ self._convert_n_to_tensor(self._slots['sparse_features_weights'],
+ as_ref=True),
+ self._convert_n_to_tensor(self._slots['dense_features_weights'],
+ as_ref=True),
+ l1=self._options['symmetric_l1_regularization'],
+ l2=self._options['symmetric_l2_regularization'],
+ container=self._container,
+ solver_uuid=self._solver_uuid)
def unregularized_loss(self, examples):
"""Add operations to compute the loss (without the regularization loss).
@@ -310,8 +362,9 @@ class SdcaModel(object):
err = math_ops.sub(labels, predictions)
weighted_squared_err = math_ops.mul(math_ops.square(err), weights)
+ # SDCA squared loss function is sum(err^2) / (2*sum(weights))
return (math_ops.reduce_sum(weighted_squared_err) /
- math_ops.reduce_sum(weights))
+ (2.0 * math_ops.reduce_sum(weights)))
def regularized_loss(self, examples):
"""Add operations to compute the loss with regularization loss included.
@@ -321,7 +374,7 @@ class SdcaModel(object):
Returns:
An Operation that computes mean (regularized) loss for given set of
- examples.
+ examples.
Raises:
ValueError: if examples are not well defined.
"""
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index dc9d398a1b..5ee2337647 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -36,10 +36,6 @@
#
# filegroup ":android_proto_srcs" - Protos
# filegroup ":android_srcs" - Core sources
-# filegroup ":android_core_ops" - Essential kernels
-# filegroup ":android_extended_ops" - Optional kernels
-# filegroup ":android_extended_ops_group1" - Optional kernels, first batch
-# filegroup ":android_extended_ops_group2" - Optional kernels, second batch
# cc_library ":android_tensorflow_lib" - Native library
# portable_proto_library ":android_proto_lib" (Google-internal)
@@ -117,7 +113,6 @@ tf_proto_library_cc(
srcs = ["protobuf/master.proto"],
cc_api_version = 2,
cc_libs = [":protos_all_cc"],
- py_api_version = 2,
visibility = [
"//tensorflow:internal",
],
@@ -131,7 +126,6 @@ tf_proto_library_cc(
cc_grpc_version = 1,
cc_libs = [":master_proto_cc"],
cc_stubby_versions = ["2"],
- py_api_version = 2,
visibility = [
"//tensorflow:internal",
],
@@ -611,57 +605,6 @@ filegroup(
visibility = ["//visibility:public"],
)
-# Core kernels we want on Android. Only a subset of kernels to keep
-# base library small.
-filegroup(
- name = "android_core_ops",
- srcs = [
- "//tensorflow/core/kernels:android_core_ops",
- ],
- visibility = ["//visibility:public"],
-)
-
-# Other kernels we may want on Android.
-#
-# The kernels can be consumed as a whole or in two groups for
-# supporting separate compilation. Note that the split into groups
-# is entirely for improving compilation time, and not for
-# organizational reasons; you should not depend on any
-# of those groups independently.
-filegroup(
- name = "android_extended_ops",
- srcs = [
- ":android_extended_ops_group1",
- ":android_extended_ops_group2",
- ],
- visibility = ["//visibility:public"],
-)
-
-filegroup(
- name = "android_extended_ops_headers",
- srcs = [
- "//tensorflow/core/kernels:android_extended_ops_headers",
- ],
-)
-
-filegroup(
- name = "android_extended_ops_group1",
- srcs = [
- ":android_extended_ops_headers",
- "//tensorflow/core/kernels:android_extended_ops_group1",
- ],
- visibility = ["//visibility:public"],
-)
-
-filegroup(
- name = "android_extended_ops_group2",
- srcs = [
- ":android_extended_ops_headers",
- "//tensorflow/core/kernels:android_extended_ops_group2",
- ],
- visibility = ["//visibility:public"],
-)
-
# Config setting for determining if we are building for Android.
config_setting(
name = "android",
@@ -718,8 +661,8 @@ cc_library(
cc_library(
name = "android_tensorflow_lib",
srcs = [
- "//tensorflow/core:android_core_ops",
- "//tensorflow/core:android_extended_ops",
+ "//tensorflow/core/kernels:android_core_ops",
+ "//tensorflow/core/kernels:android_extended_ops",
],
copts = select({
":android": ANDROID_TF_COPTS,
@@ -738,6 +681,7 @@ cc_library(
],
)
+# -----------------------------------------------------------------------------
# Libraries for GPU facilities that are useful for writing kernels.
cc_library(
name = "gpu_lib",
@@ -963,10 +907,6 @@ tf_cuda_library(
],
),
copts = tf_copts(),
- cuda_deps = [
- ":core_gpu_internal",
- ":stream_executor",
- ],
deps = [
":framework",
":framework_internal",
@@ -982,39 +922,6 @@ tf_cuda_library(
alwayslink = 1,
)
-# This target should not link in any GPU runtime (CUDA) dependencies,
-# only libraries for interfacing with GPUs that can be safely linked
-# into CPU binaries.
-cc_library(
- name = "core_gpu_internal",
- srcs = [
- "common_runtime/gpu/gpu_allocator_retry.cc",
- "common_runtime/gpu/gpu_bfc_allocator.cc",
- "common_runtime/gpu/gpu_debug_allocator.cc",
- "common_runtime/gpu/gpu_init.cc",
- "common_runtime/gpu/pool_allocator.cc",
- "common_runtime/gpu/process_state.cc",
- ],
- hdrs = [
- "common_runtime/gpu/gpu_allocator_retry.h",
- "common_runtime/gpu/gpu_bfc_allocator.h",
- "common_runtime/gpu/gpu_debug_allocator.h",
- "common_runtime/gpu/gpu_init.h",
- "common_runtime/gpu/pool_allocator.h",
- "common_runtime/gpu/process_state.h",
- "common_runtime/gpu/visitable_allocator.h",
- ],
- copts = tf_copts(),
- deps = [
- ":framework",
- ":framework_internal",
- ":lib",
- ":lib_internal",
- ":protos_all_cc",
- ":stream_executor",
- ],
-)
-
cc_library(
name = "cuda",
deps = [
@@ -1044,16 +951,29 @@ tf_cuda_library(
tf_cuda_library(
name = "gpu_runtime",
srcs = [
+ "common_runtime/gpu/gpu_allocator_retry.cc",
+ "common_runtime/gpu/gpu_bfc_allocator.cc",
+ "common_runtime/gpu/gpu_debug_allocator.cc",
"common_runtime/gpu/gpu_device.cc",
"common_runtime/gpu/gpu_device_factory.cc",
+ "common_runtime/gpu/gpu_init.cc",
"common_runtime/gpu/gpu_stream_util.cc",
"common_runtime/gpu/gpu_util.cc",
"common_runtime/gpu/gpu_util_platform_specific.cc",
+ "common_runtime/gpu/pool_allocator.cc",
+ "common_runtime/gpu/process_state.cc",
],
hdrs = [
+ "common_runtime/gpu/gpu_allocator_retry.h",
+ "common_runtime/gpu/gpu_bfc_allocator.h",
+ "common_runtime/gpu/gpu_debug_allocator.h",
"common_runtime/gpu/gpu_device.h",
+ "common_runtime/gpu/gpu_init.h",
"common_runtime/gpu/gpu_stream_util.h",
"common_runtime/gpu/gpu_util.h",
+ "common_runtime/gpu/pool_allocator.h",
+ "common_runtime/gpu/process_state.h",
+ "common_runtime/gpu/visitable_allocator.h",
],
copts = tf_copts(),
cuda_deps = [
@@ -1063,7 +983,6 @@ tf_cuda_library(
deps = [
":core_cpu",
":core_cpu_internal",
- ":core_gpu_internal",
":framework",
":framework_internal",
":gpu_lib",
@@ -1247,7 +1166,6 @@ tf_cc_tests(
":all_kernels",
":core_cpu",
":core_cpu_internal",
- ":core_gpu_internal",
":direct_session",
":framework",
":framework_internal",
diff --git a/tensorflow/core/common_runtime/gpu/gpu_device_factory.cc b/tensorflow/core/common_runtime/gpu/gpu_device_factory.cc
index d37a55784d..d0726f235c 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_device_factory.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_device_factory.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/gpu/gpu_device.h"
#include "tensorflow/core/common_runtime/gpu/process_state.h"
+#include "tensorflow/core/common_runtime/threadpool_device.h"
namespace tensorflow {
@@ -61,6 +62,49 @@ class GPUDeviceFactory : public BaseGPUDeviceFactory {
REGISTER_LOCAL_DEVICE_FACTORY("GPU", GPUDeviceFactory);
+//------------------------------------------------------------------------------
+// A CPUDevice that optimizes for interaction with GPUs in the
+// process.
+// -----------------------------------------------------------------------------
+class GPUCompatibleCPUDevice : public ThreadPoolDevice {
+ public:
+ GPUCompatibleCPUDevice(const SessionOptions& options, const string& name,
+ Bytes memory_limit, BusAdjacency bus_adjacency,
+ Allocator* allocator)
+ : ThreadPoolDevice(options, name, memory_limit, bus_adjacency,
+ allocator) {}
+ ~GPUCompatibleCPUDevice() override {}
+
+ Allocator* GetAllocator(AllocatorAttributes attr) override {
+ ProcessState* ps = ProcessState::singleton();
+ if (attr.gpu_compatible()) {
+ return ps->GetCUDAHostAllocator(0);
+ } else {
+ // Call the parent's implementation.
+ return ThreadPoolDevice::GetAllocator(attr);
+ }
+ }
+};
+
+// The associated factory.
+class GPUCompatibleCPUDeviceFactory : public DeviceFactory {
+ public:
+ void CreateDevices(const SessionOptions& options, const string& name_prefix,
+ std::vector<Device*>* devices) override {
+ int n = 1;
+ auto iter = options.config.device_count().find("CPU");
+ if (iter != options.config.device_count().end()) {
+ n = iter->second;
+ }
+ for (int i = 0; i < n; i++) {
+ string name = strings::StrCat(name_prefix, "/cpu:", i);
+ devices->push_back(new GPUCompatibleCPUDevice(
+ options, name, Bytes(256 << 20), BUS_ANY, cpu_allocator()));
+ }
+ }
+};
+REGISTER_LOCAL_DEVICE_FACTORY("CPU", GPUCompatibleCPUDeviceFactory, 50);
+
} // namespace tensorflow
#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/common_runtime/threadpool_device.cc b/tensorflow/core/common_runtime/threadpool_device.cc
index 98f42a7e45..6477e9a336 100644
--- a/tensorflow/core/common_runtime/threadpool_device.cc
+++ b/tensorflow/core/common_runtime/threadpool_device.cc
@@ -26,10 +26,6 @@ limitations under the License.
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/public/session_options.h"
-#if GOOGLE_CUDA
-#include "tensorflow/core/common_runtime/gpu/process_state.h"
-#endif // GOOGLE_CUDA
-
namespace tensorflow {
ThreadPoolDevice::ThreadPoolDevice(const SessionOptions& options,
@@ -56,12 +52,6 @@ void ThreadPoolDevice::Compute(OpKernel* op_kernel, OpKernelContext* context) {
}
Allocator* ThreadPoolDevice::GetAllocator(AllocatorAttributes attr) {
-#if GOOGLE_CUDA
- ProcessState* ps = ProcessState::singleton();
- if (attr.gpu_compatible()) {
- return ps->GetCUDAHostAllocator(0);
- }
-#endif // GOOGLE_CUDA
return allocator_;
}
diff --git a/tensorflow/core/distributed_runtime/BUILD b/tensorflow/core/distributed_runtime/BUILD
index 00d97a6ef9..fb672196ca 100644
--- a/tensorflow/core/distributed_runtime/BUILD
+++ b/tensorflow/core/distributed_runtime/BUILD
@@ -269,6 +269,18 @@ cc_library(
],
)
+cc_library(
+ name = "server_lib",
+ srcs = ["server_lib.cc"],
+ hdrs = ["server_lib.h"],
+ deps = [
+ "//tensorflow/core:framework",
+ "//tensorflow/core:framework_internal",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ ],
+)
+
# TODO(mrry): Move executor_test.cc to ../common_runtime when once it no longer depends
# on grpc_testlib.
tf_cc_tests(
diff --git a/tensorflow/core/distributed_runtime/rpc/BUILD b/tensorflow/core/distributed_runtime/rpc/BUILD
index d9d016e67e..df86046c45 100644
--- a/tensorflow/core/distributed_runtime/rpc/BUILD
+++ b/tensorflow/core/distributed_runtime/rpc/BUILD
@@ -211,7 +211,7 @@ cc_library(
srcs = [
"grpc_server_lib.cc",
],
- hdrs = ["grpc_server_lib.h"],
+ linkstatic = 1, # Seems to be needed since alwayslink is broken in bazel
deps = [
"@grpc//:grpc++_unsecure",
":async_service_interface",
@@ -230,8 +230,10 @@ cc_library(
"//tensorflow/core/distributed_runtime:master_env",
"//tensorflow/core/distributed_runtime:master_session",
"//tensorflow/core/distributed_runtime:process_util",
+ "//tensorflow/core/distributed_runtime:server_lib",
"//tensorflow/core/distributed_runtime:worker_env",
],
+ alwayslink = 1,
)
cc_binary(
@@ -247,6 +249,7 @@ cc_binary(
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
+ "//tensorflow/core/distributed_runtime:server_lib",
],
)
@@ -276,6 +279,7 @@ cc_binary(
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
+ "//tensorflow/core/distributed_runtime:server_lib",
],
)
@@ -344,5 +348,6 @@ tf_cc_tests(
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
"//tensorflow/core/distributed_runtime:process_util",
+ "//tensorflow/core/distributed_runtime:server_lib",
],
)
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc
index 5cefc7605f..629441cee4 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc
@@ -13,8 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
-
#include <memory>
#include "grpc++/grpc++.h"
@@ -33,6 +31,7 @@ limitations under the License.
#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h"
#include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h"
+#include "tensorflow/core/distributed_runtime/server_lib.h"
#include "tensorflow/core/distributed_runtime/worker_env.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/lib/strings/strcat.h"
@@ -41,14 +40,14 @@ limitations under the License.
#include "tensorflow/core/public/session_options.h"
namespace tensorflow {
-
namespace {
-class TensorFlowServer : public ServerInterface {
+
+class GrpcServer : public ServerInterface {
public:
- TensorFlowServer(const ServerDef& server_def, Env* env)
+ GrpcServer(const ServerDef& server_def, Env* env)
: server_def_(server_def), env_(env), state_(NEW) {}
- ~TensorFlowServer() {
+ ~GrpcServer() {
Stop();
Join();
@@ -59,8 +58,14 @@ class TensorFlowServer : public ServerInterface {
// to destroy them.
delete master_env_.worker_cache; // Shared with worker_env.worker_cache.
- delete worker_env_.device_mgr;
+ // We must delete graph_mgr before device_mgr, due to shared
+ // ownership of OpKernels in the executors. (The graph_mgr will
+ // free all stateless OpKernels, and pass over borrowed stateful
+ // OpKernels, which are also held in their respective devices'
+ // OpSegments.)
delete worker_env_.graph_mgr;
+ delete worker_env_.device_mgr;
+
delete worker_env_.rendezvous_mgr;
// Do not delete (as these are not owned by the server):
@@ -91,6 +96,56 @@ class TensorFlowServer : public ServerInterface {
return errors::Internal("Could not parse worker name.");
}
+ // Look up the port that has been requested for this task in `server_def_`.
+ requested_port_ = -1;
+ for (const auto& job : server_def_.cluster().job()) {
+ if (job.name() == server_def_.job_name()) {
+ auto iter = job.tasks().find(server_def_.task_index());
+ if (iter == job.tasks().end()) {
+ return errors::InvalidArgument("Task ", server_def_.task_index(),
+ " was not defined in job \"",
+ server_def_.job_name(), "\"");
+ } else if (!str_util::NumericParse32(
+ str_util::Split(iter->second, ':')[1],
+ &requested_port_)) {
+ return errors::Internal(
+ "Could not parse port for local server from \"", iter->second,
+ "\"");
+ } else {
+ break;
+ }
+ }
+ }
+ if (requested_port_ == -1) {
+ return errors::Internal("Job \"", server_def_.job_name(),
+ "\" was not defined in cluster");
+ }
+
+ // N.B. The order of initialization here is intricate, because we
+ // wish to allow `requested_port_ == 0` (for choosing any port,
+ // mostly for testing). Therefore, the construction of the channel
+ // and worker caches depends on `bound_port_`, which is not set
+ // until we call `builder.BuildAndStart()`. We must create the
+ // service objects before calling `builder.BuildAndStart()`, but
+ // `master_env_` and `worker_env_` are only partially
+ // configured. However, this is not dangerous, because we do not
+ // start serving requests until `this->Start()` is called, which
+ // happens after this method returns.
+ //
+ // TODO(mrry): Provide a general mechanism for dynamically setting
+ // the identities of tasks in the worker pool after the service is
+ // running.
+ ::grpc::ServerBuilder builder;
+ builder.AddListeningPort(strings::StrCat("0.0.0.0:", requested_port_),
+ ::grpc::InsecureServerCredentials(), &bound_port_);
+ master_service_ = NewGrpcMasterService(&master_env_, &builder);
+ worker_service_ = NewGrpcWorkerService(&worker_env_, &builder);
+ server_ = builder.BuildAndStart();
+
+ if (!server_) {
+ return errors::Internal("Could not start gRPC server");
+ }
+
GrpcChannelSpec channel_spec;
for (const auto& job : server_def_.cluster().job()) {
int max_task_id = -1;
@@ -99,7 +154,12 @@ class TensorFlowServer : public ServerInterface {
}
std::vector<string> host_ports(max_task_id + 1);
for (const auto& task : job.tasks()) {
- host_ports[task.first] = task.second;
+ if (job.name() == server_def_.job_name() &&
+ task.first == server_def_.task_index()) {
+ host_ports[task.first] = strings::StrCat("localhost:", bound_port_);
+ } else {
+ host_ports[task.first] = task.second;
+ }
}
channel_spec.AddHostPortsJob(job.name(), host_ports, host_ports.size());
}
@@ -133,12 +193,6 @@ class TensorFlowServer : public ServerInterface {
mutex_lock l(mu_);
switch (state_) {
case NEW: {
- ::grpc::ServerBuilder builder;
- builder.AddListeningPort(strings::StrCat("0.0.0.0:", requested_port_),
- ::grpc::InsecureServerCredentials());
- master_service_ = NewGrpcMasterService(&master_env_, &builder);
- worker_service_ = NewGrpcWorkerService(&worker_env_, &builder);
- server_ = builder.BuildAndStart();
master_thread_.reset(
env_->StartThread(ThreadOptions(), "TF_master_service",
[this] { master_service_->HandleRPCsLoop(); }));
@@ -196,7 +250,9 @@ class TensorFlowServer : public ServerInterface {
}
}
- const string& target() const override { return target_; }
+ const string target() const override {
+ return strings::StrCat("grpc://localhost:", bound_port_);
+ }
private:
// The overall server configuration.
@@ -204,8 +260,9 @@ class TensorFlowServer : public ServerInterface {
Env* env_;
// The port requested for this server.
- // TODO(mrry): Support requested_port_ == 0 to bind to any available port.
int requested_port_;
+ // The port to which this server is bound.
+ int bound_port_ = 0;
// The `SessionOptions.target` to be used when connecting to this
// server (as a master).
@@ -238,15 +295,30 @@ class TensorFlowServer : public ServerInterface {
std::unique_ptr<::grpc::Server> server_ GUARDED_BY(mu_);
};
-} // namespace
-Status NewServer(const ServerDef& server_def,
- std::unique_ptr<ServerInterface>* out_server) {
- std::unique_ptr<TensorFlowServer> ret(
- new TensorFlowServer(server_def, Env::Default()));
- TF_RETURN_IF_ERROR(ret->Init());
- *out_server = std::move(ret);
- return Status::OK();
-}
+class GrpcServerFactory : public ServerFactory {
+ public:
+ bool AcceptsOptions(const ServerDef& server_def) override {
+ return server_def.protocol() == "grpc";
+ }
+ Status NewServer(const ServerDef& server_def,
+ std::unique_ptr<ServerInterface>* out_server) override {
+ std::unique_ptr<GrpcServer> ret(new GrpcServer(server_def, Env::Default()));
+ TF_RETURN_IF_ERROR(ret->Init());
+ *out_server = std::move(ret);
+ return Status::OK();
+ }
+};
+
+// Registers a `ServerFactory` for `GrpcServer` instances.
+class GrpcServerRegistrar {
+ public:
+ GrpcServerRegistrar() {
+ ServerFactory::Register("GRPC_SERVER", new GrpcServerFactory());
+ }
+};
+static GrpcServerRegistrar registrar;
+
+} // namespace
} // namespace tensorflow
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h
deleted file mode 100644
index a06989d88c..0000000000
--- a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h
+++ /dev/null
@@ -1,65 +0,0 @@
-/* Copyright 2016 Google Inc. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_SERVER_LIB_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_SERVER_LIB_H_
-
-#include <memory>
-
-#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/platform/macros.h"
-#include "tensorflow/core/platform/types.h"
-#include "tensorflow/core/protobuf/tensorflow_server.pb.h"
-
-namespace tensorflow {
-
-// Represents a single TensorFlow server, which exports Master and Worker
-// services.
-class ServerInterface {
- public:
- ServerInterface() {}
- virtual ~ServerInterface() {}
-
- // Starts the server running asynchronously. Returns OK on success, otherwise
- // returns an error.
- virtual Status Start() = 0;
-
- // Stops the server asynchronously. Returns OK on success, otherwise returns
- // an error.
- //
- // After calling `Stop()`, the caller may call `Join()` to block until the
- // server has stopped.
- virtual Status Stop() = 0;
-
- // Blocks until the server has stopped. Returns OK on success, otherwise
- // returns an error.
- virtual Status Join() = 0;
-
- // Returns a target string that can be used to connect to this server using
- // `tensorflow::NewSession()`.
- virtual const string& target() const = 0;
-
- private:
- TF_DISALLOW_COPY_AND_ASSIGN(ServerInterface);
-};
-
-// Creates a server based on the given `server_def`, and stores it in
-// *out_server. Returns OK on success, otherwise returns an error.
-Status NewServer(const ServerDef& server_def,
- std::unique_ptr<ServerInterface>* out_server);
-
-} // namespace tensorflow
-
-#endif // THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_SERVER_LIB_H_
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib_test.cc b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib_test.cc
index 902519769c..a56afb05a6 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib_test.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib_test.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
+#include "tensorflow/core/distributed_runtime/server_lib.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_session.h"
#include "tensorflow/core/lib/core/status_test_util.h"
@@ -25,6 +25,7 @@ namespace tensorflow {
// when no calls are made against the server.
TEST(Server, StopAfterNoop) {
ServerDef def;
+ def.set_protocol("grpc");
def.set_job_name("localhost");
def.set_task_index(0);
JobDef* job_def = def.mutable_cluster()->add_job();
@@ -42,6 +43,7 @@ TEST(Server, StopAfterNoop) {
// when a simple call is made against the server.
TEST(Server, StopAfterCall) {
ServerDef def;
+ def.set_protocol("grpc");
def.set_job_name("localhost");
def.set_task_index(0);
JobDef* job_def = def.mutable_cluster()->add_job();
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_tensorflow_server.cc b/tensorflow/core/distributed_runtime/rpc/grpc_tensorflow_server.cc
index 4b4aa0a2f9..27651b4770 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_tensorflow_server.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_tensorflow_server.cc
@@ -19,7 +19,7 @@ limitations under the License.
#include "grpc++/security/credentials.h"
#include "grpc++/server_builder.h"
-#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
+#include "tensorflow/core/distributed_runtime/server_lib.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
@@ -31,10 +31,13 @@ limitations under the License.
#include "tensorflow/core/util/command_line_flags.h"
// This binary starts a TensorFlow server (master and worker).
+//
+// TODO(mrry): Replace with a py_binary that uses `tf.GrpcServer()`.
namespace tensorflow {
namespace {
Status ParseFlagsForTask(int argc, char* argv[], ServerDef* options) {
+ options->set_protocol("grpc");
string cluster_spec;
int task_index = 0;
const bool parse_result = ParseFlags(
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_testlib_server.cc b/tensorflow/core/distributed_runtime/rpc/grpc_testlib_server.cc
index 700ae1f373..a563f124c4 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_testlib_server.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_testlib_server.cc
@@ -17,7 +17,7 @@ limitations under the License.
#include "grpc++/security/credentials.h"
#include "grpc++/server_builder.h"
-#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
+#include "tensorflow/core/distributed_runtime/server_lib.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
@@ -33,6 +33,7 @@ namespace tensorflow {
namespace {
Status ParseFlagsForTask(int argc, char* argv[], ServerDef* options) {
+ options->set_protocol("grpc");
string job_spec;
int num_cpus = 1;
int num_gpus = 0;
diff --git a/tensorflow/core/distributed_runtime/server_lib.cc b/tensorflow/core/distributed_runtime/server_lib.cc
new file mode 100644
index 0000000000..45d4f70a3f
--- /dev/null
+++ b/tensorflow/core/distributed_runtime/server_lib.cc
@@ -0,0 +1,73 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/distributed_runtime/server_lib.h"
+
+#include <unordered_map>
+
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/platform/mutex.h"
+
+namespace tensorflow {
+
+namespace {
+mutex* get_server_factory_lock() {
+ static mutex server_factory_lock;
+ return &server_factory_lock;
+}
+
+typedef std::unordered_map<string, ServerFactory*> ServerFactories;
+ServerFactories* server_factories() {
+ static ServerFactories* factories = new ServerFactories;
+ return factories;
+}
+} // namespace
+
+/* static */
+void ServerFactory::Register(const string& server_type,
+ ServerFactory* factory) {
+ mutex_lock l(*get_server_factory_lock());
+ if (!server_factories()->insert({server_type, factory}).second) {
+ LOG(ERROR) << "Two server factories are being registered under "
+ << server_type;
+ }
+}
+
+/* static */
+Status ServerFactory::GetFactory(const ServerDef& server_def,
+ ServerFactory** out_factory) {
+ mutex_lock l(*get_server_factory_lock());
+ // TODO(mrry): Improve the error reporting here.
+ for (const auto& server_factory : *server_factories()) {
+ if (server_factory.second->AcceptsOptions(server_def)) {
+ *out_factory = server_factory.second;
+ return Status::OK();
+ }
+ }
+ return errors::NotFound(
+ "No server factory registered for the given ServerDef: ",
+ server_def.DebugString());
+}
+
+// Creates a server based on the given `server_def`, and stores it in
+// `*out_server`. Returns OK on success, otherwise returns an error.
+Status NewServer(const ServerDef& server_def,
+ std::unique_ptr<ServerInterface>* out_server) {
+ ServerFactory* factory;
+ TF_RETURN_IF_ERROR(ServerFactory::GetFactory(server_def, &factory));
+ return factory->NewServer(server_def, out_server);
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/distributed_runtime/server_lib.h b/tensorflow/core/distributed_runtime/server_lib.h
new file mode 100644
index 0000000000..dea682795a
--- /dev/null
+++ b/tensorflow/core/distributed_runtime/server_lib.h
@@ -0,0 +1,98 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_SERVER_LIB_H_
+#define THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_SERVER_LIB_H_
+
+#include <memory>
+
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/protobuf/tensorflow_server.pb.h"
+
+namespace tensorflow {
+
+// This library supports a registration/factory-based mechanism for
+// creating TensorFlow server objects. Each server implementation must
+// have an accompanying implementation of ServerFactory, and create a
+// static "registrar" object that calls `ServerFactory::Register()`
+// with an instance of the factory class. See "rpc/grpc_server_lib.cc"
+// for an example.
+
+// Represents a single TensorFlow server that exports Master and Worker
+// services.
+class ServerInterface {
+ public:
+ ServerInterface() {}
+ virtual ~ServerInterface() {}
+
+ // Starts the server running asynchronously. Returns OK on success, otherwise
+ // returns an error.
+ virtual Status Start() = 0;
+
+ // Stops the server asynchronously. Returns OK on success, otherwise returns
+ // an error.
+ //
+ // After calling `Stop()`, the caller may call `Join()` to block until the
+ // server has stopped.
+ virtual Status Stop() = 0;
+
+ // Blocks until the server has stopped. Returns OK on success, otherwise
+ // returns an error.
+ virtual Status Join() = 0;
+
+ // Returns a target string that can be used to connect to this server using
+ // `tensorflow::NewSession()`.
+ virtual const string target() const = 0;
+
+ private:
+ TF_DISALLOW_COPY_AND_ASSIGN(ServerInterface);
+};
+
+class ServerFactory {
+ public:
+ // Creates a new server based on the given `server_def`, and stores
+ // it in `*out_server`. Returns OK on success, otherwise returns an
+ // error.
+ virtual Status NewServer(const ServerDef& server_def,
+ std::unique_ptr<ServerInterface>* out_server) = 0;
+
+ // Returns true if and only if this factory can create a server
+ // based on the given `server_def`.
+ virtual bool AcceptsOptions(const ServerDef& server_def) = 0;
+
+ virtual ~ServerFactory() {}
+
+ // For each `ServerFactory` subclass, an instance of that class must
+ // be registered by calling this method.
+ //
+ // The `server_type` must be unique to the server factory.
+ static void Register(const string& server_type, ServerFactory* factory);
+
+ // Looks up a factory that can create a server based on the given
+ // `server_def`, and stores it in `*out_factory`. Returns OK on
+ // success, otherwise returns an error.
+ static Status GetFactory(const ServerDef& server_def,
+ ServerFactory** out_factory);
+};
+
+// Creates a server based on the given `server_def`, and stores it in
+// `*out_server`. Returns OK on success, otherwise returns an error.
+Status NewServer(const ServerDef& server_def,
+ std::unique_ptr<ServerInterface>* out_server);
+
+} // namespace tensorflow
+
+#endif // THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_SERVER_LIB_H_
diff --git a/tensorflow/core/framework/op_kernel.h b/tensorflow/core/framework/op_kernel.h
index f3aecf0b96..09a1aa1a17 100644
--- a/tensorflow/core/framework/op_kernel.h
+++ b/tensorflow/core/framework/op_kernel.h
@@ -1234,7 +1234,7 @@ inline void OpOutputList::set_ref(int i, mutex* mu, Tensor* tensor_for_ref) {
}
#define OP_REQUIRES(CTX, EXP, STATUS) \
- if (!(EXP)) { \
+ if (!TF_PREDICT_TRUE(EXP)) { \
(CTX)->CtxFailure((STATUS)); \
return; \
}
@@ -1242,14 +1242,14 @@ inline void OpOutputList::set_ref(int i, mutex* mu, Tensor* tensor_for_ref) {
#define OP_REQUIRES_OK(CTX, STATUS) \
do { \
::tensorflow::Status _s(STATUS); \
- if (!_s.ok()) { \
+ if (!TF_PREDICT_TRUE(_s.ok())) { \
(CTX)->CtxFailureWithWarning(_s); \
return; \
} \
} while (0)
#define OP_REQUIRES_ASYNC(CTX, EXP, STATUS, CALLBACK) \
- if (!(EXP)) { \
+ if (!TF_PREDICT_TRUE(EXP)) { \
(CTX)->CtxFailure((STATUS)); \
(CALLBACK)(); \
return; \
@@ -1258,7 +1258,7 @@ inline void OpOutputList::set_ref(int i, mutex* mu, Tensor* tensor_for_ref) {
#define OP_REQUIRES_OK_ASYNC(CTX, STATUS, CALLBACK) \
do { \
::tensorflow::Status _s(STATUS); \
- if (!_s.ok()) { \
+ if (!TF_PREDICT_TRUE(_s.ok())) { \
(CTX)->CtxFailureWithWarning(_s); \
(CALLBACK)(); \
return; \
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index a7c7551ea6..7ca186eaec 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -28,16 +28,6 @@ cc_library(
],
)
-cc_library(
- name = "bounds_check",
- hdrs = ["bounds_check.h"],
- visibility = ["//visibility:private"],
- deps = [
- "//tensorflow/core:framework",
- "//third_party/eigen3",
- ],
-)
-
tf_kernel_library(
name = "concat_lib",
srcs = ["concat_lib_cpu.cc"],
@@ -153,7 +143,6 @@ tf_proto_library(
cc_api_version = 2,
go_api_version = 2,
java_api_version = 2,
- py_api_version = 2,
)
cc_library(
@@ -200,6 +189,18 @@ cc_library(
],
)
+# Private support libraries ---------------------------------------------------
+
+cc_library(
+ name = "bounds_check",
+ hdrs = ["bounds_check.h"],
+ visibility = ["//visibility:private"],
+ deps = [
+ "//tensorflow/core:framework",
+ "//third_party/eigen3",
+ ],
+)
+
# OpKernel libraries ----------------------------------------------------------
tf_kernel_libraries(
@@ -652,6 +653,7 @@ tf_kernel_libraries(
"sparse_matmul_op",
],
deps = [
+ ":bounds_check",
":fill_functor",
":transpose_functor",
"//tensorflow/core:core_cpu",
@@ -734,6 +736,7 @@ tf_kernel_libraries(
"xent_op",
],
deps = [
+ ":bounds_check",
":conv_2d",
":conv_ops",
":depthwise_conv_op",
@@ -980,6 +983,8 @@ filegroup(
],
)
+# Core kernels we want on Android. Only a subset of kernels to keep
+# base library small.
filegroup(
name = "android_core_ops",
srcs = [
@@ -1036,6 +1041,22 @@ filegroup(
],
)
+# Other kernels we may want on Android.
+#
+# The kernels can be consumed as a whole or in two groups for
+# supporting separate compilation. Note that the split into groups
+# is entirely for improving compilation time, and not for
+# organizational reasons; you should not depend on any
+# of those groups independently.
+filegroup(
+ name = "android_extended_ops",
+ srcs = [
+ ":android_extended_ops_group1",
+ ":android_extended_ops_group2",
+ ],
+ visibility = ["//visibility:public"],
+)
+
filegroup(
name = "android_extended_ops_headers",
srcs = [
@@ -1090,6 +1111,7 @@ filegroup(
"cwise_op_sub.cc",
"cwise_op_tanh.cc",
"dynamic_partition_op.cc",
+ ":android_extended_ops_headers",
],
)
@@ -1122,6 +1144,7 @@ filegroup(
"transpose_op.cc",
"where_op.cc",
"xent_op.cc",
+ ":android_extended_ops_headers",
],
)
diff --git a/tensorflow/core/kernels/bounds_check.h b/tensorflow/core/kernels/bounds_check.h
index 665cbdaff9..9bfbde9bc7 100644
--- a/tensorflow/core/kernels/bounds_check.h
+++ b/tensorflow/core/kernels/bounds_check.h
@@ -33,6 +33,19 @@ EIGEN_ALWAYS_INLINE bool FastBoundsCheck(Index index, Index limit) {
static_cast<UIndex>(limit));
}
+// Upcasting specializations when the index and bounds do not match;
+// always move to the larger type.
+
+EIGEN_ALWAYS_INLINE bool FastBoundsCheck(int64 index, int32 limit) {
+ return TF_PREDICT_TRUE(static_cast<uint64>(index) <
+ static_cast<uint64>(limit));
+}
+
+EIGEN_ALWAYS_INLINE bool FastBoundsCheck(int32 index, int64 limit) {
+ return TF_PREDICT_TRUE(static_cast<uint64>(index) <
+ static_cast<uint64>(limit));
+}
+
namespace internal {
// Ensure that the compiler cannot elide a copy into a local, for
// bounds checking on source tensors that might be updated asynchronously.
diff --git a/tensorflow/core/kernels/decode_csv_op.cc b/tensorflow/core/kernels/decode_csv_op.cc
index c48dd309c7..18ee40e623 100644
--- a/tensorflow/core/kernels/decode_csv_op.cc
+++ b/tensorflow/core/kernels/decode_csv_op.cc
@@ -214,9 +214,10 @@ class DecodeCSVOp : public OpKernel {
}
OP_REQUIRES(
- ctx, input[current_idx] == '"' &&
- (static_cast<size_t>(current_idx) == input.size() - 1 ||
- input[current_idx + 1] == delim_),
+ ctx, (static_cast<size_t>(current_idx) < input.size() &&
+ input[current_idx] == '"' &&
+ (static_cast<size_t>(current_idx) == input.size() - 1 ||
+ input[current_idx + 1] == delim_)),
errors::InvalidArgument("Quoted field has to end with quote "
"followed by delim or end"));
diff --git a/tensorflow/core/kernels/in_topk_op.cc b/tensorflow/core/kernels/in_topk_op.cc
index 0172031e43..47f334ba76 100644
--- a/tensorflow/core/kernels/in_topk_op.cc
+++ b/tensorflow/core/kernels/in_topk_op.cc
@@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/kernels/bounds_check.h"
namespace tensorflow {
@@ -55,7 +56,10 @@ class InTopK : public OpKernel {
const auto size = targets.size();
const auto num_classes = predictions.dimension(1);
for (int b = 0; b < size; b++) {
- T target_prediction = predictions(b, targets(b));
+ auto target = internal::SubtleMustCopy(targets(b));
+ OP_REQUIRES(context, FastBoundsCheck(target, num_classes),
+ errors::InvalidArgument("targets[", b, "] is out of range"));
+ T target_prediction = predictions(b, target);
bool cannot_say = !std::isfinite(target_prediction);
int more_probable_classes = 0;
if (!cannot_say) {
diff --git a/tensorflow/core/kernels/segment_reduction_ops.cc b/tensorflow/core/kernels/segment_reduction_ops.cc
index 8b672960d3..a4e75a7796 100644
--- a/tensorflow/core/kernels/segment_reduction_ops.cc
+++ b/tensorflow/core/kernels/segment_reduction_ops.cc
@@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/kernels/bounds_check.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/util/util.h"
@@ -78,7 +79,9 @@ class SegmentReductionOp : public OpKernel {
// Note that the current implementation assumes that segment_vec values are
// sorted.
const Index output_rows =
- num_indices > 0 ? segment_vec(num_indices - 1) + 1 : 0;
+ num_indices > 0
+ ? internal::SubtleMustCopy(segment_vec(num_indices - 1)) + 1
+ : 0;
TensorShape output_shape = input.shape();
output_shape.set_dim(0, output_rows);
@@ -118,7 +121,14 @@ class SegmentReductionOp : public OpKernel {
typedef Eigen::TensorMap<Eigen::Tensor<T, 1, Eigen::RowMajor>,
Eigen::Unaligned>
OutT;
- T* out_slice_ptr = &output_flat(segment_vec(start), 0);
+
+ Index out_index = internal::SubtleMustCopy(segment_vec(start));
+ OP_REQUIRES(
+ context, FastBoundsCheck(out_index, output_rows),
+ errors::InvalidArgument(
+ "Segment id ", out_index, " out of range [0, ", output_rows,
+ "), probably because 'segment_ids' input is not sorted."));
+ T* out_slice_ptr = &output_flat(out_index, 0);
OutT out_slice(out_slice_ptr, out_slice_shape);
// We don't use out_slice.device(context->eigen_device<Device>)
// because these pieces of work are likely to be very small and
@@ -208,7 +218,6 @@ class UnsortedSegmentSumOp : public OpKernel {
context, IsLegacyScalar(num_segments.shape()),
errors::InvalidArgument("num_segments should be a scalar, not shape ",
num_segments.shape().DebugString()));
-
OP_REQUIRES(
context,
TensorShapeUtils::StartsWith(data.shape(), segment_ids.shape()),
@@ -218,15 +227,11 @@ class UnsortedSegmentSumOp : public OpKernel {
const auto segment_flat = segment_ids.flat<Index>();
const int32 N = segment_flat.dimension(0);
- const int32 output_rows = num_segments.scalar<int32>()();
-
- for (int i = 0; i < N; i++) {
- int j = segment_flat(i);
- OP_REQUIRES(context, 0 <= j && j < output_rows,
- errors::InvalidArgument(
- "segment_ids", SliceDebugString(segment_ids.shape(), i),
- " = ", j, " is out of range [0, ", output_rows, ")"));
- }
+ const Index output_rows =
+ internal::SubtleMustCopy(num_segments.scalar<int32>()());
+ OP_REQUIRES(context, output_rows >= 0,
+ errors::InvalidArgument("Input num_segments == ", output_rows,
+ " must not be negative."));
TensorShape output_shape;
output_shape.AddDim(output_rows);
@@ -242,8 +247,12 @@ class UnsortedSegmentSumOp : public OpKernel {
if (data.NumElements() > 0) {
auto data_flat = data.shaped<T, 2>({N, data.NumElements() / N});
for (int i = 0; i < N; ++i) {
- output_flat.template chip<0>(segment_flat(i)) +=
- data_flat.template chip<0>(i);
+ Index j = internal::SubtleMustCopy(segment_flat(i));
+ OP_REQUIRES(context, FastBoundsCheck(j, output_rows),
+ errors::InvalidArgument(
+ "segment_ids", SliceDebugString(segment_ids.shape(), i),
+ " = ", j, " is out of range [0, ", output_rows, ")"));
+ output_flat.template chip<0>(j) += data_flat.template chip<0>(i);
}
}
}
diff --git a/tensorflow/core/kernels/stack_ops.cc b/tensorflow/core/kernels/stack_ops.cc
index 2ba571bcdb..4bddcd7e98 100644
--- a/tensorflow/core/kernels/stack_ops.cc
+++ b/tensorflow/core/kernels/stack_ops.cc
@@ -181,47 +181,52 @@ class StackPushOp : public AsyncOpKernel {
// Push the tensor onto the stack. Swap the tensor to CPU if instructed.
const Tensor& tensor = ctx->input(1);
AllocatorAttributes alloc_attrs = ctx->input_alloc_attr(1);
- DeviceContext* device_ctxt = ctx->op_device_context();
- auto device = static_cast<tensorflow::Device*>(ctx->device());
- Allocator* allocator = device->GetAllocator(alloc_attrs);
- AllocatorStats stats;
- allocator->GetStats(&stats);
+ // For now, we use a simple heuristic for swapping: A GPU tensor is moved
+ // to CPU if the tensor has more than kCopyThreshold bytes and the GPU
+ // allocator says more than kOccupancy of the memory is in use.
static constexpr int kCopyThreshold = 2048;
static constexpr double kOccupancy = 0.7;
if (swap_memory_ && !alloc_attrs.on_host() &&
std::is_same<Device, GPUDevice>::value &&
- stats.bytes_in_use > (stats.bytes_limit * kOccupancy) &&
tensor.TotalBytes() > kCopyThreshold) {
- // Asynchronously copy the tensor from GPU to CPU memory.
- // TODO(yuanbyu): Swap the oldest tensor first.
- AllocatorAttributes host_alloc_attrs;
- host_alloc_attrs.set_gpu_compatible(true);
- host_alloc_attrs.set_on_host(true);
- Allocator* cpu_allocator = device->GetAllocator(host_alloc_attrs);
- Tensor* cpu_tensor =
- new Tensor(cpu_allocator, tensor.dtype(), tensor.shape());
- device_ctxt->CopyDeviceTensorToCPU(
- &tensor, "StackPush", device, cpu_tensor,
- [cpu_tensor, stack, ctx, done](const Status& s) {
- ctx->SetStatus(s);
- if (s.ok()) {
- AllocatorAttributes alloc_attrs = ctx->input_alloc_attr(1);
- ctx->SetStatus(stack->Push(
- {PersistentTensor(*cpu_tensor), alloc_attrs, true}));
- }
- if (ctx->status().ok()) {
- ctx->set_output(0, *cpu_tensor);
- }
- done();
- delete cpu_tensor;
- });
- } else {
- // Execute synchronously if not swapped.
- OP_REQUIRES_OK(
- ctx, stack->Push({PersistentTensor(tensor), alloc_attrs, false}));
- ctx->set_output(0, tensor);
- done();
+ DeviceContext* device_ctxt = ctx->op_device_context();
+ auto device = static_cast<tensorflow::Device*>(ctx->device());
+ Allocator* allocator = device->GetAllocator(alloc_attrs);
+ AllocatorStats stats;
+ allocator->GetStats(&stats);
+ if (stats.bytes_in_use > (stats.bytes_limit * kOccupancy)) {
+ // Asynchronously copy the tensor from GPU to CPU memory.
+ // TODO(yuanbyu): Swap the oldest tensor first.
+ AllocatorAttributes host_alloc_attrs;
+ host_alloc_attrs.set_gpu_compatible(true);
+ host_alloc_attrs.set_on_host(true);
+ Allocator* cpu_allocator = device->GetAllocator(host_alloc_attrs);
+ Tensor* cpu_tensor =
+ new Tensor(cpu_allocator, tensor.dtype(), tensor.shape());
+ device_ctxt->CopyDeviceTensorToCPU(
+ &tensor, "StackPush", device, cpu_tensor,
+ [cpu_tensor, stack, ctx, done](const Status& s) {
+ ctx->SetStatus(s);
+ if (s.ok()) {
+ AllocatorAttributes alloc_attrs = ctx->input_alloc_attr(1);
+ ctx->SetStatus(stack->Push(
+ {PersistentTensor(*cpu_tensor), alloc_attrs, true}));
+ }
+ if (ctx->status().ok()) {
+ ctx->set_output(0, *cpu_tensor);
+ }
+ done();
+ delete cpu_tensor;
+ });
+ return;
+ }
}
+
+ // Execute synchronously if not swapped.
+ OP_REQUIRES_OK(ctx,
+ stack->Push({PersistentTensor(tensor), alloc_attrs, false}));
+ ctx->set_output(0, tensor);
+ done();
}
bool IsExpensive() override { return false; }
diff --git a/tensorflow/core/kernels/transpose_op.cc b/tensorflow/core/kernels/transpose_op.cc
index 88786ec774..5ecef9c6f9 100644
--- a/tensorflow/core/kernels/transpose_op.cc
+++ b/tensorflow/core/kernels/transpose_op.cc
@@ -169,7 +169,7 @@ Status TransposeGpuOp::DoTranspose(OpKernelContext* ctx, const Tensor& in,
.TypeConstraint<T>("T") \
.HostMemory("perm"), \
TransposeGpuOp);
-TF_CALL_NUMBER_TYPES(REGISTER);
+TF_CALL_POD_TYPES(REGISTER);
#undef REGISTER
#endif
diff --git a/tensorflow/core/lib/random/philox_random.h b/tensorflow/core/lib/random/philox_random.h
index 5fc2d5d20d..dc8de09d2c 100644
--- a/tensorflow/core/lib/random/philox_random.h
+++ b/tensorflow/core/lib/random/philox_random.h
@@ -25,7 +25,7 @@ limitations under the License.
#include "tensorflow/core/platform/types.h"
// Function qualifiers that need to work on both CPU and GPU.
-#ifdef __CUDA_ARCH__
+#if defined(__CUDACC__)
// For nvcc.
#define PHILOX_DEVICE_FUNC __host__ __device__
#define PHILOX_INLINE __inline__
diff --git a/tensorflow/core/platform/macros.h b/tensorflow/core/platform/macros.h
index 9cc08eca52..c7d5e63a1b 100644
--- a/tensorflow/core/platform/macros.h
+++ b/tensorflow/core/platform/macros.h
@@ -50,8 +50,8 @@ limitations under the License.
#define TF_PREDICT_FALSE(x) (__builtin_expect(x, 0))
#define TF_PREDICT_TRUE(x) (__builtin_expect(!!(x), 1))
#else
-#define TF_PREDICT_FALSE(x) x
-#define TF_PREDICT_TRUE(x) x
+#define TF_PREDICT_FALSE(x) (x)
+#define TF_PREDICT_TRUE(x) (x)
#endif
// A macro to disallow the copy constructor and operator= functions
diff --git a/tensorflow/core/protobuf/tensorflow_server.proto b/tensorflow/core/protobuf/tensorflow_server.proto
index 5b4ee3e85a..9b8ec1b5ed 100644
--- a/tensorflow/core/protobuf/tensorflow_server.proto
+++ b/tensorflow/core/protobuf/tensorflow_server.proto
@@ -105,4 +105,9 @@ message ServerDef {
// The default configuration for sessions that run on this server.
ConfigProto default_session_config = 4;
+
+ // The protocol to be used by this server.
+ //
+ // Acceptable values include: "grpc".
+ string protocol = 5;
}
diff --git a/tensorflow/examples/udacity/README.md b/tensorflow/examples/udacity/README.md
index 8308a56766..a6d6f8742a 100644
--- a/tensorflow/examples/udacity/README.md
+++ b/tensorflow/examples/udacity/README.md
@@ -6,7 +6,7 @@ Course information can be found at https://www.udacity.com/course/deep-learning-
Running the Docker container from the Google Cloud repository
-------------------------------------------------------------
- docker run -p 8888:8888 -it --rm b.gcr.io/tensorflow-udacity/assignments
+ docker run -p 8888:8888 -it --rm b.gcr.io/tensorflow-udacity/assignments:0.3.0
Accessing the Notebooks
-----------------------
@@ -61,9 +61,9 @@ This will allow you to save work and have access to generated files on the host
Pushing a Google Cloud release
------------------------------
- V=0.2.0
+ V=0.3.0
docker tag $USER/assignments b.gcr.io/tensorflow-udacity/assignments:$V
- docker tag $USER/assignments b.gcr.io/tensorflow-udacity/assignments:latest
+ docker tag -f $USER/assignments b.gcr.io/tensorflow-udacity/assignments:latest
gcloud docker push b.gcr.io/tensorflow-udacity/assignments
History
@@ -71,3 +71,4 @@ History
* 0.1.0: Initial release.
* 0.2.0: Many fixes, including lower memory footprint and support for Python 3.
+* 0.3.0: Use 0.7.1 release.
diff --git a/tensorflow/models/image/mnist/convolutional.py b/tensorflow/models/image/mnist/convolutional.py
index 507165b0aa..aaba57b3b0 100644
--- a/tensorflow/models/image/mnist/convolutional.py
+++ b/tensorflow/models/image/mnist/convolutional.py
@@ -285,7 +285,7 @@ def main(argv=None): # pylint: disable=unused-argument
batch_data = train_data[offset:(offset + BATCH_SIZE), ...]
batch_labels = train_labels[offset:(offset + BATCH_SIZE)]
# This dictionary maps the batch data (as a numpy array) to the
- # node in the graph is should be fed to.
+ # node in the graph it should be fed to.
feed_dict = {train_data_node: batch_data,
train_labels_node: batch_labels}
# Run the graph and fetch some of the nodes.
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index 7533317ac3..6fcb33d39b 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -821,6 +821,7 @@ py_library(
deps = [
":framework",
":ops",
+ ":server_lib",
":session",
":training_ops",
],
@@ -897,6 +898,7 @@ tf_py_wrap_cc(
srcs = ["tensorflow.i"],
swig_includes = [
"client/events_writer.i",
+ "client/server_lib.i",
"client/tf_session.i",
"framework/python_op_gen.i",
"lib/core/py_func.i",
@@ -915,6 +917,8 @@ tf_py_wrap_cc(
":py_record_writer_lib",
":python_op_gen",
":tf_session_helper",
+ "//tensorflow/core/distributed_runtime:server_lib",
+ "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
"//tensorflow/core/distributed_runtime/rpc:grpc_session",
"//util/python:python_headers",
],
@@ -940,6 +944,28 @@ py_library(
],
)
+py_library(
+ name = "server_lib",
+ srcs = ["client/server_lib.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":pywrap_tensorflow",
+ ],
+)
+
+py_test(
+ name = "server_lib_test",
+ srcs = ["client/server_lib_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":extra_py_tests_deps",
+ ":framework",
+ ":framework_test_lib",
+ ":server_lib",
+ ":session",
+ ],
+)
+
# Just used by tests.
tf_cuda_library(
name = "construction_fails_op",
diff --git a/tensorflow/python/__init__.py b/tensorflow/python/__init__.py
index aab8ada371..b04912b107 100644
--- a/tensorflow/python/__init__.py
+++ b/tensorflow/python/__init__.py
@@ -111,6 +111,7 @@ __all__ = make_all(__name__,
# documentation, or remove.
__all__.extend([
'AttrValue',
+ 'ClusterDef',
'ConfigProto',
'Event',
'GPUOptions',
@@ -119,7 +120,9 @@ __all__.extend([
'GRAPH_DEF_VERSION_MIN_PRODUCER',
'GraphDef',
'GraphOptions',
+ 'GrpcServer',
'HistogramProto',
+ 'JobDef',
'LogMessage',
'NameAttrList',
'NodeDef',
@@ -127,6 +130,7 @@ __all__.extend([
'PaddingFIFOQueue',
'RunOptions',
'RunOutputs',
+ 'ServerDef',
'SessionLog',
'Summary',
'arg_max',
diff --git a/tensorflow/python/client/client_lib.py b/tensorflow/python/client/client_lib.py
index b06b37b7d0..0ab8f9dce0 100644
--- a/tensorflow/python/client/client_lib.py
+++ b/tensorflow/python/client/client_lib.py
@@ -51,6 +51,12 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+# NOTE(mrry): Support for `tf.GrpcServer` is currently experimental.
+from tensorflow.core.protobuf.tensorflow_server_pb2 import ClusterDef
+from tensorflow.core.protobuf.tensorflow_server_pb2 import JobDef
+from tensorflow.core.protobuf.tensorflow_server_pb2 import ServerDef
+from tensorflow.python.client.server_lib import GrpcServer
+
from tensorflow.python.client.session import InteractiveSession
from tensorflow.python.client.session import Session
diff --git a/tensorflow/python/client/server_lib.i b/tensorflow/python/client/server_lib.i
new file mode 100644
index 0000000000..835f883ef4
--- /dev/null
+++ b/tensorflow/python/client/server_lib.i
@@ -0,0 +1,88 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+%nothread tensorflow::ServerInterface::Join;
+
+%include "tensorflow/python/platform/base.i"
+
+//%newobject tensorflow::NewServer;
+
+%typemap(in) const ServerDef& (tensorflow::ServerDef temp) {
+ char* c_string;
+ Py_ssize_t py_size;
+ if (PyBytes_AsStringAndSize($input, &c_string, &py_size) == -1) {
+ // Python has raised an error (likely TypeError or UnicodeEncodeError).
+ SWIG_fail;
+ }
+
+ if (!temp.ParseFromString(string(c_string, py_size))) {
+ PyErr_SetString(
+ PyExc_TypeError,
+ "The ServerDef could not be parsed as a valid protocol buffer");
+ SWIG_fail;
+ }
+ $1 = &temp;
+}
+
+%typemap(in, numinputs=0)
+ std::unique_ptr<tensorflow::ServerInterface>* out_server (
+ std::unique_ptr<tensorflow::ServerInterface> temp) {
+ $1 = &temp;
+}
+
+%typemap(out) tensorflow::Status tensorflow::NewServer {
+ if (!$1.ok()) {
+ RaiseStatusNotOK($1, $descriptor(tensorflow::Status*));
+ SWIG_fail;
+ }
+}
+
+%typemap(argout) std::unique_ptr<tensorflow::ServerInterface>* out_server {
+ // TODO(mrry): Convert this to SWIG_POINTER_OWN when the issues with freeing
+ // a server are fixed.
+ $result = SWIG_NewPointerObj($1->release(),
+ $descriptor(tensorflow::ServerInterface*),
+ 0);
+}
+
+%feature("except") tensorflow::ServerInterface::Join {
+ // Let other threads run while we wait for the server to shut down.
+ Py_BEGIN_ALLOW_THREADS
+ $action
+ Py_END_ALLOW_THREADS
+}
+
+%{
+#include "tensorflow/core/distributed_runtime/server_lib.h"
+
+using tensorflow::ServerDef;
+%}
+
+%ignoreall
+
+%unignore tensorflow;
+%unignore tensorflow::ServerDef;
+%unignore tensorflow::ServerInterface;
+%unignore tensorflow::ServerInterface::~ServerInterface;
+%unignore tensorflow::ServerInterface::Start;
+%unignore tensorflow::ServerInterface::Stop;
+%unignore tensorflow::ServerInterface::Join;
+%unignore tensorflow::ServerInterface::target;
+
+%unignore tensorflow::NewServer;
+
+%include "tensorflow/core/distributed_runtime/server_lib.h"
+
+%unignoreall
diff --git a/tensorflow/python/client/server_lib.py b/tensorflow/python/client/server_lib.py
new file mode 100644
index 0000000000..38612edf15
--- /dev/null
+++ b/tensorflow/python/client/server_lib.py
@@ -0,0 +1,86 @@
+# Copyright 2015 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""A Python interface for creating TensorFlow servers."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import six # pylint: disable=unused-import
+
+from tensorflow.core.protobuf import tensorflow_server_pb2
+from tensorflow.python import pywrap_tensorflow
+
+
+class GrpcServer(object):
+ """An in-process TensorFlow server.
+
+ NOTE(mrry): This class is experimental and not yet suitable for use.
+ """
+
+ def __init__(self, server_def, start=True):
+ """Creates a new server with the given definition.
+
+ Args:
+ server_def: A `tf.ServerDef` protocol buffer, describing the server to
+ be created (and the cluster of which it is a member).
+ start: (Optional.) Boolean, indicating whether to start the server after
+ creating it. Defaults to `True`.
+ """
+ if not isinstance(server_def, tensorflow_server_pb2.ServerDef):
+ raise TypeError("server_def must be a tf.ServerDef")
+
+ self._server = pywrap_tensorflow.NewServer(server_def.SerializeToString())
+ if start:
+ self.start()
+
+ def start(self):
+ """Starts this server."""
+ self._server.Start()
+
+ def stop(self):
+ """Stops this server.
+
+ NOTE(mrry): This method is currently not implemented.
+ """
+ # TODO(mrry): Implement this.
+ raise NotImplementedError("GrpcServer.stop()")
+
+ def join(self):
+ """Blocks until the server has shut down.
+
+ NOTE(mrry): Since `GrpcServer.stop()` is not currently implemented, this
+ method blocks forever.
+ """
+ self._server.Join()
+
+ @property
+ def target(self):
+ """Returns the target for a `tf.Session` to connect to this server.
+
+ To create a
+ [`tf.Session`](../../api_docs/python/client.md#Session) that
+ connects to this server, use the following snippet:
+
+ ```python
+ server = tf.GrpcServer(...)
+ with tf.Session(server.target):
+ # ...
+ ```
+
+ Returns:
+ A string containing a session target for this server.
+ """
+ return self._server.target()
diff --git a/tensorflow/python/client/server_lib_test.py b/tensorflow/python/client/server_lib_test.py
new file mode 100644
index 0000000000..5705f363d5
--- /dev/null
+++ b/tensorflow/python/client/server_lib_test.py
@@ -0,0 +1,65 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for tf.GrpcServer."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+
+class GrpcServerTest(tf.test.TestCase):
+
+ def _localServer(self):
+ server_def = tf.ServerDef(protocol="grpc")
+ job_def = server_def.cluster.job.add()
+ job_def.name = "local"
+ job_def.tasks[0] = "localhost:0"
+ server_def.job_name = job_def.name
+ server_def.task_index = 0
+ return server_def
+
+ def testRunStep(self):
+ server = tf.GrpcServer(self._localServer())
+ server.start()
+
+ with tf.Session(server.target) as sess:
+ c = tf.constant([[2, 1]])
+ d = tf.constant([[1], [2]])
+ e = tf.matmul(c, d)
+ print(sess.run(e))
+ # TODO(mrry): Add `server.stop()` and `server.join()` when these work.
+
+ def testMultipleSessions(self):
+ server = tf.GrpcServer(self._localServer())
+ server.start()
+
+ c = tf.constant([[2, 1]])
+ d = tf.constant([[1], [2]])
+ e = tf.matmul(c, d)
+
+ sess_1 = tf.Session(server.target)
+ sess_2 = tf.Session(server.target)
+
+ sess_1.run(e)
+ sess_2.run(e)
+
+ sess_1.close()
+ sess_2.close()
+ # TODO(mrry): Add `server.stop()` and `server.join()` when these work.
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/python/framework/gen_docs_combined.py b/tensorflow/python/framework/gen_docs_combined.py
index 7a7c58b19f..7180f7d77c 100644
--- a/tensorflow/python/framework/gen_docs_combined.py
+++ b/tensorflow/python/framework/gen_docs_combined.py
@@ -127,8 +127,8 @@ _hidden_symbols = ["Event", "LogMessage", "Summary", "SessionLog", "xrange",
"SessionInterface", "BaseSession", "NameAttrList",
"AttrValue", "TensorArray", "OptimizerOptions",
"CollectionDef", "MetaGraphDef", "QueueRunnerDef",
- "SaverDef", "VariableDef", "TestCase",
- ]
+ "SaverDef", "VariableDef", "TestCase", "GrpcServer",
+ "ClusterDef", "JobDef", "ServerDef"]
def main(unused_argv):
if not FLAGS.out_dir:
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py
index 95c7cfc2cf..dabc474f42 100644
--- a/tensorflow/python/framework/ops.py
+++ b/tensorflow/python/framework/ops.py
@@ -2526,9 +2526,11 @@ class Graph(object):
return name
@contextlib.contextmanager
- def colocate_with(self, op):
+ def colocate_with(self, op, ignore_existing=False):
"""Returns a context manager that specifies an op to colocate with.
+ Note: this function is not for public use, only for internal libraries.
+
For example:
```python
@@ -2543,6 +2545,9 @@ class Graph(object):
Args:
op: The op to colocate all created ops with.
+ ignore_existing: If true, only applies colocation of this op within
+ the context, rather than applying all colocation properties
+ on the stack.
Raises:
ValueError: if op is None.
@@ -2569,6 +2574,10 @@ class Graph(object):
device_fn_tmp = self._device_function_stack
self._device_function_stack = []
+ if ignore_existing:
+ current_stack = self._colocation_stack
+ self._colocation_stack = []
+
self._colocation_stack.append(op)
try:
@@ -2578,6 +2587,10 @@ class Graph(object):
self._device_function_stack = device_fn_tmp
self._colocation_stack.pop()
+ # Reset the colocation stack if requested.
+ if ignore_existing:
+ self._colocation_stack = current_stack
+
@contextlib.contextmanager
def device(self, device_name_or_function):
"""Returns a context manager that specifies the default device to use.
@@ -3007,8 +3020,8 @@ def device(device_name_or_function):
return get_default_graph().device(device_name_or_function)
-def colocate_with(op):
- return get_default_graph().colocate_with(op)
+def colocate_with(op, ignore_existing=False):
+ return get_default_graph().colocate_with(op, ignore_existing)
def name_scope(name):
diff --git a/tensorflow/python/framework/ops_test.py b/tensorflow/python/framework/ops_test.py
index b5dbd3c6f6..cfc96a0cc8 100644
--- a/tensorflow/python/framework/ops_test.py
+++ b/tensorflow/python/framework/ops_test.py
@@ -1283,6 +1283,14 @@ class ColocationGroupTest(test_util.TensorFlowTestCase):
c = constant_op.constant(4.0)
self.assertEqual(set([b"loc:@a", b"loc:@b"]), set(c.op.colocation_groups()))
+ def testColocationIgnoreStack(self):
+ a = constant_op.constant([2.0], name="a")
+ b = constant_op.constant(3.0, name="b")
+ with ops.colocate_with(a.op):
+ with ops.colocate_with(b.op, ignore_existing=True):
+ c = constant_op.constant(4.0)
+ self.assertEqual(set(["loc:@b"]), set(c.op.colocation_groups()))
+
def testColocateVariables(self):
a = variables.Variable([2.0], name="a")
with ops.colocate_with(a.op):
diff --git a/tensorflow/python/kernel_tests/constant_op_test.py b/tensorflow/python/kernel_tests/constant_op_test.py
index 766b416f75..d93020e825 100644
--- a/tensorflow/python/kernel_tests/constant_op_test.py
+++ b/tensorflow/python/kernel_tests/constant_op_test.py
@@ -328,6 +328,19 @@ class ZerosLikeTest(tf.test.TestCase):
z = tf.zeros_like(d)
self.assertEqual(d.get_shape().as_list(), z.get_shape().as_list())
+ def testZerosLikeDtype(self):
+ # Make sure zeros_like works even for dtypes that cannot be cast between
+ with self.test_session():
+ shape = (3, 5)
+ dtypes = np.float32, np.complex64
+ for in_type in dtypes:
+ x = np.arange(15).astype(in_type).reshape(*shape)
+ for out_type in dtypes:
+ y = tf.zeros_like(x, dtype=out_type).eval()
+ self.assertEqual(y.dtype, out_type)
+ self.assertEqual(y.shape, shape)
+ self.assertAllEqual(y, np.zeros(shape, dtype=out_type))
+
class OnesTest(tf.test.TestCase):
diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
index 07f11354b9..77da519fcc 100644
--- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
+++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
@@ -898,7 +898,7 @@ class ControlFlowTest(tf.test.TestCase):
r = control_flow_ops.While(c, b, [n, v], parallel_iterations=1)
r = tf.gradients(r[1], x)[0]
- self.assertEqual(r.get_shape().as_list(), [None])
+ self.assertEqual(r.get_shape(), tensor_shape.unknown_shape())
self.assertAllClose([810.0, 2560.0], r.eval(feed_dict={x: [3.0, 4.0]}))
def testWhileGrad_MultipleUses(self):
diff --git a/tensorflow/python/kernel_tests/decode_csv_op_test.py b/tensorflow/python/kernel_tests/decode_csv_op_test.py
index 37541284d1..959268a544 100644
--- a/tensorflow/python/kernel_tests/decode_csv_op_test.py
+++ b/tensorflow/python/kernel_tests/decode_csv_op_test.py
@@ -160,6 +160,21 @@ class DecodeCSVOpTest(tf.test.TestCase):
args,
expected_err_re="Unquoted fields cannot have quotes/CRLFs inside")
+ def testWrongDefaults(self):
+ args = {
+ "records": [",1", "0.2,2", "3.0adf,3"],
+ "record_defaults": [[1.0]]
+ }
+
+ self._test(args,
+ expected_err_re="Expect 1 fields but have 2 in record 0")
+
+ def testShortQuotedString(self):
+ args = {"records": ["\""], "record_defaults": [["default"]],}
+
+ self._test(args,
+ expected_err_re="Quoted field has to end with quote followed.*")
+
if __name__ == "__main__":
tf.test.main()
diff --git a/tensorflow/python/kernel_tests/in_topk_op_test.py b/tensorflow/python/kernel_tests/in_topk_op_test.py
index dd8a8350c8..97a064df9d 100644
--- a/tensorflow/python/kernel_tests/in_topk_op_test.py
+++ b/tensorflow/python/kernel_tests/in_topk_op_test.py
@@ -58,6 +58,14 @@ class InTopKTest(tf.test.TestCase):
target = [0, 2]
self._validateInTopK(predictions, target, 2, [False, False])
+ def testBadTarget(self):
+ predictions = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.2, 0.3, 0.4]]
+ target = [0, 80000]
+ with self.test_session():
+ with self.assertRaisesRegexp(tf.errors.InvalidArgumentError,
+ "target.*out of range"):
+ tf.nn.in_top_k(predictions, target, 2).eval()
+
if __name__ == "__main__":
tf.test.main()
diff --git a/tensorflow/python/kernel_tests/rnn_test.py b/tensorflow/python/kernel_tests/rnn_test.py
index fe858b78b1..be59ac08c2 100644
--- a/tensorflow/python/kernel_tests/rnn_test.py
+++ b/tensorflow/python/kernel_tests/rnn_test.py
@@ -560,7 +560,7 @@ class LSTMTest(tf.test.TestCase):
for out0, out1 in zip(outputs0_values, outputs1_values):
self.assertAllEqual(out0, out1)
- def _testDynamicEquivalentToStaticRNN(self, use_gpu):
+ def _testDynamicEquivalentToStaticRNN(self, use_gpu, use_sequence_length):
time_steps = 8
num_units = 3
num_proj = 4
@@ -569,7 +569,10 @@ class LSTMTest(tf.test.TestCase):
input_values = np.random.randn(time_steps, batch_size, input_size)
- sequence_length = np.random.randint(0, time_steps, size=batch_size)
+ if use_sequence_length:
+ sequence_length = np.random.randint(0, time_steps, size=batch_size)
+ else:
+ sequence_length = None
########### Step 1: Run static graph and generate readouts
with self.test_session(use_gpu=use_gpu, graph=tf.Graph()) as sess:
@@ -744,8 +747,14 @@ class LSTMTest(tf.test.TestCase):
self._testDoubleInputWithDropoutAndDynamicCalculation(use_gpu=True)
def testDynamicEquivalentToStaticRNN(self):
- self._testDynamicEquivalentToStaticRNN(use_gpu=False)
- self._testDynamicEquivalentToStaticRNN(use_gpu=True)
+ self._testDynamicEquivalentToStaticRNN(
+ use_gpu=False, use_sequence_length=False)
+ self._testDynamicEquivalentToStaticRNN(
+ use_gpu=True, use_sequence_length=False)
+ self._testDynamicEquivalentToStaticRNN(
+ use_gpu=False, use_sequence_length=True)
+ self._testDynamicEquivalentToStaticRNN(
+ use_gpu=True, use_sequence_length=True)
class BidirectionalRNNTest(tf.test.TestCase):
@@ -1091,7 +1100,7 @@ def rnn_long_sequence_benchmark(batch_size, seqlen, num_units,
def main(_):
print("Graph Creation: Static Unroll vs. Dynamic Unroll LSTM")
print("max_t \t dt(static) \t dt(dynamic) \t dt(dynamic)/dt(static)")
- for max_time in (1, 25, 50):
+ for max_time in (1, 25, 50, 100, 200):
graph_creation_static_vs_dynamic_rnn_benchmark(max_time)
print("Calculation: Static Unroll with Dynamic Flow LSTM "
diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py
index 7257be5d59..44efdb538b 100644
--- a/tensorflow/python/ops/array_ops.py
+++ b/tensorflow/python/ops/array_ops.py
@@ -636,10 +636,12 @@ def zeros_like(tensor, dtype=None, name=None):
"""
with ops.op_scope([tensor], name, "zeros_like") as name:
tensor = ops.convert_to_tensor(tensor, name="tensor")
- ret = gen_array_ops._zeros_like(tensor)
- if (dtype is not None) and (tensor.dtype != dtype):
- ret = gen_math_ops.cast(ret, dtype)
- return ret
+ if dtype is not None and tensor.dtype != dtype:
+ ret = zeros(shape(tensor), dtype, name=name)
+ ret.set_shape(tensor.get_shape())
+ return ret
+ else:
+ return gen_array_ops._zeros_like(tensor, name=name)
def ones_like(tensor, dtype=None, name=None):
diff --git a/tensorflow/python/ops/control_flow_grad.py b/tensorflow/python/ops/control_flow_grad.py
index cc911ca24f..aa85c12931 100644
--- a/tensorflow/python/ops/control_flow_grad.py
+++ b/tensorflow/python/ops/control_flow_grad.py
@@ -36,17 +36,18 @@ def _SwitchGrad(op, *grad):
the merge on the first visit, and update the other input of the merge
on the second visit. A next_iteration is also added on second visit.
"""
- real_op = GetRealOp(op)
+ graph = ops.get_default_graph()
# pylint: disable=protected-access
- ctxt = real_op._get_control_flow_context()
+ op_ctxt = op._get_control_flow_context()
+ grad_ctxt = graph._get_control_flow_context()
# pylint: enable=protected-access
- if isinstance(ctxt, WhileContext):
- merge_op = op.grad_state.switch_map.get(real_op)
+ if isinstance(op_ctxt, WhileContext):
+ merge_op = grad_ctxt.grad_state.switch_map.get(op)
if merge_op:
# This is the second time this Switch is visited. It comes from
# the non-exit branch of the Switch, so update the second input
# to the Merge.
- # TODO: Need to perform shape inference with this new input.
+ # TODO: Perform shape inference with this new input.
# pylint: disable=protected-access
merge_op._update_input(1, control_flow_ops._NextIteration(grad[1]))
# pylint: enable=protected-access
@@ -58,21 +59,22 @@ def _SwitchGrad(op, *grad):
# input of merge when we see this Switch the second time.
merge_fn = control_flow_ops._Merge # pylint: disable=protected-access
merge_op = merge_fn([grad[0], grad[0]], name="b_switch")[0]
- op.grad_state.switch_map[real_op] = merge_op.op
+ grad_ctxt.grad_state.switch_map[op] = merge_op.op
return merge_op, None
- elif isinstance(ctxt, CondContext):
- good_grad = grad[ctxt.branch]
- zero_grad = grad[1 - ctxt.branch]
- # If this Switch is wrapped, it is part of a cond within a loop. In
- # this case, we have called ControlFlowState.ZeroLike() so grad is
- # ready for merge. Otherwise, we need a switch to control zero_grad.
- if not isinstance(op, ControlFlowOpWrapper):
+ elif isinstance(op_ctxt, CondContext):
+ good_grad = grad[op_ctxt.branch]
+ zero_grad = grad[1 - op_ctxt.branch]
+ # If we are in a grad context, this switch is part of a cond within a
+ # loop. In this case, we have called ControlFlowState.ZeroLike() so grad
+ # is ready for merge. Otherwise, we need a switch to control zero_grad.
+ if not (grad_ctxt and grad_ctxt.grad_state):
dtype = good_grad.dtype
- zero_grad = switch(zero_grad, ctxt.pred, dtype=dtype)[1 - ctxt.branch]
+ branch = op_ctxt.branch
+ zero_grad = switch(zero_grad, op_ctxt.pred, dtype=dtype)[1 - branch]
return merge([good_grad, zero_grad], name="cond_grad")[0], None
else:
- false_grad = switch(grad[0], real_op.inputs[1])[0]
- true_grad = switch(grad[1], real_op.inputs[1])[1]
+ false_grad = switch(grad[0], op.inputs[1])[0]
+ true_grad = switch(grad[1], op.inputs[1])[1]
return merge([false_grad, true_grad])[0], None
@@ -83,24 +85,24 @@ ops.RegisterGradient("RefSwitch")(_SwitchGrad)
@ops.RegisterGradient("Merge")
def _MergeGrad(op, grad, _):
"""Gradients for a Merge op are calculated using a Switch op."""
- real_op = GetRealOp(op)
- input_op = real_op.inputs[0].op
+ input_op = op.inputs[0].op
+ graph = ops.get_default_graph()
# pylint: disable=protected-access
- ctxt = input_op._get_control_flow_context()
+ op_ctxt = input_op._get_control_flow_context()
+ grad_ctxt = graph._get_control_flow_context()
# pylint: enable=protected-access
- if isinstance(ctxt, WhileContext):
- grad_ctxt = op.grad_state.grad_context
+ if isinstance(op_ctxt, WhileContext):
# pylint: disable=protected-access
return control_flow_ops._SwitchRefOrTensor(grad, grad_ctxt.pivot)
# pylint: enable=protected-access
- elif isinstance(ctxt, CondContext):
- pred = ctxt.pred
- if isinstance(op, ControlFlowOpWrapper):
+ elif isinstance(op_ctxt, CondContext):
+ pred = op_ctxt.pred
+ if grad_ctxt and grad_ctxt.grad_state:
# This Merge node is part of a cond within a loop.
# The backprop needs to have the value of this predicate for every
# iteration. So we must have its values accumulated in the forward, and
# use the accumulated values as the predicate for this backprop switch.
- grad_state = op.grad_state
+ grad_state = grad_ctxt.grad_state
real_pred = grad_state.history_map.get(pred.name)
if not real_pred:
# Remember the value of pred for every iteration.
@@ -118,8 +120,8 @@ def _MergeGrad(op, grad, _):
return control_flow_ops._SwitchRefOrTensor(grad, pred, name="cond_grad")
# pylint: enable=protected-access
else:
- num_inputs = len(real_op.inputs)
- cond = [math_ops.equal(real_op.outputs[1], i) for i in xrange(num_inputs)]
+ num_inputs = len(op.inputs)
+ cond = [math_ops.equal(op.outputs[1], i) for i in xrange(num_inputs)]
# pylint: disable=protected-access
return [control_flow_ops._SwitchRefOrTensor(grad, cond[i])[1]
for i in xrange(num_inputs)]
@@ -132,16 +134,17 @@ def _RefMergeGrad(op, grad, _):
@ops.RegisterGradient("Exit")
-def _ExitGrad(op, grad):
+def _ExitGrad(_, grad):
"""Gradients for an exit op are calculated using an Enter op."""
- real_op = GetRealOp(op)
+ graph = ops.get_default_graph()
# pylint: disable=protected-access
- forward_ctxt = real_op._get_control_flow_context()
+ grad_ctxt = graph._get_control_flow_context()
# pylint: enable=protected-access
- if not forward_ctxt.back_prop:
- # No gradient computation for this loop.
+ if not grad_ctxt.back_prop:
+ # The flag `back_prop` is set by users to suppress gradient
+ # computation for this loop. If the flag `back_prop` is true,
+ # no gradient computation.
return None
- grad_ctxt = op.grad_state.grad_context
grad_ctxt.AddName(grad.name)
enter_fn = control_flow_ops._Enter # pylint: disable=protected-access
grad_ctxt.Enter()
@@ -176,17 +179,14 @@ def _EnterGrad(op, grad):
For loop variables, grad is the gradient so just add an exit.
For loop invariants, we need to add an accumulator loop.
"""
- real_op = GetRealOp(op)
+ graph = ops.get_default_graph()
# pylint: disable=protected-access
- forward_ctxt = real_op._get_control_flow_context()
+ grad_ctxt = graph._get_control_flow_context()
# pylint: enable=protected-access
- if not forward_ctxt.back_prop:
- # The flag `back_prop` is set by users to suppress gradient
- # computation for this loop. If the flag `back_prop` is true,
- # no gradient computation.
+ if not grad_ctxt.back_prop:
+ # If the flag `back_prop` is true, no gradient computation.
return grad
- grad_ctxt = op.grad_state.grad_context
- if real_op.get_attr("is_constant"):
+ if op.get_attr("is_constant"):
# Add a gradient accumulator for each loop invariant.
result = grad_ctxt.AddBackPropAccumulator(grad)
else:
diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py
index cfd2bed9c5..e7f8a6d76c 100644
--- a/tensorflow/python/ops/control_flow_ops.py
+++ b/tensorflow/python/ops/control_flow_ops.py
@@ -279,23 +279,23 @@ def _SwitchRefOrTensor(data, pred, name="Switch"):
TypeError: if data is not a Tensor or IndexedSlices
"""
data = ops.convert_to_tensor_or_indexed_slices(data, name="data")
- # NOTE(mrry): ops.device(None) below addresses the following scenario.
+ # NOTE(vrv): ops.colocate_with(data, ignore_existing=True) below
+ # addresses the following scenario.
#
# Assume you execute Optimizer.apply_gradients() in a branch of a cond().
#
- # 1. The update op is created inside a `with tf.device(var.device):` block
- # say var.device = "/job:ps/task:1".
+ # 1. The update op is created inside a `with ops.colocate(var):` block
#
# 2. Some tensor `data` is captured and a switch is created in a
- # `with tf.device(data.device):` block (data.device = "/job:worker_train").
+ # `with ops.colocate_with(data):` block.
#
- # with tf.device("/job:ps/task:1"):
- # with tf.device("/job:worker_train"):
+ # with ops.colocate_with(var):
+ # with ops.colocate_with(data):
# op = ...
#
- # But then calling `print op.device` returns:
- # ==> "/job:worker_train/task:1" -- a device that doesn't exist in this case!
- with ops.colocate_with(data):
+ # var and data may be pinned to different devices, so we want to ops
+ # created within ops.colocate_with(data) to ignore the existing stack.
+ with ops.colocate_with(data, ignore_existing=True):
if isinstance(data, ops.Tensor):
if not data.dtype.is_ref_dtype:
return switch(data, pred, name=name)
@@ -324,142 +324,21 @@ def _convert_flows_to_tensorarrays(tensors_or_tensorarrays, tensors_or_flows):
for (ta, t_or_flow) in zip(tensors_or_tensorarrays, tensors_or_flows)]
-class ControlFlowOpWrapper(object):
- """A wrapper class for Operation.
-
- A wrapped op allows us to capture the uses of its inputs and outputs. In
- gradients(), right before calling the gradient function of an op, we wrap
- the op by calling MakeWrapper. So during the exection of the gradient
- function of an op , any time when one of its inputs/outputs is used, we
- generate code to remember its values for all iterations.
- """
-
- class _ControlFlowOpInputs(object):
- """An indirection to capture the input tensors needed in backprop."""
-
- def __init__(self, op, grad_state):
- self._op = op
- self._grad_state = grad_state
- self._inputs = None
-
- def __len__(self):
- return len(self._op._inputs)
-
- def __getitem__(self, index):
- if self._inputs is None:
- self._inputs = [None for _ in self._op.inputs]
- if isinstance(index, int):
- val = self._inputs[index]
- if val is None:
- f_val = self._op.inputs[index]
- val = self._grad_state.GetRealValue(f_val)
- self._inputs[index] = val
- return val
- elif isinstance(index, slice):
- start, stop, step = index.indices(len(self))
- vals = [self[i] for i in xrange(start, stop, step)]
- return vals
- else:
- raise TypeError("index must be an integer or slice")
-
- class _ControlFlowOpOutputs(object):
- """An indirection to capture the output tensors needed in backprop."""
-
- def __init__(self, op, grad_state):
- self._op = op
- self._grad_state = grad_state
- self._outputs = None
-
- def __len__(self):
- return len(self._op._outputs)
-
- def __getitem__(self, index):
- if self._outputs is None:
- self._outputs = [None for _ in self._op.outputs]
- if isinstance(index, int):
- val = self._outputs[index]
- if val is None:
- f_val = self._op.outputs[index]
- val = self._grad_state.GetRealValue(f_val)
- self._outputs[index] = val
- return val
- elif isinstance(index, slice):
- start, stop, step = index.indices(len(self))
- vals = [self[i] for i in xrange(start, stop, step)]
- return vals
- else:
- raise TypeError("index must be an integer or slice")
-
- def __init__(self, op, grad_state):
- self._grad_state = grad_state # The GradLoopState this op belongs to.
- self._op = op
- self._inputs = None
- self._outputs = None
-
- @property
- def grad_state(self):
- return self._grad_state
-
- @property
- def inputs(self):
- if self._inputs is None:
- self._inputs = self._ControlFlowOpInputs(self._op, self._grad_state)
- return self._inputs
-
- @property
- def outputs(self):
- if self._outputs is None:
- self._outputs = self._ControlFlowOpOutputs(self._op, self._grad_state)
- return self._outputs
-
- @property
- def op(self):
- return self._op
-
- @property
- def name(self):
- """Returns the name of this instance of op."""
- return self._op.name
-
- @property
- def _id(self):
- """Returns the unique id of this operation."""
- return self._op._id
-
- @property
- def device(self):
- """Returns the device of this operation.
-
- Returns:
- a string or None if the device was not set.
- """
- return self._op.device
-
- @property
- def type(self):
- """Returns the type of the op."""
- return self._op.type
-
- @property
- def graph(self):
- """The `Graph` that contains this operation."""
- return self._op.graph
-
- def get_attr(self, name):
- """Returns the value of the attr of this op with the given `name`."""
- return self._op.get_attr(name)
-
- def _get_control_flow_context(self):
- """Returns the control flow context of this op."""
- return self._op._get_control_flow_context()
-
-
def _IsLoopConstantEnter(op):
- """Returns true iff op is a loop invariant."""
+ """Return true iff op is a loop invariant."""
is_enter = (op.type == "Enter" or op.type == "RefEnter")
return is_enter and op.get_attr("is_constant")
+def _GetLoopConstantEnter(value):
+ """Return the enter op if we can infer `value` to be a loop invariant."""
+ id_ops = {"Switch", "RefSwitch", "Identity", "RefIdentity"}
+ op = value.op
+ while op.type in id_ops:
+ op = op.inputs[0].op
+ return op if _IsLoopConstantEnter(op) else None
+
+
def _IsLoopExit(op):
return op.type == "Exit" or op.type == "RefExit"
@@ -531,7 +410,8 @@ class GradLoopState(object):
self._grad_context = WhileContext(forward_ctxt.parallel_iterations,
forward_ctxt.back_prop,
forward_ctxt.swap_memory,
- forward_ctxt.name)
+ forward_ctxt.name,
+ self)
real_cnt = outer_grad_state.AddBackPropAccumulatedValue(history_cnt, cnt)
self._grad_index = self._grad_context.AddBackPropCounter(real_cnt)
outer_grad_ctxt.Exit()
@@ -540,7 +420,8 @@ class GradLoopState(object):
self._grad_context = WhileContext(forward_ctxt.parallel_iterations,
forward_ctxt.back_prop,
forward_ctxt.swap_memory,
- forward_ctxt.name)
+ forward_ctxt.name,
+ self)
self._grad_index = self._grad_context.AddBackPropCounter(cnt)
if outer_forward_ctxt: outer_forward_ctxt.Exit()
@@ -629,55 +510,59 @@ class GradLoopState(object):
edge from the push op to either `forward_index.op` or `forward_sync`.
Args:
- value: The tensor that is to be accumulated.
+ value: The source tensor in forward that is to be accumulated.
dead_branch: True iff the tensor is on a dead branch of a cond.
Returns:
The stack that contains the accumulated history of the tensor.
"""
- # TODO(yuanbyu): Make sure the colocation of stack ops and value.
- # pylint: disable=protected-access
- acc = gen_data_flow_ops._stack(value.dtype.base_dtype, name="f_acc")
- # pylint: enable=protected-access
-
- # Make acc available in the forward context.
- enter_acc = self.forward_context.AddValue(acc)
-
- # Add the stack_push op in the context of value.op.
- swap_enabled = self.forward_context.swap_memory
- value_ctxt = value.op._get_control_flow_context()
- if _IsLoopExit(value.op):
- value_ctxt = value_ctxt.outer_context
- if value_ctxt == self.forward_context:
- # value is not nested in the forward context.
- self.forward_context.Enter()
- push = gen_data_flow_ops._stack_push(enter_acc, value,
- swap_memory=swap_enabled)
- self.forward_context.Exit()
- # Protect stack push and order it before forward_index.
- self.forward_index.op._add_control_input(push.op)
- else:
- # value is in a cond context within the forward context.
- assert isinstance(value_ctxt, CondContext)
- if dead_branch:
- # The special case for creating a zero tensor for a dead
- # branch of a switch. See ControlFlowState.ZerosLike().
- value_ctxt.outer_context.Enter()
- push = gen_data_flow_ops._stack_push(enter_acc, value,
- swap_memory=swap_enabled)
- value_ctxt.outer_context.Exit()
- push.op._set_control_flow_context(value_ctxt)
+ curr_ctxt = ops.get_default_graph()._get_control_flow_context()
+ with ops.control_dependencies(None):
+ if curr_ctxt: curr_ctxt.Enter()
+ with ops.colocate_with(value):
+ # pylint: disable=protected-access
+ acc = gen_data_flow_ops._stack(value.dtype.base_dtype, name="f_acc")
+ # pylint: enable=protected-access
+ if curr_ctxt: curr_ctxt.Exit()
+
+ # Make acc available in the forward context.
+ enter_acc = self.forward_context.AddValue(acc)
+
+ # Add the stack_push op in the context of value.op.
+ swap_enabled = self.forward_context.swap_memory
+ value_ctxt = value.op._get_control_flow_context()
+ if _IsLoopExit(value.op):
+ value_ctxt = value_ctxt.outer_context
+ if value_ctxt == self.forward_context:
+ # value is not nested in the forward context.
+ self.forward_context.Enter()
+ push = gen_data_flow_ops._stack_push(
+ enter_acc, value, swap_memory=swap_enabled)
+ self.forward_context.Exit()
+ # Protect stack push and order it before forward_index.
+ self.forward_index.op._add_control_input(push.op)
else:
- value_ctxt.Enter()
- push = gen_data_flow_ops._stack_push(enter_acc, value,
- swap_memory=swap_enabled)
- value_ctxt.Exit()
- # Protect stack push and order it before forward_sync.
- self.forward_sync._add_control_input(push.op)
- # Order stack push after the successor of forward_index
- add_op = self.forward_index.op.inputs[0].op
- push.op._add_control_input(add_op)
- return acc
+ # value is in a cond context within the forward context.
+ assert isinstance(value_ctxt, CondContext)
+ if dead_branch:
+ # The special case for creating a zero tensor for a dead
+ # branch of a switch. See ControlFlowState.ZerosLike().
+ value_ctxt.outer_context.Enter()
+ push = gen_data_flow_ops._stack_push(
+ enter_acc, value, swap_memory=swap_enabled)
+ value_ctxt.outer_context.Exit()
+ push.op._set_control_flow_context(value_ctxt)
+ else:
+ value_ctxt.Enter()
+ push = gen_data_flow_ops._stack_push(
+ enter_acc, value, swap_memory=swap_enabled)
+ value_ctxt.Exit()
+ # Protect stack push and order it before forward_sync.
+ self.forward_sync._add_control_input(push.op)
+ # Order stack push after the successor of forward_index
+ add_op = self.forward_index.op.inputs[0].op
+ push.op._add_control_input(add_op)
+ return acc
def AddBackPropAccumulatedValue(self, history_value, value,
dead_branch=False):
@@ -704,60 +589,67 @@ class GradLoopState(object):
cond_ctxt = value_ctxt
break
value_ctxt = value_ctxt.outer_context
- if cond_ctxt:
- # Guard stack pop with a switch if it is controlled by a cond
- grad_state = self
- pred = None
- while not pred and grad_state:
- pred = grad_state.history_map.get(cond_ctxt.pred.name)
- grad_state = grad_state.outer_grad_state
- branch = (1 - cond_ctxt.branch) if dead_branch else cond_ctxt.branch
- history_value = _SwitchRefOrTensor(history_value, pred)[branch]
- pop = gen_data_flow_ops._stack_pop(history_value, value.dtype.base_dtype)
+ with ops.control_dependencies(None):
+ self.grad_context.Enter()
+ if cond_ctxt:
+ # Guard stack pop with a switch if it is controlled by a cond
+ grad_state = self
+ pred = None
+ while not pred and grad_state:
+ pred = grad_state.history_map.get(cond_ctxt.pred.name)
+ grad_state = grad_state.outer_grad_state
+ branch = (1 - cond_ctxt.branch) if dead_branch else cond_ctxt.branch
+ history_value = _SwitchRefOrTensor(history_value, pred)[branch]
+ pop = gen_data_flow_ops._stack_pop(history_value, value.dtype.base_dtype)
+ self.grad_context.Exit()
if self.grad_context.parallel_iterations > 1:
# All pops are ordered after pivot_for_body and before grad_sync.
self.grad_sync._add_control_input(pop.op)
return pop
def GetRealValue(self, value):
- """Get the real value.
+ """Get the real value of `value`.
- If backprop "uses" a value produced by forward inference, an
- accumulator is added in the forward loop to accumulate its values.
- We use the accumulated value.
+ If backprop "uses" a value produced by forward inference, an accumulator
+ is added in the forward loop to accumulate its values. We use the
+ accumulated value. This method must be called in the grad loop context.
+ `value` must be in forward and needed for backprop.
Args:
value: A tensor to be captured.
Returns:
- The same tensor value from the saved history.
+ The same tensor obtained from the saved history.
"""
assert value.op.type != "Variable"
real_value = self._history_map.get(value.name)
if real_value is None:
- if _IsLoopConstantEnter(value.op):
- # Special case for loop invariant.
- if self._outer_grad_state:
- # This is a nested loop so we record the history of this
- # value in outer_forward_ctxt.
+ cur_value = value
+ cur_grad_state = self
+ while True:
+ enter_op = _GetLoopConstantEnter(cur_value)
+ if enter_op:
+ # Special case: cur_value comes from a constant Enter node.
+ cur_value = enter_op.inputs[0]
+ if self._outer_grad_state:
+ cur_grad_state = cur_grad_state.outer_grad_state
+ else:
+ # We are now outside all nested loops for this gradient(),
+ # so `value` is a loop invariant and there is no need to
+ # save the history of value.
+ real_value = self._grad_context.AddValue(cur_value)
+ break
+ else:
+ # Record the history of this value in forward_ctxt.
+ # TODO(yuanbyu): Avoid recording constants.
self._grad_context.Exit()
- outer_value = value.op.inputs[0]
- history_value = self._outer_grad_state.AddForwardAccumulator(
- outer_value)
+ h_value = cur_grad_state.AddForwardAccumulator(cur_value)
self._grad_context.Enter()
- else:
- # Just use the input value of this Enter node.
- real_value = GetRealOp(value.op).inputs[0]
- else:
- # Record the history of this value in forward_ctxt.
- # NOTE(yuanbyu): Don't record for constants.
- self._grad_context.Exit()
- history_value = self.AddForwardAccumulator(value)
- self._grad_context.Enter()
+ break
if real_value is None:
# Add the stack pop op in the grad context.
- real_value = self.AddBackPropAccumulatedValue(history_value, value)
+ real_value = self.AddBackPropAccumulatedValue(h_value, value)
self._history_map[value.name] = real_value
return real_value
@@ -776,9 +668,9 @@ class ControlFlowState(object):
def __init__(self):
self._map = {} # maps forward loop context to GradLoopState
- def _GetGradState(self, op):
- """Get the gradient loop state for this op if any."""
- if _IsLoopExit(op):
+ def _GetGradState(self, op, before):
+ """Return the grad state for this op if it's in a forward loop context."""
+ if before and _IsLoopExit(op):
forward_ctxt = op._get_control_flow_context()
forward_ctxt = forward_ctxt.outer_context
if forward_ctxt:
@@ -789,15 +681,6 @@ class ControlFlowState(object):
return self._map.get(forward_ctxt)
return None
- def MakeWrapper(self, op):
- """Make a wrapper for op if it is in a WhileContext."""
- forward_ctxt = _GetWhileContext(op)
- if forward_ctxt:
- grad_state = self._map.get(forward_ctxt)
- if grad_state:
- return ControlFlowOpWrapper(op, grad_state)
- return op
-
def GetAllLoopExits(self):
"""Return a list containing the exits of all the loops."""
loop_exits = []
@@ -806,15 +689,15 @@ class ControlFlowState(object):
loop_exits.append(loop_exit)
return loop_exits
- def EnterGradWhileContext(self, op):
+ def EnterGradWhileContext(self, op, before):
"""Enter the WhileContext for gradient computation."""
- grad_state = self._GetGradState(op)
+ grad_state = self._GetGradState(op, before)
if grad_state:
grad_state.grad_context.Enter()
- def ExitGradWhileContext(self, op):
+ def ExitGradWhileContext(self, op, before):
"""Exit the WhileContext for gradient computation."""
- grad_state = self._GetGradState(op)
+ grad_state = self._GetGradState(op, before)
if grad_state:
grad_state.grad_context.Exit()
@@ -877,12 +760,18 @@ class ControlFlowState(object):
result = array_ops.zeros(val_shape.dims, val.dtype)
outer_grad_state.grad_context.Exit()
else:
- history_val = outer_grad_state.AddForwardAccumulator(val)
+ # Only the shape of value is needed for backprop.
+ forward_ctxt.outer_context.Enter()
+ shape = array_ops.shape(value)
+ forward_ctxt.outer_context.Exit()
+ # Save the shape to a stack.
+ history_shape = outer_grad_state.AddForwardAccumulator(shape)
+ # Get the shape back from the stack.
outer_grad_ctxt = outer_grad_state.grad_context
outer_grad_ctxt.Enter()
- real_val = outer_grad_state.AddBackPropAccumulatedValue(
- history_val, val)
- result = array_ops.zeros_like(real_val)
+ real_shape = outer_grad_state.AddBackPropAccumulatedValue(
+ history_shape, shape)
+ result = array_ops.zeros(real_shape, value.dtype)
outer_grad_ctxt.Exit()
else:
# This is not a nested loop.
@@ -943,23 +832,17 @@ class ControlFlowState(object):
# Add forward accumulator for shape.
grad_state.grad_context.Exit()
- history_shape = grad_state.AddForwardAccumulator(zeros_shape, dead_branch)
+ h_shape = grad_state.AddForwardAccumulator(
+ zeros_shape, dead_branch=dead_branch)
grad_state.grad_context.Enter()
# Create a zero tensor with the right shape.
shape = grad_state.AddBackPropAccumulatedValue(
- history_shape, zeros_shape, dead_branch)
+ h_shape, zeros_shape, dead_branch)
result = array_ops.zeros(shape, val.dtype)
return result
-def GetRealOp(op):
- """Get the real op by removing the wrapper."""
- while isinstance(op, ControlFlowOpWrapper):
- op = op.op
- return op
-
-
def MaybeCreateControlFlowState(between_op_list, between_ops):
"""Create the state for all the while loops involved in one gradients().
@@ -1106,6 +989,9 @@ class CondContext(ControlFlowContext):
return result
def AddOp(self, op):
+ self._AddOpInternal(op)
+
+ def _AddOpInternal(self, op):
"""Add `op` to the current context."""
if not op.inputs:
# Add this op to the enclosing while context
@@ -1248,7 +1134,8 @@ def cond(pred, fn1, fn2, name=None):
class WhileContext(ControlFlowContext):
"""The context for the loop construct."""
- def __init__(self, parallel_iterations, back_prop, swap_memory, name):
+ def __init__(self, parallel_iterations, back_prop, swap_memory, name,
+ grad_state=None):
ControlFlowContext.__init__(self)
self._name = ops.get_default_graph().unique_name(name)
self._parallel_iterations = parallel_iterations
@@ -1263,6 +1150,8 @@ class WhileContext(ControlFlowContext):
self._pivot = None
# The list of exit tensors for loop variables.
self._loop_exits = None
+ # The gradient loop state
+ self._grad_state = grad_state
@property
def name(self):
@@ -1293,6 +1182,11 @@ class WhileContext(ControlFlowContext):
"""The list of exit tensors for loop variables."""
return self._loop_exits
+ @property
+ def grad_state(self):
+ """The gradient loop state."""
+ return self._grad_state
+
def GetWhileContext(self):
return self
@@ -1306,6 +1200,22 @@ class WhileContext(ControlFlowContext):
result = val
if val.name not in self._values:
self._values.add(val.name)
+
+ # If we are in a grad context and val is from its forward context,
+ # use GetRealValue(), which adds the logic to save the history of
+ # val in forward.
+ grad_ctxt = ops.get_default_graph()._get_control_flow_context()
+ if grad_ctxt:
+ grad_ctxt = grad_ctxt.GetWhileContext()
+ if grad_ctxt.grad_state:
+ forward_ctxt = _GetWhileContext(val.op)
+ if _IsLoopExit(val.op):
+ forward_ctxt = forward_ctxt.outer_context
+ if forward_ctxt == grad_ctxt.grad_state.forward_context:
+ real_val = grad_ctxt.grad_state.GetRealValue(val)
+ self._external_values[val.name] = real_val
+ return real_val
+
if self._outer_context is not None:
result = self._outer_context.AddValue(val)
# Create an Enter to make `result` known to this loop context.
@@ -1327,7 +1237,27 @@ class WhileContext(ControlFlowContext):
return result
def AddOp(self, op):
- """Adds `op` to the current context."""
+ """Add `op` to the current context."""
+ # For a reduction op, if op is in a grad context and its input is from
+ # its forward context, moving op to the forward context means we would
+ # store the tensor after the reduction as opposed to the tensor before
+ # reduction, and therefore could significantly reduce memory consumption.
+ # For now, we do this only for a few ops.
+ if op.type in {"Shape", "Size", "Rank"}:
+ grad_ctxt = ops.get_default_graph()._get_control_flow_context()
+ if grad_ctxt:
+ grad_ctxt = grad_ctxt.GetWhileContext()
+ if grad_ctxt.grad_state:
+ op_input_forward_ctxt = _GetWhileContext(op.inputs[0].op)
+ if op_input_forward_ctxt == grad_ctxt.grad_state.forward_context:
+ op_input_ctxt = op.inputs[0].op._get_control_flow_context()
+ op._set_control_flow_context(op_input_ctxt)
+ op_input_ctxt._AddOpInternal(op)
+ return
+ self._AddOpInternal(op)
+
+ def _AddOpInternal(self, op):
+ """Add `op` to the current context."""
if not op.inputs:
if not op.control_inputs:
# Add a control edge from the control pivot to this op.
@@ -1863,7 +1793,6 @@ def foldr(fn, elems, initializer=None, name=None):
fn: The function to be performed.
elems: A tensor that is unpacked into a sequence of tensors to apply `fn`.
initializer: (optional) The initial value for the accumulator.
- use_tensor_array: (optional) use tensor_array if true.
name: (optional) Name prefix for the returned tensors.
Returns:
diff --git a/tensorflow/python/ops/gradients.py b/tensorflow/python/ops/gradients.py
index ced841f269..9fc1aa80d1 100644
--- a/tensorflow/python/ops/gradients.py
+++ b/tensorflow/python/ops/gradients.py
@@ -427,14 +427,14 @@ def gradients(ys,
op = queue.popleft()
with _maybe_colocate_with(op, colocate_gradients_with_ops):
if loop_state:
- loop_state.EnterGradWhileContext(op)
+ loop_state.EnterGradWhileContext(op, before=True)
out_grads = _AggregatedGrads(grads, op, loop_state, aggregation_method)
- grad_fn = None
+ if loop_state:
+ loop_state.ExitGradWhileContext(op, before=True)
+ grad_fn = None
# pylint: disable=protected-access
is_func_call = ops.get_default_graph()._is_function(op.type)
- # pylint: enable=protected-access
-
if not is_func_call and any(out_grads) and op._id not in stop_ops:
# pylint: enable=protected-access
# A grad_fn must be defined, either as a function or as None
@@ -445,6 +445,9 @@ def gradients(ys,
raise LookupError(
"No gradient defined for operation '%s' (op type: %s)" %
(op.name, op.type))
+
+ if loop_state:
+ loop_state.EnterGradWhileContext(op, before=False)
if (grad_fn or is_func_call) and any(out_grads):
# NOTE: If _AggregatedGrads didn't compute a value for the i'th
# output, it means that the cost does not depend on output[i],
@@ -461,9 +464,6 @@ def gradients(ys,
# pylint: disable=protected-access
with ops.get_default_graph()._original_op(op):
# pylint: enable=protected-access
- wrapped_op = op
- if loop_state:
- wrapped_op = loop_state.MakeWrapper(op)
if is_func_call:
# For function call ops, we add a 'SymbolicGradient'
# node to the graph to compute gradients.
@@ -474,7 +474,7 @@ def gradients(ys,
f_in, f_types, op.type))
# pylint: enable=protected-access
else:
- in_grads = _AsList(grad_fn(wrapped_op, *out_grads))
+ in_grads = _AsList(grad_fn(op, *out_grads))
_VerifyGeneratedGradients(in_grads, op)
if gate_gradients and len(tuple(filter(None, in_grads))) > 1:
in_grads = control_flow_ops.tuple(in_grads)
@@ -491,7 +491,7 @@ def gradients(ys,
if in_grad:
_SetGrad(grads, t_in, in_grad)
if loop_state:
- loop_state.ExitGradWhileContext(op)
+ loop_state.ExitGradWhileContext(op, before=False)
# update pending count for the inputs of op.
# pylint: disable=protected-access
diff --git a/tensorflow/python/ops/image_ops.py b/tensorflow/python/ops/image_ops.py
index 619922983a..d65910abb2 100644
--- a/tensorflow/python/ops/image_ops.py
+++ b/tensorflow/python/ops/image_ops.py
@@ -644,7 +644,8 @@ def resize_images(images,
new_width_const = tensor_util.constant_value(new_width)
new_height_const = tensor_util.constant_value(new_height)
- if width == new_width_const and height == new_height_const:
+ if new_width_const is not None and new_height_const is not None and (
+ width == new_width_const and height == new_height_const):
if not is_batch:
images = array_ops.squeeze(images, squeeze_dims=[0])
return images
diff --git a/tensorflow/python/ops/rnn.py b/tensorflow/python/ops/rnn.py
index e9a029259c..611f5fa314 100644
--- a/tensorflow/python/ops/rnn.py
+++ b/tensorflow/python/ops/rnn.py
@@ -269,6 +269,10 @@ def _reverse_seq(input_seq, lengths):
# Join into (time, batch_size, depth)
s_joined = array_ops.pack(input_seq)
+ # TODO(schuster, ebrevdo): Remove cast when reverse_sequence takes int32
+ if lengths is not None:
+ lengths = math_ops.to_int64(lengths)
+
# Reverse along dimension 0
s_reversed = array_ops.reverse_sequence(s_joined, lengths, 0, 1)
# Split again into list
@@ -346,9 +350,9 @@ def bidirectional_rnn(cell_fw, cell_bw, inputs,
return (outputs, output_state_fw, output_state_bw)
-def dynamic_rnn(cell, inputs, sequence_length, initial_state=None, dtype=None,
- parallel_iterations=None, swap_memory=False, time_major=False,
- scope=None):
+def dynamic_rnn(cell, inputs, sequence_length=None, initial_state=None,
+ dtype=None, parallel_iterations=None, swap_memory=False,
+ time_major=False, scope=None):
"""Creates a recurrent neural network specified by RNNCell "cell".
This function is functionally identical to the function `rnn` above, but
@@ -369,9 +373,9 @@ def dynamic_rnn(cell, inputs, sequence_length, initial_state=None, dtype=None,
`[batch_size, max_time, cell.input_size]`.
If time_major == True, this must be a tensor of shape:
`[max_time, batch_size, cell.input_size]`.
- sequence_length: An int32/int64 vector (tensor) size [batch_size].
+ sequence_length: (optional) An int32/int64 vector sized `[batch_size]`.
initial_state: (optional) An initial state for the RNN. This must be
- a tensor of appropriate type and shape [batch_size x cell.state_size].
+ a tensor of appropriate type and shape `[batch_size x cell.state_size]`.
dtype: (optional) The data type for the initial state. Required if
initial_state is not provided.
parallel_iterations: (Default: 32). The number of iterations to run in
@@ -415,8 +419,10 @@ def dynamic_rnn(cell, inputs, sequence_length, initial_state=None, dtype=None,
inputs = array_ops.transpose(inputs, [1, 0, 2]) # (B,T,D) => (T,B,D)
parallel_iterations = parallel_iterations or 32
- sequence_length = math_ops.to_int32(sequence_length)
- sequence_length = array_ops.identity(sequence_length, name="sequence_length")
+ if sequence_length is not None:
+ sequence_length = math_ops.to_int32(sequence_length)
+ sequence_length = array_ops.identity( # Just to find it in the graph.
+ sequence_length, name="sequence_length")
# Create a new scope in which the caching device is either
# determined by the parent scope, or is set to place the cached
@@ -442,15 +448,16 @@ def dynamic_rnn(cell, inputs, sequence_length, initial_state=None, dtype=None,
["Expected shape for Tensor %s is " % x.name,
packed_shape, " but saw shape: ", x_shape])
- # Perform some shape validation
- with ops.control_dependencies(
- [_assert_has_shape(sequence_length, [batch_size])]):
- sequence_length = array_ops.identity(sequence_length, name="CheckSeqLen")
+ if sequence_length is not None:
+ # Perform some shape validation
+ with ops.control_dependencies(
+ [_assert_has_shape(sequence_length, [batch_size])]):
+ sequence_length = array_ops.identity(
+ sequence_length, name="CheckSeqLen")
(outputs, final_state) = _dynamic_rnn_loop(
- cell, inputs, state, sequence_length,
- parallel_iterations=parallel_iterations,
- swap_memory=swap_memory)
+ cell, inputs, state, parallel_iterations=parallel_iterations,
+ swap_memory=swap_memory, sequence_length=sequence_length)
# Outputs of _dynamic_rnn_loop are always shaped [time, batch, depth].
# If we are performing batch-major calculations, transpose output back
@@ -461,17 +468,18 @@ def dynamic_rnn(cell, inputs, sequence_length, initial_state=None, dtype=None,
return (outputs, final_state)
-def _dynamic_rnn_loop(cell, inputs, initial_state, sequence_length,
- parallel_iterations, swap_memory):
+def _dynamic_rnn_loop(
+ cell, inputs, initial_state, parallel_iterations, swap_memory,
+ sequence_length=None):
"""Internal implementation of Dynamic RNN.
Args:
cell: An instance of RNNCell.
inputs: A `Tensor` of shape [time, batch_size, depth].
initial_state: A `Tensor` of shape [batch_size, depth].
- sequence_length: An `int32` `Tensor` of shape [batch_size].
parallel_iterations: Positive Python int.
swap_memory: A Python boolean
+ sequence_length: (optional) An `int32` `Tensor` of shape [batch_size].
Returns:
Tuple (final_outputs, final_state).
@@ -502,8 +510,9 @@ def _dynamic_rnn_loop(cell, inputs, initial_state, sequence_length,
# Prepare dynamic conditional copying of state & output
zero_output = array_ops.zeros(
array_ops.pack([batch_size, cell.output_size]), inputs.dtype)
- min_sequence_length = math_ops.reduce_min(sequence_length)
- max_sequence_length = math_ops.reduce_max(sequence_length)
+ if sequence_length is not None:
+ min_sequence_length = math_ops.reduce_min(sequence_length)
+ max_sequence_length = math_ops.reduce_max(sequence_length)
time = array_ops.constant(0, dtype=dtypes.int32, name="time")
@@ -536,9 +545,14 @@ def _dynamic_rnn_loop(cell, inputs, initial_state, sequence_length,
# Restore some shape information
input_t.set_shape([const_batch_size, const_depth])
- (output, new_state) = _rnn_step(
- time, sequence_length, min_sequence_length, max_sequence_length,
- zero_output, state, lambda: cell(input_t, state))
+ call_cell = lambda: cell(input_t, state)
+
+ if sequence_length is not None:
+ (output, new_state) = _rnn_step(
+ time, sequence_length, min_sequence_length, max_sequence_length,
+ zero_output, state, call_cell)
+ else:
+ (output, new_state) = call_cell()
output_ta_t = output_ta_t.write(time, output)
diff --git a/tensorflow/python/tensorflow.i b/tensorflow/python/tensorflow.i
index 65ea4d2e17..766bdf7dd3 100644
--- a/tensorflow/python/tensorflow.i
+++ b/tensorflow/python/tensorflow.i
@@ -28,5 +28,6 @@ limitations under the License.
%include "tensorflow/python/client/events_writer.i"
%include "tensorflow/python/client/tf_session.i"
+%include "tensorflow/python/client/server_lib.i"
%include "tensorflow/python/framework/python_op_gen.i"
diff --git a/tensorflow/python/training/coordinator.py b/tensorflow/python/training/coordinator.py
index 6bc36429d9..661bae7bc1 100644
--- a/tensorflow/python/training/coordinator.py
+++ b/tensorflow/python/training/coordinator.py
@@ -335,7 +335,6 @@ class LooperThread(threading.Thread):
looper.start()
return looper
- # pylint: disable=broad-except
def run(self):
with self._coord.stop_on_exception():
self.start_loop()
@@ -349,12 +348,16 @@ class LooperThread(threading.Thread):
while not self._coord.wait_for_stop(next_timer_time - time.time()):
next_timer_time += self._timer_interval_secs
self.run_loop()
- # pylint: enable=broad-except
+ self.stop_loop()
def start_loop(self):
"""Called when the thread starts."""
pass
+ def stop_loop(self):
+ """Called when the thread stops."""
+ pass
+
def run_loop(self):
"""Called at 'timer_interval_secs' boundaries."""
if self._target:
diff --git a/tensorflow/tensorboard/components/tf-categorizer/test/categorizerTest.ts b/tensorflow/tensorboard/components/tf-categorizer/test/categorizerTest.ts
index 8f858becb2..ede6a1f5a3 100644
--- a/tensorflow/tensorboard/components/tf-categorizer/test/categorizerTest.ts
+++ b/tensorflow/tensorboard/components/tf-categorizer/test/categorizerTest.ts
@@ -12,8 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-
-/// <reference path="../categorizer.ts" />
var assert = chai.assert;
module Categorizer {
diff --git a/tensorflow/tensorboard/components/tf-categorizer/test/index.html b/tensorflow/tensorboard/components/tf-categorizer/test/index.html
new file mode 100644
index 0000000000..fd4a097708
--- /dev/null
+++ b/tensorflow/tensorboard/components/tf-categorizer/test/index.html
@@ -0,0 +1,13 @@
+<!doctype html>
+<html>
+<head>
+ <meta charset="utf-8">
+ <script src="../../webcomponentsjs/webcomponents-lite.min.js"></script>
+ <script src="../../web-component-tester/browser.js"></script>
+ <link rel="import" href="../../tf-imports/d3.html">
+</head>
+<body>
+ <script src="../categorizer.js"></script>
+ <script src="categorizerTest.js"></script>
+</body>
+</html>
diff --git a/tensorflow/tensorboard/components/tf-dashboard-common/urlGenerator.ts b/tensorflow/tensorboard/components/tf-dashboard-common/urlGenerator.ts
index 7148fd3fce..00c593a049 100644
--- a/tensorflow/tensorboard/components/tf-dashboard-common/urlGenerator.ts
+++ b/tensorflow/tensorboard/components/tf-dashboard-common/urlGenerator.ts
@@ -12,9 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-
-/// <reference path="../plottable/plottable.d.ts" />
-
module TF {
export module Urls {
export type RunTagUrlFn = (tag: string, run: string) => string;
@@ -69,7 +66,21 @@ module TF {
};
};
- export function demoRouter(dataDir: string): Router {
+ export function demoRouter(dataDir: string,
+ oldVersion = false): Router {
+ if (oldVersion) {
+ return {
+ runs: () => dataDir + "runs.json",
+ graph: (run) => dataDir + run + "-graph.pbtxt",
+ scalars: (tag, run) => {
+ return dataDir + run.split("_")[0] + ".json";
+ },
+ histograms: () => null,
+ compressedHistograms: () => null,
+ images: () => null,
+ individualImage: () => null
+ };
+ }
/* Retrieves static .json data generated by demo_from_server.py */
function demoRoute(route) {
return function(tag, run) {
diff --git a/tensorflow/tensorboard/components/tf-event-dashboard/dataCoordinator.ts b/tensorflow/tensorboard/components/tf-event-dashboard/dataCoordinator.ts
index 489a2138f0..5407800710 100644
--- a/tensorflow/tensorboard/components/tf-event-dashboard/dataCoordinator.ts
+++ b/tensorflow/tensorboard/components/tf-event-dashboard/dataCoordinator.ts
@@ -12,9 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-
-/// <reference path="../plottable/plottable.d.ts" />
-
module TF {
/* The DataCoordinator generates TF.Datasets for each run/tag combination,
diff --git a/tensorflow/tensorboard/components/tf-event-dashboard/dataset.ts b/tensorflow/tensorboard/components/tf-event-dashboard/dataset.ts
index 8ced6ad0e2..3677a300d1 100644
--- a/tensorflow/tensorboard/components/tf-event-dashboard/dataset.ts
+++ b/tensorflow/tensorboard/components/tf-event-dashboard/dataset.ts
@@ -12,9 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-
-/// <reference path="../plottable/plottable.d.ts" />
-
module TF {
/* An extension of Plottable.Dataset that knows how to load data from a backend.
*/
diff --git a/tensorflow/tensorboard/components/tf-event-dashboard/tf-chart.ts b/tensorflow/tensorboard/components/tf-event-dashboard/tf-chart.ts
index 05fcf6b3e9..d799c190cf 100644
--- a/tensorflow/tensorboard/components/tf-event-dashboard/tf-chart.ts
+++ b/tensorflow/tensorboard/components/tf-event-dashboard/tf-chart.ts
@@ -12,9 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-
-/// <reference path="../plottable/plottable.d.ts" />
-
module TF {
type TFDatum = [number, number, number];
type tooltipMap = {[run: string]: string};
diff --git a/tensorflow/tensorboard/components/tf-graph-common/lib/graph.ts b/tensorflow/tensorboard/components/tf-graph-common/lib/graph.ts
index ed89706b45..b2f6d21598 100644
--- a/tensorflow/tensorboard/components/tf-graph-common/lib/graph.ts
+++ b/tensorflow/tensorboard/components/tf-graph-common/lib/graph.ts
@@ -12,8 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-
-/// <reference path="common.ts" />
module tf.graph {
/** Delimiter used in node names to denote namespaces. */
diff --git a/tensorflow/tensorboard/components/tf-graph-common/lib/hierarchy.ts b/tensorflow/tensorboard/components/tf-graph-common/lib/hierarchy.ts
index 98f34bdd3f..af5c1e97b6 100644
--- a/tensorflow/tensorboard/components/tf-graph-common/lib/hierarchy.ts
+++ b/tensorflow/tensorboard/components/tf-graph-common/lib/hierarchy.ts
@@ -12,10 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-
-/// <reference path="graph.ts" />
-/// <reference path="template.ts" />
-
/**
* Package for the Graph Hierarchy for TensorFlow graph.
*/
diff --git a/tensorflow/tensorboard/components/tf-graph-common/lib/layout.ts b/tensorflow/tensorboard/components/tf-graph-common/lib/layout.ts
index 0e7b1d17d5..0d9e5b53bf 100644
--- a/tensorflow/tensorboard/components/tf-graph-common/lib/layout.ts
+++ b/tensorflow/tensorboard/components/tf-graph-common/lib/layout.ts
@@ -12,10 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-
-/// <reference path="graph.ts" />
-/// <reference path="render.ts" />
-
module tf.graph.layout {
/** Set of parameters that define the look and feel of the graph. */
diff --git a/tensorflow/tensorboard/components/tf-graph-common/lib/parser.ts b/tensorflow/tensorboard/components/tf-graph-common/lib/parser.ts
index f88da0dd33..6d1aa875ee 100644
--- a/tensorflow/tensorboard/components/tf-graph-common/lib/parser.ts
+++ b/tensorflow/tensorboard/components/tf-graph-common/lib/parser.ts
@@ -12,8 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-
-/// <reference path="common.ts" />
module tf.graph.parser {
/**
diff --git a/tensorflow/tensorboard/components/tf-graph-common/lib/render.ts b/tensorflow/tensorboard/components/tf-graph-common/lib/render.ts
index b0ee19a25e..fa0ee99d19 100644
--- a/tensorflow/tensorboard/components/tf-graph-common/lib/render.ts
+++ b/tensorflow/tensorboard/components/tf-graph-common/lib/render.ts
@@ -12,14 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-
-/// <reference path="graph.ts" />
-/// <reference path="hierarchy.ts" />
-
/**
* Package for the Render Hierarchy for TensorFlow graph.
*/
-
module tf.graph.render {
export type Point = {x: number, y: number};
diff --git a/tensorflow/tensorboard/components/tf-graph-common/lib/scene/annotation.ts b/tensorflow/tensorboard/components/tf-graph-common/lib/scene/annotation.ts
index a50d31b5b9..b48d62c346 100644
--- a/tensorflow/tensorboard/components/tf-graph-common/lib/scene/annotation.ts
+++ b/tensorflow/tensorboard/components/tf-graph-common/lib/scene/annotation.ts
@@ -12,13 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-
-/// <reference path="../graph.ts" />
-/// <reference path="../render.ts" />
-/// <reference path="scene.ts" />
-/// <reference path="edge.ts" />
-/// <reference path="contextmenu.ts" />
-
module tf.graph.scene.annotation {
/**
diff --git a/tensorflow/tensorboard/components/tf-graph-common/lib/scene/edge.ts b/tensorflow/tensorboard/components/tf-graph-common/lib/scene/edge.ts
index d0f1e8fad6..2938aa3f1d 100644
--- a/tensorflow/tensorboard/components/tf-graph-common/lib/scene/edge.ts
+++ b/tensorflow/tensorboard/components/tf-graph-common/lib/scene/edge.ts
@@ -12,11 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-
-/// <reference path="../graph.ts" />
-/// <reference path="../render.ts" />
-/// <reference path="scene.ts" />
-
module tf.graph.scene.edge {
/** Delimiter between dimensions when showing sizes of tensors. */
diff --git a/tensorflow/tensorboard/components/tf-graph-common/lib/scene/minimap.ts b/tensorflow/tensorboard/components/tf-graph-common/lib/scene/minimap.ts
index 72464c69c4..bd8917929f 100644
--- a/tensorflow/tensorboard/components/tf-graph-common/lib/scene/minimap.ts
+++ b/tensorflow/tensorboard/components/tf-graph-common/lib/scene/minimap.ts
@@ -12,9 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-
-/// <reference path="../common.ts" />
-
module tf.scene {
/** Show minimap when the viewpoint area is less than X% of the whole area. */
diff --git a/tensorflow/tensorboard/components/tf-graph-common/lib/scene/node.ts b/tensorflow/tensorboard/components/tf-graph-common/lib/scene/node.ts
index cef46578b5..f2e73976ff 100644
--- a/tensorflow/tensorboard/components/tf-graph-common/lib/scene/node.ts
+++ b/tensorflow/tensorboard/components/tf-graph-common/lib/scene/node.ts
@@ -12,12 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-
-/// <reference path="../graph.ts" />
-/// <reference path="scene.ts" />
-/// <reference path="annotation.ts" />
-/// <reference path="contextmenu.ts" />
-
module tf.graph.scene.node {
/**
diff --git a/tensorflow/tensorboard/components/tf-graph-common/lib/scene/scene.ts b/tensorflow/tensorboard/components/tf-graph-common/lib/scene/scene.ts
index 685ad646f7..b6eb3f7d81 100644
--- a/tensorflow/tensorboard/components/tf-graph-common/lib/scene/scene.ts
+++ b/tensorflow/tensorboard/components/tf-graph-common/lib/scene/scene.ts
@@ -12,12 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-
-/// <reference path="../graph.ts" />
-/// <reference path="edge.ts" />
-/// <reference path="node.ts" />
-/// <reference path="../layout.ts" />
-
module tf.graph.scene {
/** Enums element class of objects in the scene */
diff --git a/tensorflow/tensorboard/components/tf-graph-common/lib/template.ts b/tensorflow/tensorboard/components/tf-graph-common/lib/template.ts
index 0423e1c863..93d1540939 100644
--- a/tensorflow/tensorboard/components/tf-graph-common/lib/template.ts
+++ b/tensorflow/tensorboard/components/tf-graph-common/lib/template.ts
@@ -12,10 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-
-/// <reference path="graph.ts" />
-/// <reference path="hierarchy.ts" />
-
module tf.graph.template {
/**
diff --git a/tensorflow/tensorboard/components/tf-tensorboard/demo/index.html b/tensorflow/tensorboard/components/tf-tensorboard/demo/index.html
index e97a1815c2..829769c3d0 100644
--- a/tensorflow/tensorboard/components/tf-tensorboard/demo/index.html
+++ b/tensorflow/tensorboard/components/tf-tensorboard/demo/index.html
@@ -1,40 +1,11 @@
<!DOCTYPE html>
<html>
- <head>
- <script src="../../webcomponentsjs/webcomponents-lite.min.js"></script>
- <link rel="import" href="../tf-tensorboard.html">
+<head>
+ <script src="../../webcomponentsjs/webcomponents-lite.min.js"></script>
+ <link rel="import" href="../tf-tensorboard-demo.html">
<link rel="stylesheet" type="text/css" href="../../../lib/css/global.css">
- <title>TensorBoard Demo</title>
- </head>
- <body>
- <base href="/">
- <dom-module id="x-demo">
- <template>
- <tf-tensorboard
- id="demo"
- router="[[demoRouter]]">
- </tf-tensorboard>
- </template>
- <script>
- var dataDir = "components/tf-tensorboard/demo/data/";
- var demoRouter = {
- runs: function() { return dataDir + "runs.json";},
- graph: function(run) {return dataDir + run + "-graph.pbtxt";},
- scalars: function(tag, run) {
- return dataDir + run.split("_")[0] + ".json";
- },
- };
- Polymer({
- is: "x-demo",
- properties: {
- demoRouter: {
- type: Object,
- value: demoRouter,
- },
- },
- });
- </script>
- </dom-module>
- <x-demo></x-demo>
- </body>
+</head>
+<body>
+ <tf-tensorboard-demo old-version="true" data-dir="data/"></tf-tensorboard-demo>
+</body>
</html>
diff --git a/tensorflow/tensorboard/components/tf-tensorboard/tf-tensorboard-demo.html b/tensorflow/tensorboard/components/tf-tensorboard/tf-tensorboard-demo.html
index 8fe248aff0..abed65ef66 100644
--- a/tensorflow/tensorboard/components/tf-tensorboard/tf-tensorboard-demo.html
+++ b/tensorflow/tensorboard/components/tf-tensorboard/tf-tensorboard-demo.html
@@ -9,6 +9,7 @@ json data from a "dataDir" rather than connecting to a live backend.
<tf-tensorboard
id="tensorboard"
router="[[_demoRouter]]"
+ no-hash="[[noHash]]"
></tf-tensorboard>
<style>
:host {
@@ -23,15 +24,27 @@ json data from a "dataDir" rather than connecting to a live backend.
properties: {
_demoRouter: {
type: Object,
- computed: "_makeDemoRouter(dataDir)",
+ computed: "_makeDemoRouter(dataDir, oldVersion)",
},
dataDir: {
type: String,
value: "data",
},
+ // To use the old version of the router which can serve the
+ // demo/data folder that is checked into the repository.
+ oldVersion: {
+ type: Boolean,
+ value: false
+ },
+ // If true, tab switching in TensorBoard will not update
+ // location hash. Hash update interferes with selenium tests.
+ noHash: {
+ type: Boolean,
+ value: false
+ }
},
- _makeDemoRouter: function(dataDir) {
- return TF.Urls.demoRouter(dataDir);
+ _makeDemoRouter: function(dataDir, oldVersion) {
+ return TF.Urls.demoRouter(dataDir, oldVersion);
},
});
</script>
diff --git a/tensorflow/tensorboard/components/tf-tensorboard/tf-tensorboard.html b/tensorflow/tensorboard/components/tf-tensorboard/tf-tensorboard.html
index bfcbb7ae5f..1c5ff47564 100644
--- a/tensorflow/tensorboard/components/tf-tensorboard/tf-tensorboard.html
+++ b/tensorflow/tensorboard/components/tf-tensorboard/tf-tensorboard.html
@@ -20,11 +20,11 @@ allows the user to toggle between various dashboards.
<paper-toolbar id="toolbar">
<div id="toolbar-content">
<div class="toolbar-title">TensorBoard</div>
- <paper-tabs selected="0" noink class="tabs" id="tabs">
- <paper-tab data-mode="events" on-click="changeMode">Events</paper-tab>
- <paper-tab data-mode="images" on-click="changeMode">Images</paper-tab>
- <paper-tab data-mode="graphs" on-click="changeMode">Graph</paper-tab>
- <paper-tab data-mode="histograms" on-click="changeMode">Histograms</paper-tab>
+ <paper-tabs selected="{{modeIndex}}" noink class="tabs" id="tabs">
+ <paper-tab data-mode="events">Events</paper-tab>
+ <paper-tab data-mode="images">Images</paper-tab>
+ <paper-tab data-mode="graphs">Graph</paper-tab>
+ <paper-tab data-mode="histograms">Histograms</paper-tab>
</paper-tabs>
</div>
</paper-toolbar>
@@ -111,14 +111,24 @@ allows the user to toggle between various dashboards.
type: Object,
value: TF.Urls.productionRouter(),
},
+ // Which tab is selected (events, graph, images etc).
mode: {
type: String,
- value: "events",
+ computed: '_getModeFromIndex(modeIndex)'
},
+ // If true, tab switching in TensorBoard will not update
+ // location hash. Hash update interferes with selenium tests.
+ noHash: {
+ type: Boolean,
+ value: false
+ }
},
- changeMode: function(ev) {
- var mode = ev.target.parentElement.getAttribute('data-mode');
- this._changeMode(mode, true);
+ _getModeFromIndex: function(modeIndex) {
+ var mode = this.tabs[modeIndex];
+ if (!this.noHash) {
+ window.location.hash = mode;
+ }
+ return mode;
},
eventDashboard: function(mode) {
return mode === "events";
@@ -132,36 +142,26 @@ allows the user to toggle between various dashboards.
histogramDashboard: function(mode) {
return mode === "histograms";
},
- loadPreviousMode: function() {
- this._changeMode(this._getMode(), false);
- },
ready: function() {
- this._changeMode(this._getMode(), true);
-
- var self = this;
- window.addEventListener('hashchange', function(){
- self.loadPreviousMode();
+ this.tabs = [].slice.call(this.querySelectorAll('paper-tab')).map(function(a) {
+ return a.dataset.mode;
});
+ this._getModeFromHash();
+ window.addEventListener('hashchange', function() {
+ this._getModeFromHash();
+ }.bind(this));
},
- _changeMode: function(mode, isNewState) {
- this.mode = mode;
-
- // Change the selected tab
- this.$.tabs.selected = this._tabs().indexOf(mode);
-
- if (isNewState){
- window.location.hash = mode;
- }
- },
- _getMode: function() {
+ _getModeFromHash: function() {
// Return the mode as it is stored in the hash.
- // If no mode can be found, default to the first tab.
- var hash = window.location.hash;
- return hash.length > 0 ? hash.slice(1, hash.length) : this._tabs()[0];
- },
- _tabs: function() {
- var elts = Array.prototype.slice.call(this.querySelectorAll('paper-tab'));
- return elts.map(function(elt){ return elt.getAttribute('data-mode')});
+ var tabName = window.location.hash.trim().slice(1);
+ var modeIndex = this.tabs.indexOf(tabName);
+ if (modeIndex == -1 && this.modeIndex == null) {
+ // Selecting the first tab as default.
+ this.set('modeIndex', 0);
+ }
+ if (modeIndex != -1 && modeIndex != this.modeIndex) {
+ this.set('modeIndex', modeIndex);
+ }
},
});
</script>
diff --git a/tensorflow/tensorboard/components/tf-test/index.html b/tensorflow/tensorboard/components/tf-test/index.html
deleted file mode 100644
index d551750e3c..0000000000
--- a/tensorflow/tensorboard/components/tf-test/index.html
+++ /dev/null
@@ -1,16 +0,0 @@
-<!doctype html>
-<html>
-<head>
- <meta charset="utf-8">
- <script src="../web-component-tester/browser.js"></script>
-</head>
-<body>
-<script>
-// Run the tests for each main component in tensorboard.
-WCT.loadSuites([
- '../tf-graph-common/test/index.html',
- '../tf-graph-loader/test/index.html',
-]);
-</script>
-</body>
-</html>
diff --git a/tensorflow/tensorboard/gulpfile.js b/tensorflow/tensorboard/gulpfile.js
index dcc79f3008..6eeb24ddbe 100644
--- a/tensorflow/tensorboard/gulpfile.js
+++ b/tensorflow/tensorboard/gulpfile.js
@@ -98,8 +98,7 @@ gulp.task('compile.all', ['typings'], function() {
});
gulp.task('test', ['tslint-strict', 'compile.all'], function(done) {
- tester({suites: ['components/tf-test/'],
- plugins: {local: {}, sauce: false}}, function(error) {
+ tester({}, function(error) {
if (error) {
// Pretty error for gulp.
error = new Error(error.message || error);
diff --git a/tensorflow/tensorboard/lib/js/backend/test/index.html b/tensorflow/tensorboard/lib/js/backend/test/index.html
index 2305cf9426..7965ce6d0b 100644
--- a/tensorflow/tensorboard/lib/js/backend/test/index.html
+++ b/tensorflow/tensorboard/lib/js/backend/test/index.html
@@ -14,13 +14,9 @@ limitations under the License.
=============================================================================-->
<!doctype html>
<html>
-<!-- This test file has import paths that are suitable for gulp test and
- direct loading in the browser -->
<head>
<meta charset="utf-8">
- <script src="../../../../../components/webcomponentsjs/webcomponents-lite.min.js"></script>
- <script src="../../../../components/web-component-tester/browser.js"></script>
-
+ <script src="../../web-component-tester/browser.js"></script>
</head>
<body>
<script src="../../requestManager/requestManager.js"></script>
diff --git a/tensorflow/tensorboard/lib/js/nanite/test/index.html b/tensorflow/tensorboard/lib/js/nanite/test/index.html
index 0ac18a1bf2..2a886afe62 100644
--- a/tensorflow/tensorboard/lib/js/nanite/test/index.html
+++ b/tensorflow/tensorboard/lib/js/nanite/test/index.html
@@ -1,13 +1,10 @@
<!doctype html>
<html>
-<!-- This test file has import paths that are suitable for gulp test and
- direct loading in the browser -->
<head>
<meta charset="utf-8">
- <script src="../../../../../components/webcomponentsjs/webcomponents-lite.min.js"></script>
- <script src="../../../../components/web-component-tester/browser.js"></script>
- <link rel="import" href="../../../../../components/polymer/polymer.html">
-
+ <script src="../../webcomponentsjs/webcomponents-lite.min.js"></script>
+ <script src="../../web-component-tester/browser.js"></script>
+ <link rel="import" href="../../polymer/polymer.html">
</head>
<body>
<script src="../nanite.js"></script>
diff --git a/tensorflow/tensorboard/lib/js/nanite/test/naniteTest.ts b/tensorflow/tensorboard/lib/js/nanite/test/naniteTest.ts
index ecc792944e..ba9dce0f57 100644
--- a/tensorflow/tensorboard/lib/js/nanite/test/naniteTest.ts
+++ b/tensorflow/tensorboard/lib/js/nanite/test/naniteTest.ts
@@ -14,11 +14,9 @@ limitations under the License.
==============================================================================*/
var assert = chai.assert;
declare function fixture(id: string): void;
-declare module HTMLImports {
- export function whenReady(f: Function): void;
-}
+
module TF.Nanite {
- HTMLImports.whenReady(function() {
+ window.HTMLImports.whenReady(function() {
Polymer({
is: "test-element",
properties: {
diff --git a/tensorflow/tensorboard/lib/js/node-radar/test/index.html b/tensorflow/tensorboard/lib/js/node-radar/test/index.html
index afb21ba15f..83c3018ed2 100644
--- a/tensorflow/tensorboard/lib/js/node-radar/test/index.html
+++ b/tensorflow/tensorboard/lib/js/node-radar/test/index.html
@@ -1,12 +1,8 @@
-
<!doctype html>
<html>
- <!-- This test file has import paths that are suitable for gulp test and
- direct loading in the browser -->
<head>
<meta charset="utf-8">
- <script src="../../../../../components/webcomponentsjs/webcomponents-lite.min.js"></script>
- <script src="../../../../components/web-component-tester/browser.js"></script>
+ <script src="../../web-component-tester/browser.js"></script>
</head>
<body>
<script src="../nodeRadar.js"></script>
diff --git a/tensorflow/tensorboard/lib/js/requestManager/test/index.html b/tensorflow/tensorboard/lib/js/requestManager/test/index.html
index b9712e8daf..53487f1f58 100644
--- a/tensorflow/tensorboard/lib/js/requestManager/test/index.html
+++ b/tensorflow/tensorboard/lib/js/requestManager/test/index.html
@@ -2,8 +2,7 @@
<html>
<head>
<meta charset="utf-8">
- <script src="../../../../../components/webcomponentsjs/webcomponents-lite.min.js"></script>
- <script src="../../../../components/web-component-tester/browser.js"></script>
+ <script src="../../web-component-tester/browser.js"></script>
</head>
<body>
<script src="../requestManager.js"></script>
diff --git a/tensorflow/tensorboard/package.json b/tensorflow/tensorboard/package.json
index 1902bc4756..25bb35df67 100644
--- a/tensorflow/tensorboard/package.json
+++ b/tensorflow/tensorboard/package.json
@@ -25,7 +25,7 @@
"tslint": "^3.2.1",
"typescript": "1.8.0",
"vulcanize": "^1.14.0",
- "web-component-tester": "~3.4.2",
+ "web-component-tester": "4.2.2",
"gulp-header": "~1.7.1",
"gulp-rename": "~1.2.2",
"gulp-typings": "~1.1.0",
diff --git a/tensorflow/tensorboard/wct.conf.json b/tensorflow/tensorboard/wct.conf.json
new file mode 100644
index 0000000000..0a5c6c20b6
--- /dev/null
+++ b/tensorflow/tensorboard/wct.conf.json
@@ -0,0 +1,12 @@
+{
+ "suites": [
+ "components/tf-*/test",
+ "lib/js/*/test"
+ ],
+ "plugins": ["local"],
+ "webserver": {
+ "pathMappings": [
+ {"/components/<basename>/lib/js": "components"}
+ ]
+ }
+} \ No newline at end of file
diff --git a/tensorflow/tools/docker/notebooks/3_mnist_from_scratch.ipynb b/tensorflow/tools/docker/notebooks/3_mnist_from_scratch.ipynb
index 9bb889e41a..8f8bedbdfe 100644
--- a/tensorflow/tools/docker/notebooks/3_mnist_from_scratch.ipynb
+++ b/tensorflow/tools/docker/notebooks/3_mnist_from_scratch.ipynb
@@ -1273,7 +1273,7 @@
"batch_labels = train_labels[:BATCH_SIZE]\n",
"\n",
"# This dictionary maps the batch data (as a numpy array) to the\n",
- "# node in the graph is should be fed to.\n",
+ "# node in the graph it should be fed to.\n",
"feed_dict = {train_data_node: batch_data,\n",
" train_labels_node: batch_labels}\n",
"\n",
@@ -1680,7 +1680,7 @@
" batch_data = train_data[offset:(offset + BATCH_SIZE), :, :, :]\n",
" batch_labels = train_labels[offset:(offset + BATCH_SIZE)]\n",
" # This dictionary maps the batch data (as a numpy array) to the\n",
- " # node in the graph is should be fed to.\n",
+ " # node in the graph it should be fed to.\n",
" feed_dict = {train_data_node: batch_data,\n",
" train_labels_node: batch_labels}\n",
" # Run the graph and fetch some of the nodes.\n",
diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl
index 1a8a026b7e..6debeabd97 100644
--- a/tensorflow/workspace.bzl
+++ b/tensorflow/workspace.bzl
@@ -10,8 +10,8 @@ def tf_workspace(path_prefix = ""):
native.new_http_archive(
name = "eigen_archive",
- url = "https://bitbucket.org/eigen/eigen/get/f1ce2528ee99.tar.gz",
- sha256 = "2c4ce322d13a613bbc53de3381760cf56f1c9b03c409233b764a6434ee1db909",
+ url = "https://bitbucket.org/eigen/eigen/get/88444e025a5c.tar.gz",
+ sha256 = "42e6f6de56b3ff010531a2bbf3e2db1db46be30d3965efb1eaa5634c5db013dd",
build_file = path_prefix + "eigen.BUILD",
)
diff --git a/third_party/eigen3/Eigen/Cholesky b/third_party/eigen3/Eigen/Cholesky
index af815350c8..95a503d611 100644
--- a/third_party/eigen3/Eigen/Cholesky
+++ b/third_party/eigen3/Eigen/Cholesky
@@ -1 +1 @@
-#include "eigen-eigen-f1ce2528ee99/Eigen/Cholesky"
+#include "eigen-eigen-88444e025a5c/Eigen/Cholesky"
diff --git a/third_party/eigen3/Eigen/Core b/third_party/eigen3/Eigen/Core
index 1625edf8f7..b4a10f6ed1 100644
--- a/third_party/eigen3/Eigen/Core
+++ b/third_party/eigen3/Eigen/Core
@@ -1 +1 @@
-#include "eigen-eigen-f1ce2528ee99/Eigen/Core"
+#include "eigen-eigen-88444e025a5c/Eigen/Core"
diff --git a/third_party/eigen3/Eigen/Eigenvalues b/third_party/eigen3/Eigen/Eigenvalues
index f5e92ae98a..56657aa837 100644
--- a/third_party/eigen3/Eigen/Eigenvalues
+++ b/third_party/eigen3/Eigen/Eigenvalues
@@ -1 +1 @@
-#include "eigen-eigen-f1ce2528ee99/Eigen/Eigenvalues"
+#include "eigen-eigen-88444e025a5c/Eigen/Eigenvalues"
diff --git a/third_party/eigen3/Eigen/LU b/third_party/eigen3/Eigen/LU
index 77f592a412..3c491eeef9 100644
--- a/third_party/eigen3/Eigen/LU
+++ b/third_party/eigen3/Eigen/LU
@@ -1 +1 @@
-#include "eigen-eigen-f1ce2528ee99/Eigen/LU"
+#include "eigen-eigen-88444e025a5c/Eigen/LU"
diff --git a/third_party/eigen3/Eigen/QR b/third_party/eigen3/Eigen/QR
index 2f1eeb9a6e..5a97880470 100644
--- a/third_party/eigen3/Eigen/QR
+++ b/third_party/eigen3/Eigen/QR
@@ -1 +1 @@
-#include "eigen-eigen-f1ce2528ee99/Eigen/QR"
+#include "eigen-eigen-88444e025a5c/Eigen/QR"
diff --git a/third_party/eigen3/unsupported/Eigen/CXX11/Tensor b/third_party/eigen3/unsupported/Eigen/CXX11/Tensor
index b87d22f207..20150d0594 100644
--- a/third_party/eigen3/unsupported/Eigen/CXX11/Tensor
+++ b/third_party/eigen3/unsupported/Eigen/CXX11/Tensor
@@ -1 +1 @@
-#include "eigen-eigen-f1ce2528ee99/unsupported/Eigen/CXX11/Tensor"
+#include "eigen-eigen-88444e025a5c/unsupported/Eigen/CXX11/Tensor"
diff --git a/tools/bazel.rc.template b/tools/bazel.rc.template
index e90ec790fd..d2b1b0b25a 100644
--- a/tools/bazel.rc.template
+++ b/tools/bazel.rc.template
@@ -3,8 +3,6 @@ build:cuda --define=using_cuda=true
build --force_python=py$PYTHON_MAJOR_VERSION
build --python$PYTHON_MAJOR_VERSION_path=$PYTHON_BINARY
-build --define=use_fast_cpp_protos=true
-build --define=allow_oversize_protos=true
build --spawn_strategy=standalone
test --spawn_strategy=standalone