diff options
99 files changed, 2009 insertions, 1159 deletions
diff --git a/eigen.BUILD b/eigen.BUILD index 958772ee9d..806b6d36b9 100644 --- a/eigen.BUILD +++ b/eigen.BUILD @@ -1,6 +1,6 @@ package(default_visibility = ["//visibility:public"]) -archive_dir = "eigen-eigen-f1ce2528ee99" +archive_dir = "eigen-eigen-88444e025a5c" cc_library( name = "eigen", diff --git a/tensorflow/contrib/linear_optimizer/kernels/BUILD b/tensorflow/contrib/linear_optimizer/kernels/BUILD index 682fd6f822..2e56171211 100644 --- a/tensorflow/contrib/linear_optimizer/kernels/BUILD +++ b/tensorflow/contrib/linear_optimizer/kernels/BUILD @@ -7,16 +7,27 @@ exports_files(["LICENSE"]) package(default_visibility = ["//tensorflow:__subpackages__"]) cc_library( - name = "losses", + name = "loss_updaters", hdrs = [ "hinge-loss.h", "logistic-loss.h", + "loss.h", "squared-loss.h", ], deps = ["//tensorflow/core:lib"], ) -# TODO(katsiapis): Add tests for losses. +cc_test( + name = "loss_updaters_test", + srcs = ["loss_updaters_test.cc"], + deps = [ + ":loss_updaters", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) cc_library( name = "resources", @@ -44,7 +55,7 @@ cc_library( name = "sdca_ops", srcs = ["sdca_ops.cc"], deps = [ - ":losses", + ":loss_updaters", ":resources", "//third_party/eigen3", "//tensorflow/core:framework", diff --git a/tensorflow/contrib/linear_optimizer/kernels/hinge-loss.h b/tensorflow/contrib/linear_optimizer/kernels/hinge-loss.h index 877fceeeb4..3655fa707e 100644 --- a/tensorflow/contrib/linear_optimizer/kernels/hinge-loss.h +++ b/tensorflow/contrib/linear_optimizer/kernels/hinge-loss.h @@ -19,8 +19,13 @@ limitations under the License. #include <algorithm> #include <cmath> +#include "tensorflow/contrib/linear_optimizer/kernels/loss.h" +#include "tensorflow/core/lib/core/errors.h" + namespace tensorflow { -struct hinge_loss { + +class HingeLossUpdater : public DualLossUpdater { + public: // Computes the updated dual variable (corresponding) to a single example. The // updated dual value maximizes the objective function of the dual // optimization problem associated with hinge loss (conditioned on keeping the @@ -30,13 +35,11 @@ struct hinge_loss { // and the particular form of conjugate function for hinge loss. // TODO(pmol): Write up a doc with concrete derivation and point to it from // here. - inline static double ComputeUpdatedDual(const double label, - const double example_weight, - const double current_dual, - const double wx, - const double weighted_example_norm, - const double primal_loss, - const double dual_loss) { + double ComputeUpdatedDual(const double label, const double example_weight, + const double current_dual, const double wx, + const double weighted_example_norm, + const double primal_loss, + const double dual_loss) const final { // Intutitvely there are 3 cases: // a. new optimal value of the dual variable falls withing the admissible // range [0, 1]. In this case we set new dual to this value. @@ -65,9 +68,8 @@ struct hinge_loss { // on its label. In particular: // \phi_y*(z) = y*z if y*z \in [-w, 0] and +infinity everywhere else where // y \in {-1,1}. The following method implements \phi_y*(-\alpha/w). - inline static double ComputeDualLoss(const double current_dual, - const double example_label, - const double example_weight) { + double ComputeDualLoss(const double current_dual, const double example_label, + const double example_weight) const final { // For binary classification, there are 2 conjugate functions, one per // label value (-1 and 1). const double y_alpha = current_dual * example_label; // y \alpha @@ -80,13 +82,29 @@ struct hinge_loss { // Hinge loss for binary classification for a single example. Hinge loss // equals max(0, 1 - y * wx) (see https://en.wikipedia.org/wiki/Hinge_loss). // For weighted instances loss should be multiplied by the instance weight. - inline static double ComputePrimalLoss(const double wx, - const double example_label, - const double example_weight) { + double ComputePrimalLoss(const double wx, const double example_label, + const double example_weight) const final { const double y_wx = example_label * wx; return std::max(0.0, 1 - y_wx) * example_weight; } + + // Converts binary example labels from 0.0 or 1.0 to -1.0 or 1.0 respectively + // as expected by hinge loss. + Status ConvertLabel(float* const example_label) const final { + if (*example_label == 0.0) { + *example_label = -1; + return Status::OK(); + } + if (*example_label == 1.0) { + return Status::OK(); + } + return errors::InvalidArgument( + "Only labels of 0.0 or 1.0 are supported right now. " + "Found example with label: ", + *example_label); + } }; + } // namespace tensorflow #endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LINEAR_OPTIMIZER_KERNELS_HINGE_LOSS_H_ diff --git a/tensorflow/contrib/linear_optimizer/kernels/logistic-loss.h b/tensorflow/contrib/linear_optimizer/kernels/logistic-loss.h index d75a707820..b18116be9d 100644 --- a/tensorflow/contrib/linear_optimizer/kernels/logistic-loss.h +++ b/tensorflow/contrib/linear_optimizer/kernels/logistic-loss.h @@ -19,44 +19,21 @@ limitations under the License. #include <algorithm> #include <cmath> +#include "tensorflow/contrib/linear_optimizer/kernels/loss.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/core/status.h" namespace tensorflow { -struct logistic_loss { - // Partial derivative of the logistic loss w.r.t (1 + exp(-ywx)). - inline static double PartialDerivativeLogisticLoss(const double wx, - const double label) { - // To avoid overflow, we compute partial derivative of logistic loss as - // follows. - const double ywx = label * wx; - if (ywx > 0) { - const double exp_minus_ywx = exp(-ywx); - return exp_minus_ywx / (1 + exp_minus_ywx); - } - return 1 / (1 + exp(ywx)); - } - - // Smoothness constant for the logistic loss. - inline static double SmoothnessConstantLogisticLoss( - const double partial_derivative_loss, const double wx, - const double label) { - // Upper bound on the smoothness constant of log loss. This is 0.25 i.e. - // when log-odds is zero. - return (wx == 0) ? 0.25 - : (1 - 2 * partial_derivative_loss) / (2 * label * wx); - } +class LogisticLossUpdater : public DualLossUpdater { + public: // Use an approximate step that is guaranteed to decrease the dual loss. // Derivation of this is available in Page 14 Eq 16 of // http://arxiv.org/pdf/1211.2717v1.pdf - inline static double ComputeUpdatedDual(const double label, - const double example_weight, - const double current_dual, - const double wx, - const double weighted_example_norm, - const double primal_loss, - const double dual_loss) { + double ComputeUpdatedDual(const double label, const double example_weight, + const double current_dual, const double wx, + const double weighted_example_norm, + const double primal_loss, + const double dual_loss) const final { const double partial_derivative_loss = PartialDerivativeLogisticLoss(label, wx); // f(a) = sup (a*x - f(x)) then a = f'(x), where a is the aproximate dual. @@ -81,9 +58,8 @@ struct logistic_loss { // Dual of logisitic loss function. // https://en.wikipedia.org/wiki/Convex_conjugate - inline static double ComputeDualLoss(const double current_dual, - const double example_label, - const double example_weight) { + double ComputeDualLoss(const double current_dual, const double example_label, + const double example_weight) const final { // Dual of the logistic loss function is // ay * log(ay) + (1-ay) * log (1-ay), where a is the dual variable. const double ay = current_dual * example_label; @@ -95,9 +71,8 @@ struct logistic_loss { // Logistic loss for binary classification. // https://en.wikipedia.org/wiki/Loss_functions_for_classification - inline static double ComputePrimalLoss(const double wx, - const double example_label, - const double example_weight) { + double ComputePrimalLoss(const double wx, const double example_label, + const double example_weight) const final { // Logistic loss: // log(1 + e^(-ywx)) // log(e^0 + e^(-ywx)) @@ -117,7 +92,7 @@ struct logistic_loss { // Converts binary example labels from 0.0 or 1.0 to -1.0 or 1.0 respectively // as expected by logistic regression. - inline static Status ConvertLabel(float* const example_label) { + Status ConvertLabel(float* const example_label) const final { if (*example_label == 0.0) { *example_label = -1; return Status::OK(); @@ -130,6 +105,30 @@ struct logistic_loss { "Found example with label: ", *example_label); } + + private: + // Partial derivative of the logistic loss w.r.t (1 + exp(-ywx)). + static inline double PartialDerivativeLogisticLoss(const double wx, + const double label) { + // To avoid overflow, we compute partial derivative of logistic loss as + // follows. + const double ywx = label * wx; + if (ywx > 0) { + const double exp_minus_ywx = exp(-ywx); + return exp_minus_ywx / (1 + exp_minus_ywx); + } + return 1 / (1 + exp(ywx)); + } + + // Smoothness constant for the logistic loss. + static inline double SmoothnessConstantLogisticLoss( + const double partial_derivative_loss, const double wx, + const double label) { + // Upper bound on the smoothness constant of log loss. This is 0.25 i.e. + // when log-odds is zero. + return (wx == 0) ? 0.25 + : (1 - 2 * partial_derivative_loss) / (2 * label * wx); + } }; } // namespace tensorflow diff --git a/tensorflow/contrib/linear_optimizer/kernels/loss.h b/tensorflow/contrib/linear_optimizer/kernels/loss.h new file mode 100644 index 0000000000..d827d6f764 --- /dev/null +++ b/tensorflow/contrib/linear_optimizer/kernels/loss.h @@ -0,0 +1,53 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LINEAR_OPTIMIZER_KERNELS_LOSS_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LINEAR_OPTIMIZER_KERNELS_LOSS_H_ + +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { + +class DualLossUpdater { + public: + virtual ~DualLossUpdater() {} + + // Compute update dual (alpha), based on a single example. Various strategies + // can be employed here, like newton step and/or line search or approximate + // step that decreases the dual sub-optimality. + virtual double ComputeUpdatedDual(const double label, + const double example_weight, + const double current_dual, const double wx, + const double weighted_example_norm, + const double primal_loss, + const double dual_loss) const = 0; + + // Compute dual loss based on the current dual (alpha), example label (y) + // and example weight (cost). + virtual double ComputeDualLoss(const double current_dual, + const double example_label, + const double example_weight) const = 0; + + // Compute the primal loss based on current estimate of log-odds(wx), + // example label (y) and example weight (cost). + virtual double ComputePrimalLoss(const double wx, const double example_label, + const double example_weight) const = 0; + + // Converts binary example labels from 0.0 or 1.0 to appropriate range for + // each loss function. + virtual Status ConvertLabel(float* const example_label) const = 0; +}; + +} // namespace tensorflow +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LINEAR_OPTIMIZER_KERNELS_LOSS_H_ diff --git a/tensorflow/contrib/linear_optimizer/kernels/loss_updaters_test.cc b/tensorflow/contrib/linear_optimizer/kernels/loss_updaters_test.cc new file mode 100644 index 0000000000..7d9f05609b --- /dev/null +++ b/tensorflow/contrib/linear_optimizer/kernels/loss_updaters_test.cc @@ -0,0 +1,58 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/linear_optimizer/kernels/logistic-loss.h" +#include "tensorflow/contrib/linear_optimizer/kernels/squared-loss.h" + +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +TEST(LogisticLoss, ComputePrimalLoss) { + LogisticLossUpdater loss_updater; + EXPECT_NEAR(0.693147, loss_updater.ComputePrimalLoss( + 0 /* wx */, 1 /* label */, 1 /* example weight */), + 1e-3); + EXPECT_NEAR(0.0, loss_updater.ComputePrimalLoss(70 /* wx */, 1 /* label */, + 1 /* example weight */), + 1e-3); + EXPECT_NEAR(0.0, loss_updater.ComputePrimalLoss(-70 /* wx */, -1 /* label */, + 1 /* example weight */), + 1e-3); +} + +TEST(LogisticLoss, ComputeDualLoss) { + LogisticLossUpdater loss_updater; + EXPECT_NEAR(0.0, + loss_updater.ComputeDualLoss(0 /* current dual */, 1 /* label */, + 1 /* example weight */), + 1e-3); + EXPECT_NEAR(0.0, + loss_updater.ComputeDualLoss(1 /* current dual */, 1 /* label */, + 1 /* example weight */), + 1e-3); + EXPECT_NEAR(-0.693147, loss_updater.ComputeDualLoss(0.5 /* current dual */, + 1 /* label */, + 1 /* example weight */), + 1e-3); +} + +// TODO(rohananil): Add tests for dual update. +// TODO(dbaylor): Add tests for squared loss. +// TODO(pmol): Add tests for hinge loss. + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/contrib/linear_optimizer/kernels/resources.cc b/tensorflow/contrib/linear_optimizer/kernels/resources.cc index 392ceac12a..d6266616f1 100644 --- a/tensorflow/contrib/linear_optimizer/kernels/resources.cc +++ b/tensorflow/contrib/linear_optimizer/kernels/resources.cc @@ -44,6 +44,16 @@ DataByExample::Key DataByExample::MakeKey(const string& example_id) { Hash64(example_id.data(), example_id.size(), kSeed2) & 0xFFFFFFFF); } +DataByExample::Data DataByExample::Get(const Key& key) { + mutex_lock l(mu_); + return data_by_key_[key]; +} + +void DataByExample::Set(const Key& key, const Data& data) { + mutex_lock l(mu_); + data_by_key_[key] = data; +} + Status DataByExample::Visit( std::function<void(const Data& data)> visitor) const { struct State { @@ -71,8 +81,8 @@ Status DataByExample::Visit( // be successful if and only if the size of the backing store hasn't // changed (since the body of this while-loop is under lock). if (data_by_key_.size() != state.size) { - return errors::Aborted("The number of elements for ", solver_uuid_, - " has changed which nullifies a visit."); + return errors::Unavailable("The number of elements for ", solver_uuid_, + " has changed which nullifies a visit."); } for (size_t i = 0; i < kVisitChunkSize && state.num_visited < state.size; ++i, ++state.num_visited, ++state.it) { diff --git a/tensorflow/contrib/linear_optimizer/kernels/resources.h b/tensorflow/contrib/linear_optimizer/kernels/resources.h index cb0ea8433e..4578e3442f 100644 --- a/tensorflow/contrib/linear_optimizer/kernels/resources.h +++ b/tensorflow/contrib/linear_optimizer/kernels/resources.h @@ -47,8 +47,10 @@ class DataByExample : public ResourceBase { static Key MakeKey(const string& example_id); struct Data { - // TODO(rohananil): Add extra data needed for duality gap computation here. float dual = 0; + float primal_loss = 0; + float dual_loss = 0; + float example_weight = 0; // Comparison operators for ease of testing. bool operator==(const Data& other) const { return dual == other.dual; } @@ -58,22 +60,17 @@ class DataByExample : public ResourceBase { // Accessor and mutator for the entry at Key. Accessor creates an entry with // default value (default constructed object) if the key is not present and // returns it. - inline Data Get(const Key& key) LOCKS_EXCLUDED(mu_) { - mutex_lock l(mu_); - return data_by_key_[key]; - } - inline void Set(const Key& key, const Data& data) LOCKS_EXCLUDED(mu_) { - mutex_lock l(mu_); - data_by_key_[key] = data; - } + Data Get(const Key& key) LOCKS_EXCLUDED(mu_); + void Set(const Key& key, const Data& data) LOCKS_EXCLUDED(mu_); // Visits all elements in this resource. The view of each element (Data) is // atomic, but the entirety of the visit is not (ie the visitor might see // different versions of the Data across elements). // - // Returns OK on success or ABORTED if the number of elements in this + // Returns OK on success or UNAVAILABLE if the number of elements in this // container has changed since the beginning of the visit (in which case the - // visit cannot be completed and is aborted early). + // visit cannot be completed and is aborted early, and computation can be + // restarted). Status Visit(std::function<void(const Data& data)> visitor) const LOCKS_EXCLUDED(mu_); @@ -86,8 +83,11 @@ class DataByExample : public ResourceBase { // Backing container. // - // sizeof(EntryPayload) = sizeof(Key) + sizeof(Data) = 16. - // So on average we use ~35 bytes per entry in this table. + // sizeof(EntryPayload) = + // sizeof(Key) + sizeof(Data) = + // 12 + 16 = 28. + // + // So on average we use ~47.5 (28 + 19.5) bytes per entry in this table. using DataByKey = std::unordered_map<Key, Data, KeyHash>; // TODO(katsiapis): Benchmark and/or optimize this. diff --git a/tensorflow/contrib/linear_optimizer/kernels/resources_test.cc b/tensorflow/contrib/linear_optimizer/kernels/resources_test.cc index 9a94c54bc5..1981db3160 100644 --- a/tensorflow/contrib/linear_optimizer/kernels/resources_test.cc +++ b/tensorflow/contrib/linear_optimizer/kernels/resources_test.cc @@ -103,7 +103,7 @@ TEST_F(DataByExampleTest, VisitMany) { (kNumElements - 1) * kNumElements / 2.0, total_dual); } -TEST_F(DataByExampleTest, VisitAborted) { +TEST_F(DataByExampleTest, VisitUnavailable) { // Populate enough entries so that Visiting will be chunked. for (size_t i = 0; i < 2 * VisitChunkSize(); ++i) { data_by_example_->Get(DataByExample::MakeKey(strings::StrCat(i))); @@ -151,7 +151,7 @@ TEST_F(DataByExampleTest, VisitAborted) { }); wait(&completed_visit); EXPECT_FALSE(thread_pool.HasPendingClosures()); - EXPECT_TRUE(errors::IsAborted(status)); + EXPECT_TRUE(errors::IsUnavailable(status)); } } // namespace tensorflow diff --git a/tensorflow/contrib/linear_optimizer/kernels/sdca_ops.cc b/tensorflow/contrib/linear_optimizer/kernels/sdca_ops.cc index 209c9a85e8..76671a47cd 100644 --- a/tensorflow/contrib/linear_optimizer/kernels/sdca_ops.cc +++ b/tensorflow/contrib/linear_optimizer/kernels/sdca_ops.cc @@ -78,7 +78,6 @@ struct RegularizationLoss { }; struct PerExampleData { - double old_wx = 0; double wx = 0; double norm = 0; }; @@ -256,7 +255,6 @@ inline PerExampleData ComputeWxAndWeightedExampleNorm( const int64 index = indices(dim); const double weight = weights(index); const double value = values(dim); - result.old_wx += Shrink(weight, shrink_by) * value; result.wx += Shrink(weight + delta_weights(index), shrink_by) * value; } result.norm += sparse_indices_values[example_id]->norm; @@ -265,7 +263,6 @@ inline PerExampleData ComputeWxAndWeightedExampleNorm( for (size_t i = 0; i < dense_features_by_group.size(); ++i) { const double weight = dense_weights_by_group[i](0); const double value = dense_features_by_group[i](example_id); - result.old_wx += Shrink(weight, shrink_by) * value; result.wx += Shrink(weight + dense_delta_weights_by_group[i](0), shrink_by) * value; result.norm += value * value; @@ -300,6 +297,7 @@ void AddDeltaWeights(const WeightsByGroup& src, WeightsByGroup* const dst) { void ShrinkWeights(const Regularizations& regularizations, WeightsByGroup* const sparse_weights_by_group, WeightsByGroup* const dense_weights_by_group) { + // TODO(rohananil): Parallelize shrinking. const double shrink_by = ShrinkageFactor(regularizations); for (TTypes<float>::Vec weights : *sparse_weights_by_group) { for (int64 i = 0; i < weights.size(); ++i) { @@ -380,6 +378,88 @@ WeightsByGroup MakeDeltaWeightsFrom(std::vector<Tensor>* const tensors) { return result; } +Status RunTrainStepsForMiniBatch( + const std::vector<int64>& example_indices, + const TTypes<const string>::Vec example_ids, + const TTypes<const float>::Vec example_labels, + const TTypes<const float>::Vec example_weights, + const DeviceBase::CpuWorkerThreads& worker_threads, + const Regularizations& regularizations, + const WeightsByGroup& sparse_weights_by_group, + const SparseExamplesByGroup& sparse_examples_by_group, + const WeightsByGroup& dense_weights_by_group, + const DenseFeaturesByGroup& dense_features_by_group, + const DualLossUpdater& loss_updater, + WeightsByGroup* const sparse_delta_weights_by_group, + WeightsByGroup* const dense_delta_weights_by_group, + DataByExample* const data_by_example) { + // Process examples in parallel, in a partitioned fashion. + mutex mu; + Status train_step_status GUARDED_BY(mu); + auto train_step = [&](const int64 begin, const int64 end) { + for (int64 offset = begin; offset < end; ++offset) { + // Get example id, label, and weight. + const int64 example_index = example_indices[offset]; + const DataByExample::Key example_key = + DataByExample::MakeKey(example_ids(example_index)); + DataByExample::Data data = data_by_example->Get(example_key); + const double example_weight = example_weights(example_index); + float example_label = example_labels(example_index); + const Status conversion_status = + loss_updater.ConvertLabel(&example_label); + if (!conversion_status.ok()) { + mutex_lock l(mu); + train_step_status = conversion_status; + // Return from this worker thread - the calling thread is + // responsible for checking context status and returning on error. + return; + } + + // Compute wx, example norm weighted by regularization, dual loss, + // primal loss. + const PerExampleData per_example_data = ComputeWxAndWeightedExampleNorm( + example_index, sparse_weights_by_group, + *sparse_delta_weights_by_group, sparse_examples_by_group, + dense_weights_by_group, *dense_delta_weights_by_group, + dense_features_by_group, regularizations); + + const double primal_loss = loss_updater.ComputePrimalLoss( + per_example_data.wx, example_label, example_weight); + + const double dual_loss = loss_updater.ComputeDualLoss( + data.dual, example_label, example_weight); + + const double new_dual = loss_updater.ComputeUpdatedDual( + example_label, example_weight, data.dual, per_example_data.wx, + per_example_data.norm, primal_loss, dual_loss); + + // Compute new weights. + const double bounded_dual_delta = (new_dual - data.dual) * example_weight; + UpdateDeltaWeights( + example_index, sparse_examples_by_group, dense_features_by_group, + bounded_dual_delta, regularizations.symmetric_l2, + sparse_delta_weights_by_group, dense_delta_weights_by_group); + + // Update example data. + data.dual = new_dual; + data.primal_loss = primal_loss; + data.dual_loss = dual_loss; + data.example_weight = example_weight; + data_by_example->Set(example_key, data); + } + // TODO(rohananil): We may in the future want to make the primal-dual + // relationship consistent as our current updates are not + // transactional. + }; + // TODO(rohananil): Current multiplier 100000 works well empirically + // but perhaps we can tune it better. + const int64 kCostPerUnit = 100000 * (sparse_examples_by_group.size() + + dense_features_by_group.size()); + Shard(worker_threads.num_threads, worker_threads.workers, + example_indices.size(), kCostPerUnit, train_step); + return train_step_status; +} + } // namespace class SdcaSolver : public OpKernel { @@ -388,21 +468,11 @@ class SdcaSolver : public OpKernel { string loss_type; OP_REQUIRES_OK(context, context->GetAttr("loss_type", &loss_type)); if (loss_type == "logistic_loss") { - compute_dual_loss_ = logistic_loss::ComputeDualLoss; - compute_primal_loss_ = logistic_loss::ComputePrimalLoss; - compute_dual_update_ = logistic_loss::ComputeUpdatedDual; - convert_label_ = logistic_loss::ConvertLabel; + loss_updater_.reset(new LogisticLossUpdater); } else if (loss_type == "squared_loss") { - compute_dual_loss_ = squared_loss::ComputeDualLoss; - compute_primal_loss_ = squared_loss::ComputePrimalLoss; - compute_dual_update_ = squared_loss::ComputeUpdatedDual; - convert_label_ = squared_loss::ConvertLabel; + loss_updater_.reset(new SquaredLossUpdater); } else if (loss_type == "hinge_loss") { - compute_dual_loss_ = hinge_loss::ComputeDualLoss; - compute_primal_loss_ = hinge_loss::ComputePrimalLoss; - compute_dual_update_ = hinge_loss::ComputeUpdatedDual; - // Label conversion is identical for hinge and logistic loss. - convert_label_ = logistic_loss::ConvertLabel; + loss_updater_.reset(new HingeLossUpdater); } else { OP_REQUIRES(context, false, errors::InvalidArgument( "Unsupported loss type: ", loss_type)); @@ -424,8 +494,16 @@ class SdcaSolver : public OpKernel { regularizations_.symmetric_l2 = std::max(regularizations_.symmetric_l2, 1.0f); - OP_REQUIRES_OK(context, context->GetAttr("duality_gap_threshold", - &duality_gap_threshold_)); + OP_REQUIRES_OK(context, context->GetAttr("num_inner_iterations", + &num_inner_iterations_)); + + // TODO(rohananil): Provide emperical evidence for this. It is better to run + // more than one iteration on single mini-batch as we want to spend more + // time in compute. SDCA works better with larger mini batches and there + // is also recent work that shows its better to reuse old samples than train + // on new samples. See: http://arxiv.org/abs/1602.02136. + num_inner_iterations_ = + std::max(num_inner_iterations_, static_cast<int64>(2)); OP_REQUIRES_OK(context, context->GetAttr("container", &container_)); OP_REQUIRES_OK(context, context->GetAttr("solver_uuid", &solver_uuid_)); } @@ -448,7 +526,7 @@ class SdcaSolver : public OpKernel { })); OP_REQUIRES( context, !data_by_example->RefCountIsOne(), - errors::Internal("Expected shared-ownership of duals_by_example.")); + errors::Internal("Expected shared-ownership of data_by_example.")); const Tensor* example_weights_t; OP_REQUIRES_OK(context, @@ -467,14 +545,6 @@ class SdcaSolver : public OpKernel { errors::InvalidArgument("No weighted examples in ", num_examples, " training examples")); - Tensor primal_loss_t; - OP_REQUIRES_OK(context, - context->mutable_input("primal_loss", &primal_loss_t, - /*lock_held=*/true)); - OP_REQUIRES(context, TensorShapeUtils::IsScalar(primal_loss_t.shape()), - errors::InvalidArgument("primal_loss should be a scalar.")); - auto primal_loss = primal_loss_t.scalar<double>(); - OpInputList dense_features_inputs; OP_REQUIRES_OK( context, context->input_list("dense_features", &dense_features_inputs)); @@ -527,6 +597,7 @@ class SdcaSolver : public OpKernel { OpMutableInputList sparse_weights_inputs; OP_REQUIRES_OK(context, context->mutable_input_list( "sparse_weights", &sparse_weights_inputs)); + WeightsByGroup sparse_weights_by_group = MakeWeightsFrom(&sparse_weights_inputs); std::vector<Tensor> sparse_delta_weights_by_group_backing_store = @@ -534,11 +605,10 @@ class SdcaSolver : public OpKernel { WeightsByGroup sparse_delta_weights_by_group = MakeDeltaWeightsFrom(&sparse_delta_weights_by_group_backing_store); - // TODO(rohananil): Remove the code duplication between sparse and - // dense weights. OpMutableInputList dense_weights_inputs; OP_REQUIRES_OK(context, context->mutable_input_list("dense_weights", &dense_weights_inputs)); + WeightsByGroup dense_weights_by_group = MakeWeightsFrom(&dense_weights_inputs); std::vector<Tensor> dense_delta_weights_by_group_backing_store = @@ -563,140 +633,165 @@ class SdcaSolver : public OpKernel { *context->device()->tensorflow_cpu_worker_threads(), &sparse_examples_by_group)); - // Those will be shuffled below at each iteration and processed in a - // partitioned fashion across multiple threads. - std::vector<int64> example_indices(num_examples); - std::iota(example_indices.begin(), example_indices.end(), 0); - - std::random_device random_device; - std::mt19937 random_generator(random_device()); - - // Break when duality gap |P(w) - D(alpha)| is less than - // duality_gap_threshold_ - double total_duality_gap = std::numeric_limits<double>::max(); - while ((total_duality_gap / weighted_examples) > duality_gap_threshold_) { - std::atomic<double> total_primal_loss(0); - std::atomic<double> total_dual_loss(0); - SetZeroDeltaWeights(&sparse_delta_weights_by_group, - &dense_delta_weights_by_group); - - // Compute regularization loss at the start of the iteration so that - // we can compute an exact value of duality gap (for the weights from - // the previous iteration). - const RegularizationLoss regularization_loss = ComputeRegularizationLoss( - sparse_weights_by_group, dense_weights_by_group, regularizations_); + SetZeroDeltaWeights(&sparse_delta_weights_by_group, + &dense_delta_weights_by_group); + // Examples are shuffled below at each iteration and processed in a + // partitioned fashion across multiple threads. + // TODO(rohananil): We may want to avoid shuffling inside the op + // but instead shuffle outside as part of the input reader. Re-evaluate + // once we have more data on how this op is used. + const std::vector<int64> example_indices = [num_examples]() { + std::vector<int64> result(num_examples); + std::iota(result.begin(), result.end(), 0); + std::random_device random_device; + std::mt19937 random_generator(random_device()); // Randomize the examples across iterations for faster convergence. - std::shuffle(example_indices.begin(), example_indices.end(), - random_generator); - - { - // Process examples in parallel, in a partitioned fashion. - mutex mu; - Status update_status GUARDED_BY(mu); - auto update_partition = [&](const int64 begin, const int64 end) { - double dual_loss_on_example_subset = 0; - double primal_loss_on_example_subset = 0; - for (int64 offset = begin; offset < end; ++offset) { - // Get example id, label, and weight. - const int64 example_index = example_indices[offset]; - const DataByExample::Key example_key = - DataByExample::MakeKey(example_ids(example_index)); - DataByExample::Data data = data_by_example->Get(example_key); - const double example_weight = example_weights(example_index); - float example_label = example_labels(example_index); - const Status conversion_status = convert_label_(&example_label); - if (!conversion_status.ok()) { - mutex_lock l(mu); - update_status = conversion_status; - // Return from this worker thread - the calling thread is - // responsible for checking context status and returning on error. - return; - } - - // Compute wx, example norm weighted by regularization, dual loss, - // primal loss. - const PerExampleData per_example_data = - ComputeWxAndWeightedExampleNorm( - example_index, sparse_weights_by_group, - sparse_delta_weights_by_group, sparse_examples_by_group, - dense_weights_by_group, dense_delta_weights_by_group, - dense_features_by_group, regularizations_); - // Compute primal based on the previous iteration. - primal_loss_on_example_subset += compute_primal_loss_( - per_example_data.old_wx, example_label, example_weight); - - const double primal_loss = compute_primal_loss_( - per_example_data.wx, example_label, example_weight); - - const double dual_loss = - compute_dual_loss_(data.dual, example_label, example_weight); - dual_loss_on_example_subset += dual_loss; - - const double new_dual = compute_dual_update_( - example_label, example_weight, data.dual, per_example_data.wx, - per_example_data.norm, primal_loss, dual_loss); - - // Compute new weights. - const double bounded_dual_delta = - (new_dual - data.dual) * example_weight; - UpdateDeltaWeights(example_index, sparse_examples_by_group, - dense_features_by_group, bounded_dual_delta, - regularizations_.symmetric_l2, - &sparse_delta_weights_by_group, - &dense_delta_weights_by_group); - - // Update dual variable. - data.dual = new_dual; - data_by_example->Set(example_key, data); - } - AtomicAdd(primal_loss_on_example_subset, &total_primal_loss); - AtomicAdd(dual_loss_on_example_subset, &total_dual_loss); - // TODO(rohananil): We may in the future want to make the primal-dual - // relationship consistent as our current updates are not - // transactional. - }; - const DeviceBase::CpuWorkerThreads& worker_threads = - *context->device()->tensorflow_cpu_worker_threads(); - // TODO(katsiapis): Current multiplier (100,000) works well empirically - // but perhaps we can tune it better. - const int64 kCostPerUnit = - 100000 * (num_sparse_features_ + num_dense_features_); - Shard(worker_threads.num_threads, worker_threads.workers, num_examples, - kCostPerUnit, update_partition); - OP_REQUIRES_OK(context, update_status); - } + std::shuffle(result.begin(), result.end(), random_generator); + return result; + }(); - total_duality_gap = total_primal_loss.load() + total_dual_loss.load() + - regularization_loss.l1_loss + - regularization_loss.l2_loss; - primal_loss() = (total_primal_loss.load() + regularization_loss.l1_loss + - regularization_loss.l2_loss) / - weighted_examples; - AddDeltaWeights(sparse_delta_weights_by_group, &sparse_weights_by_group); - AddDeltaWeights(dense_delta_weights_by_group, &dense_weights_by_group); + // TODO(rohananil): Provide emperical evidence for this. It is better to run + // more than one iteration on single mini-batch as we want to spend more + // time in compute. SDCA works better with larger mini batches and there + // is also recent work that shows its better to reuse old samples than train + // on new samples. See: http://arxiv.org/abs/1602.02136. + for (int64 i = 0; i < num_inner_iterations_; ++i) { + OP_REQUIRES_OK( + context, + RunTrainStepsForMiniBatch( + example_indices, example_ids, example_labels, example_weights, + *context->device()->tensorflow_cpu_worker_threads(), + regularizations_, sparse_weights_by_group, + sparse_examples_by_group, dense_weights_by_group, + dense_features_by_group, *loss_updater_, + &sparse_delta_weights_by_group, &dense_delta_weights_by_group, + data_by_example)); } - ShrinkWeights(regularizations_, &sparse_weights_by_group, - &dense_weights_by_group); + + // TODO(rohananil): Change to atomic<float> as we are not exposing delta + // weights to users. This will allows us to simplify the code that currently + // keeps a backing store for the tensors. This also avoids losing updates + // when done in a lockless way. + AddDeltaWeights(sparse_delta_weights_by_group, &sparse_weights_by_group); + AddDeltaWeights(dense_delta_weights_by_group, &dense_weights_by_group); // TODO(katsiapis): Use core::ScopedUnref once it's moved out of internal. data_by_example->Unref(); } private: - std::function<decltype(logistic_loss::ComputeDualLoss)> compute_dual_loss_; - std::function<decltype(logistic_loss::ComputePrimalLoss)> - compute_primal_loss_; - std::function<decltype(logistic_loss::ComputeUpdatedDual)> - compute_dual_update_; - std::function<decltype(logistic_loss::ConvertLabel)> convert_label_; + // TODO(rohananil): We could use the type-constraint on loss_type, and + // template the entire class to avoid the virtual table lookup penalty in + // the inner loop. + std::unique_ptr<DualLossUpdater> loss_updater_; int64 num_sparse_features_; int64 num_dense_features_; Regularizations regularizations_; - float duality_gap_threshold_; + int64 num_inner_iterations_; string container_; string solver_uuid_; }; REGISTER_KERNEL_BUILDER(Name("SdcaSolver").Device(DEVICE_CPU), SdcaSolver); +class SdcaShrinkL1 : public OpKernel { + public: + explicit SdcaShrinkL1(OpKernelConstruction* context) : OpKernel(context) { + OP_REQUIRES_OK(context, + context->GetAttr("l1", ®ularizations_.symmetric_l1)); + OP_REQUIRES_OK(context, + context->GetAttr("l2", ®ularizations_.symmetric_l2)); + // We enforce a minimal l2, required by the algorithm. + regularizations_.symmetric_l2 = + std::max(regularizations_.symmetric_l2, 1.0f); + } + + void Compute(OpKernelContext* context) override { + OpMutableInputList sparse_weights_inputs; + OP_REQUIRES_OK(context, context->mutable_input_list( + "sparse_weights", &sparse_weights_inputs)); + WeightsByGroup sparse_weights_by_group = + MakeWeightsFrom(&sparse_weights_inputs); + + OpMutableInputList dense_weights_inputs; + OP_REQUIRES_OK(context, context->mutable_input_list("dense_weights", + &dense_weights_inputs)); + WeightsByGroup dense_weights_by_group = + MakeWeightsFrom(&dense_weights_inputs); + + ShrinkWeights(regularizations_, &sparse_weights_by_group, + &dense_weights_by_group); + } + + private: + Regularizations regularizations_; +}; +REGISTER_KERNEL_BUILDER(Name("SdcaShrinkL1").Device(DEVICE_CPU), SdcaShrinkL1); + +class ComputeDualityGap : public OpKernel { + public: + explicit ComputeDualityGap(OpKernelConstruction* context) + : OpKernel(context) { + // TODO(rohananil): Refactor grabbing common attributes across ops related + // to sdca. + OP_REQUIRES_OK(context, + context->GetAttr("l1", ®ularizations_.symmetric_l1)); + OP_REQUIRES_OK(context, + context->GetAttr("l2", ®ularizations_.symmetric_l2)); + // We enforce a minimal l2, required by the algorithm. + regularizations_.symmetric_l2 = + std::max(regularizations_.symmetric_l2, 1.0f); + OP_REQUIRES_OK(context, context->GetAttr("container", &container_)); + OP_REQUIRES_OK(context, context->GetAttr("solver_uuid", &solver_uuid_)); + } + + void Compute(OpKernelContext* context) override { + DataByExample* data_by_example = nullptr; + OP_REQUIRES_OK(context, context->resource_manager()->Lookup<DataByExample>( + container_, solver_uuid_, &data_by_example)); + OP_REQUIRES( + context, !data_by_example->RefCountIsOne(), + errors::Internal("Expected shared-ownership of data_by_example.")); + + OpMutableInputList sparse_weights_inputs; + OP_REQUIRES_OK(context, context->mutable_input_list( + "sparse_weights", &sparse_weights_inputs)); + WeightsByGroup sparse_weights_by_group = + MakeWeightsFrom(&sparse_weights_inputs); + + OpMutableInputList dense_weights_inputs; + OP_REQUIRES_OK(context, context->mutable_input_list("dense_weights", + &dense_weights_inputs)); + WeightsByGroup dense_weights_by_group = + MakeWeightsFrom(&dense_weights_inputs); + + double example_weight_sum = 0; + double total_duality_gap = 0; + OP_REQUIRES_OK(context, + data_by_example->Visit([&](const DataByExample::Data& data) { + example_weight_sum += data.example_weight; + total_duality_gap += data.primal_loss + data.dual_loss; + })); + + const RegularizationLoss regularization_loss = ComputeRegularizationLoss( + sparse_weights_by_group, dense_weights_by_group, regularizations_); + total_duality_gap += + regularization_loss.l2_loss + regularization_loss.l1_loss; + + Tensor* duality_gap_t = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output("duality_gap", {}, &duality_gap_t)); + duality_gap_t->scalar<float>()() = total_duality_gap / example_weight_sum; + + // TODO(katsiapis): Use core::ScopedUnref once it's moved out of internal. + data_by_example->Unref(); + } + + private: + Regularizations regularizations_; + string container_; + string solver_uuid_; +}; +REGISTER_KERNEL_BUILDER(Name("ComputeDualityGap").Device(DEVICE_CPU), + ComputeDualityGap); } // namespace tensorflow diff --git a/tensorflow/contrib/linear_optimizer/kernels/squared-loss.h b/tensorflow/contrib/linear_optimizer/kernels/squared-loss.h index 94d3c6f8c4..fc37a98a4f 100644 --- a/tensorflow/contrib/linear_optimizer/kernels/squared-loss.h +++ b/tensorflow/contrib/linear_optimizer/kernels/squared-loss.h @@ -19,19 +19,19 @@ limitations under the License. #include <algorithm> #include <cmath> -#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/contrib/linear_optimizer/kernels/loss.h" namespace tensorflow { -struct squared_loss { + +class SquaredLossUpdater : public DualLossUpdater { + public: // Closed form solution that decreases the dual squared loss. // See page 23 of http://arxiv.org/pdf/1309.2375v2.pdf - inline static double ComputeUpdatedDual(const double label, - const double example_weight, - const double current_dual, - const double wx, - const double weighted_example_norm, - const double primal_loss_unused, - const double dual_loss_unused) { + double ComputeUpdatedDual(const double label, const double example_weight, + const double current_dual, const double wx, + const double weighted_example_norm, + const double primal_loss_unused, + const double dual_loss_unused) const final { const double delta_numerator = (label - current_dual - wx) * example_weight; const double delta_denominator = 1 + weighted_example_norm * example_weight * example_weight * 0.5; @@ -40,27 +40,26 @@ struct squared_loss { // Dual of squared loss function. // https://en.wikipedia.org/wiki/Convex_conjugate - inline static double ComputeDualLoss(const double current_dual, - const double example_label, - const double example_weight) { + double ComputeDualLoss(const double current_dual, const double example_label, + const double example_weight) const final { // Dual of the squared loss function = b * (y + b/2), where b is the // dual variable and y is the label. This is Dual(-b). return current_dual * (0.5 * current_dual - example_label) * example_weight; } // Squared loss for linear regression. - inline static double ComputePrimalLoss(const double wx, - const double example_label, - const double example_weight) { + double ComputePrimalLoss(const double wx, const double example_label, + const double example_weight) const final { const double error = wx - example_label; return error * error * example_weight * 0.5; } // Labels don't require conversion for linear regression. - inline static Status ConvertLabel(float* const example_label) { + Status ConvertLabel(float* const example_label) const final { return Status::OK(); } }; + } // namespace tensorflow #endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LINEAR_OPTIMIZER_KERNELS_SQUARED_LOSS_H_ diff --git a/tensorflow/contrib/linear_optimizer/ops/sdca_ops.cc b/tensorflow/contrib/linear_optimizer/ops/sdca_ops.cc index fb3c8154dd..ff2bae8fea 100644 --- a/tensorflow/contrib/linear_optimizer/ops/sdca_ops.cc +++ b/tensorflow/contrib/linear_optimizer/ops/sdca_ops.cc @@ -23,8 +23,8 @@ REGISTER_OP("SdcaSolver") .Attr("num_sparse_features: int >= 0") .Attr("num_dense_features: int >= 0") .Attr("l1: float >= 0") - .Attr("l2: float >= 0") - .Attr("duality_gap_threshold: float = 0.01") + .Attr("l2: float >= 1") + .Attr("num_inner_iterations: int >= 2") .Attr("container: string") .Attr("solver_uuid: string") .Input("sparse_features_indices: num_sparse_features * int64") @@ -35,7 +35,6 @@ REGISTER_OP("SdcaSolver") .Input("example_ids: string") .Input("sparse_weights: Ref(num_sparse_features * float)") .Input("dense_weights: Ref(num_dense_features * float)") - .Input("primal_loss: Ref(double)") .Doc(R"doc( Stochastic Dual Coordinate Ascent (SDCA) optimizer for linear models with L1 + L2 regularization. As global optimization objective is strongly-convex, the @@ -54,7 +53,7 @@ num_sparse_features: Number of sparse feature groups to train on. num_dense_features: Number of dense feature groups to train on. l1: Symmetric l1 regularization strength. l2: Symmetric l2 regularization strength. -duality_gap_threshold: Gap threshold at which we should stop training. +num_inner_iterations: Number of iterations per mini-batch. container: Name of the Container that stores data across invocations of this Kernel. Together with SolverUUID form an isolation unit for this solver. solver_uuid: Universally Unique Identifier for this solver. @@ -75,4 +74,53 @@ dense_weights: a list of vectors where the value is the weight associated with a dense feature group. )doc"); +REGISTER_OP("SdcaShrinkL1") + .Attr("num_sparse_features: int >= 0") + .Attr("num_dense_features: int >= 0") + .Attr("l1: float >= 0") + .Attr("l2: float >= 1") + .Input("sparse_weights: Ref(num_sparse_features * float)") + .Input("dense_weights: Ref(num_dense_features * float)") + .Doc(R"doc( +Applies L1 regularization shrink step on the parameters. + +num_sparse_features: Number of sparse feature groups to train on. +num_dense_features: Number of dense feature groups to train on. +l1: Symmetric l1 regularization strength. +l2: Symmetric l2 regularization strength. +sparse_weights: a list of vectors where each value is the weight associated with + a feature index. +dense_weights: a list of vectors where the value is the weight associated with + a dense feature group. +)doc"); + +// TODO(katsiapis): We should expand this scope of this op to compute other +// statistics about the data. +REGISTER_OP("ComputeDualityGap") + .Attr("num_sparse_features: int >= 0") + .Attr("num_dense_features: int >= 0") + .Attr("l1: float >= 0") + .Attr("l2: float >= 1") + .Attr("container: string") + .Attr("solver_uuid: string") + .Input("sparse_weights: Ref(num_sparse_features * float)") + .Input("dense_weights: Ref(num_dense_features * float)") + .Output("duality_gap: float") + .Doc(R"doc( +Computes duality gap over all examples seen by the optimizer. + +num_sparse_features: Number of sparse feature groups to train on. +num_dense_features: Number of dense feature groups to train on. +l1: Symmetric l1 regularization strength. +l2: Symmetric l2 regularization strength. +container: Name of the Container that stores data across invocations of this + Kernel. Together with SolverUUID form an isolation unit for this solver. +solver_uuid: Universally Unique Identifier for this solver. +sparse_weights: a list of vectors where each value is the weight associated with + a feature index. +dense_weights: a list of vectors where the value is the weight associated with + a dense feature group. +duality_gap: duality gap over all examples seen by the optimizer. +)doc"); + } // namespace tensorflow diff --git a/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py b/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py index 71994fd51e..13968457f7 100644 --- a/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py +++ b/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py @@ -89,11 +89,8 @@ def make_variable_dict(max_age, max_gender): # examples_dict. age_weights = tf.Variable(tf.zeros([max_age + 1], dtype=tf.float32)) gender_weights = tf.Variable(tf.zeros([max_gender + 1], dtype=tf.float32)) - primal_loss = tf.Variable(tf.zeros([], dtype=tf.float64)) return dict(sparse_features_weights=[age_weights, gender_weights], - dense_features_weights=[], - primal_loss=primal_loss) - + dense_features_weights=[]) def make_dense_variable_dict(num_dense_features, num_examples): feature_weights = ([ @@ -121,6 +118,7 @@ def get_binary_predictions_for_hinge(predictions): all_ones = tf.ones_like(predictions) return tf.add(tf.sign(predictions), all_ones) / 2 + # Setup the single container shared across all tests. This is testing proper # isolation across optimizers instantiated in each of the tests below. CONTAINER = uuid.uuid4().hex @@ -155,22 +153,32 @@ class SdcaOptimizerTest(TensorFlowTestCase): variables = make_variable_dict(1, 1) options = dict(symmetric_l2_regularization=1, symmetric_l1_regularization=0, - loss_type='logistic_loss', - prior=0.0) - tf.initialize_all_variables().run() + loss_type='logistic_loss') + lr = SdcaModel(CONTAINER, examples, variables, options) + tf.initialize_all_variables().run() unregularized_loss = lr.unregularized_loss(examples) loss = lr.regularized_loss(examples) predictions = lr.predictions(examples) self.assertAllClose(0.693147, unregularized_loss.eval()) self.assertAllClose(0.693147, loss.eval()) - lr.minimize().run() - self.assertAllClose(0.395226, unregularized_loss.eval(), - rtol=3e-2, atol=3e-2) - self.assertAllClose(0.657446, loss.eval(), - rtol=3e-2, atol=3e-2) + for _ in xrange(5): + lr.minimize().run() + # The high tolerance in unregularized_loss comparisons is due to the + # fact that it's possible to trade off unregularized_loss vs. + # regularization and still have a sum that is quite close to the + # optimal regularized_loss value. SDCA's duality gap only ensures that + # the regularized_loss is within 0.01 of optimal. + # 0.525457 is the optimal regularized_loss. + # 0.411608 is the unregularized_loss at that optimum. + self.assertAllClose(0.411608, unregularized_loss.eval(), rtol=0.11) + self.assertAllClose(0.525457, loss.eval(), atol=0.01) predicted_labels = get_binary_predictions_for_logistic(predictions) self.assertAllEqual([0, 1], predicted_labels.eval()) + self.assertAllClose(0.01, + lr.approximate_duality_gap().eval(), + rtol=1e-2, + atol=1e-2) def testSomeUnweightedExamples(self): # Setup test data with 4 examples, but should produce the same @@ -201,18 +209,22 @@ class SdcaOptimizerTest(TensorFlowTestCase): options = dict(symmetric_l2_regularization=1, symmetric_l1_regularization=0, loss_type='logistic_loss') - tf.initialize_all_variables().run() + lr = SdcaModel(CONTAINER, examples, variables, options) + tf.initialize_all_variables().run() unregularized_loss = lr.unregularized_loss(examples) loss = lr.regularized_loss(examples) predictions = lr.predictions(examples) - lr.minimize().run() - self.assertAllClose(0.395226, unregularized_loss.eval(), - rtol=3e-2, atol=3e-2) - self.assertAllClose(0.657446, loss.eval(), - rtol=3e-2, atol=3e-2) + for _ in xrange(5): + lr.minimize().run() + self.assertAllClose(0.411608, unregularized_loss.eval(), rtol=0.12) + self.assertAllClose(0.525457, loss.eval(), atol=0.01) predicted_labels = get_binary_predictions_for_logistic(predictions) self.assertAllClose([0, 1, 1, 1], predicted_labels.eval()) + self.assertAllClose(0.01, + lr.approximate_duality_gap().eval(), + rtol=1e-2, + atol=1e-2) def testFractionalLogisticExample(self): # Setup test data with 1 positive, and 1 mostly-negative example. @@ -231,10 +243,12 @@ class SdcaOptimizerTest(TensorFlowTestCase): options = dict(symmetric_l2_regularization=1, symmetric_l1_regularization=0, loss_type='logistic_loss') + + lr = SdcaModel(CONTAINER, examples, variables, options) tf.initialize_all_variables().run() with self.assertRaisesOpError( 'Only labels of 0.0 or 1.0 are supported right now.'): - SdcaModel(CONTAINER, examples, variables, options).minimize().run() + lr.minimize().run() def testNoWeightedExamples(self): # Setup test data with 1 positive, and 1 negative example. @@ -254,8 +268,9 @@ class SdcaOptimizerTest(TensorFlowTestCase): options = dict(symmetric_l2_regularization=1, symmetric_l1_regularization=0, loss_type='logistic_loss') - tf.initialize_all_variables().run() + lr = SdcaModel(CONTAINER, examples, variables, options) + tf.initialize_all_variables().run() self.assertAllClose([0.5, 0.5], lr.predictions(examples).eval()) with self.assertRaisesOpError( 'No weighted examples in 2 training examples'): @@ -281,8 +296,9 @@ class SdcaOptimizerTest(TensorFlowTestCase): options = dict(symmetric_l2_regularization=0.5, symmetric_l1_regularization=0, loss_type='logistic_loss') - tf.initialize_all_variables().run() + lr = SdcaModel(CONTAINER, examples, variables, options) + tf.initialize_all_variables().run() self.assertAllClose([0.5, 0.5], lr.predictions(examples).eval()) with self.assertRaisesOpError('Detected 1 duplicates in example_ids'): lr.minimize().run() @@ -310,19 +326,25 @@ class SdcaOptimizerTest(TensorFlowTestCase): variables = make_variable_dict(3, 1) options = dict(symmetric_l2_regularization=1, symmetric_l1_regularization=0, - loss_type='logistic_loss', - prior=-1.09861) - tf.initialize_all_variables().run() + loss_type='logistic_loss') + lr = SdcaModel(CONTAINER, examples, variables, options) + tf.initialize_all_variables().run() unregularized_loss = lr.unregularized_loss(examples) loss = lr.regularized_loss(examples) predictions = lr.predictions(examples) - lr.minimize().run() - self.assertAllClose(0.331710, unregularized_loss.eval(), - rtol=3e-2, atol=3e-2) - self.assertAllClose(0.591295, loss.eval(), rtol=3e-2, atol=3e-2) + for _ in xrange(5): + lr.minimize().run() + self.assertAllClose(0.226487 + 0.102902, + unregularized_loss.eval(), + rtol=0.08) + self.assertAllClose(0.328394 + 0.131364, loss.eval(), atol=0.01) predicted_labels = get_binary_predictions_for_logistic(predictions) self.assertAllEqual([0, 0, 0, 1], predicted_labels.eval()) + self.assertAllClose(0.01, + lr.approximate_duality_gap().eval(), + rtol=1e-2, + atol=1e-2) def testImbalancedWithExampleWeights(self): # Setup test data with 1 positive, and 1 negative example. @@ -341,17 +363,22 @@ class SdcaOptimizerTest(TensorFlowTestCase): options = dict(symmetric_l2_regularization=1, symmetric_l1_regularization=0, loss_type='logistic_loss') - tf.initialize_all_variables().run() + lr = SdcaModel(CONTAINER, examples, variables, options) + tf.initialize_all_variables().run() unregularized_loss = lr.unregularized_loss(examples) loss = lr.regularized_loss(examples) predictions = lr.predictions(examples) - lr.minimize().run() - self.assertAllClose(0.266189, unregularized_loss.eval(), - rtol=3e-2, atol=3e-2) - self.assertAllClose(0.571912, loss.eval(), rtol=3e-2, atol=3e-2) + for _ in xrange(5): + lr.minimize().run() + self.assertAllClose(0.284860, unregularized_loss.eval(), rtol=0.08) + self.assertAllClose(0.408044, loss.eval(), atol=0.012) predicted_labels = get_binary_predictions_for_logistic(predictions) self.assertAllEqual([0, 1], predicted_labels.eval()) + self.assertAllClose(0.01, + lr.approximate_duality_gap().eval(), + rtol=1e-2, + atol=1e-2) def testInstancesOfOneClassOnly(self): # Setup test data with 1 positive (ignored), and 1 negative example. @@ -367,24 +394,25 @@ class SdcaOptimizerTest(TensorFlowTestCase): with self._single_threaded_test_session(): examples = make_example_dict(example_protos, example_weights) variables = make_variable_dict(1, 1) - options = dict(symmetric_l2_regularization=0.25, + options = dict(symmetric_l2_regularization=1, symmetric_l1_regularization=0, loss_type='logistic_loss') - tf.initialize_all_variables().run() + lr = SdcaModel(CONTAINER, examples, variables, options) + tf.initialize_all_variables().run() unregularized_loss = lr.unregularized_loss(examples) loss = lr.regularized_loss(examples) - prediction = lr.predictions(examples) - lr.minimize().run() - self.assertAllClose(0.395226, - unregularized_loss.eval(), - rtol=3e-2, - atol=3e-2) - self.assertAllClose(0.460781, loss.eval(), rtol=3e-2, atol=3e-2) - predicted_labels = tf.cast( - tf.greater_equal(prediction, - tf.ones_like(prediction) * 0.5), tf.float32) + predictions = lr.predictions(examples) + for _ in xrange(5): + lr.minimize().run() + self.assertAllClose(0.411608, unregularized_loss.eval(), rtol=0.12) + self.assertAllClose(0.525457, loss.eval(), atol=0.01) + predicted_labels = get_binary_predictions_for_logistic(predictions) self.assertAllEqual([0, 0], predicted_labels.eval()) + self.assertAllClose(0.01, + lr.approximate_duality_gap().eval(), + rtol=1e-2, + atol=1e-2) def testSimpleLinear(self): # Setup test data @@ -402,21 +430,26 @@ class SdcaOptimizerTest(TensorFlowTestCase): variables = make_variable_dict(1, 1) options = dict(symmetric_l2_regularization=1, symmetric_l1_regularization=0, - loss_type='squared_loss', - prior=0.0) - tf.initialize_all_variables().run() + loss_type='squared_loss') + lr = SdcaModel(CONTAINER, examples, variables, options) - prediction = lr.predictions(examples) + tf.initialize_all_variables().run() + predictions = lr.predictions(examples) - lr.minimize().run() + for _ in xrange(20): + lr.minimize().run() # Predictions should be 2/3 of label due to minimizing regularized loss: # (label - 2 * weight)^2 / 2 + L2 * 2 * weight^2 self.assertAllClose([-20.0 / 3.0, 28.0 / 3.0], - prediction.eval(), + predictions.eval(), rtol=0.005) + self.assertAllClose(0.01, + lr.approximate_duality_gap().eval(), + rtol=1e-2, + atol=1e-2) - def testLinearRegularization(self): + def testLinearL2Regularization(self): # Setup test data example_protos = [ # 2 identical examples @@ -440,13 +473,14 @@ class SdcaOptimizerTest(TensorFlowTestCase): variables = make_variable_dict(1, 1) options = dict(symmetric_l2_regularization=16, symmetric_l1_regularization=0, - loss_type='squared_loss', - prior=0.0) - tf.initialize_all_variables().run() + loss_type='squared_loss') + lr = SdcaModel(CONTAINER, examples, variables, options) - prediction = lr.predictions(examples) + tf.initialize_all_variables().run() + predictions = lr.predictions(examples) - lr.minimize().run() + for _ in xrange(5): + lr.minimize().run() # Predictions should be 1/5 of label due to minimizing regularized loss: # (label - 2 * weight)^2 + L2 * 16 * weight^2 @@ -454,9 +488,42 @@ class SdcaOptimizerTest(TensorFlowTestCase): optimal2 = 14.0 / 5.0 self.assertAllClose( [optimal1, optimal1, optimal2, optimal2], - prediction.eval(), + predictions.eval(), rtol=0.01) + def testLinearL1Regularization(self): + # Setup test data + example_protos = [ + make_example_proto( + {'age': [0], + 'gender': [0]}, -10.0), + make_example_proto( + {'age': [1], + 'gender': [1]}, 14.0), + ] + example_weights = [1.0, 1.0] + with self._single_threaded_test_session(): + examples = make_example_dict(example_protos, example_weights) + variables = make_variable_dict(1, 1) + options = dict(symmetric_l2_regularization=1.0, + symmetric_l1_regularization=4.0, + loss_type='squared_loss') + lr = SdcaModel(CONTAINER, examples, variables, options) + tf.initialize_all_variables().run() + prediction = lr.predictions(examples) + loss = lr.regularized_loss(examples) + + for _ in xrange(5): + lr.minimize().run() + + # Predictions should be -4.0, 48/5 due to minimizing regularized loss: + # (label - 2 * weight)^2 / 2 + L2 * 2 * weight^2 + L1 * 4 * weight + self.assertAllClose([-4.0, 20.0 / 3.0], prediction.eval(), rtol=0.08) + + # Loss should be the sum of the regularized loss value from above per + # example after plugging in the optimal weights. + self.assertAllClose(308.0 / 6.0, loss.eval(), atol=0.01) + def testLinearFeatureValues(self): # Setup test data example_protos = [ @@ -474,18 +541,19 @@ class SdcaOptimizerTest(TensorFlowTestCase): variables = make_variable_dict(1, 1) options = dict(symmetric_l2_regularization=1, symmetric_l1_regularization=0, - loss_type='squared_loss', - prior=0.0) - tf.initialize_all_variables().run() + loss_type='squared_loss') + lr = SdcaModel(CONTAINER, examples, variables, options) - prediction = lr.predictions(examples) + tf.initialize_all_variables().run() + predictions = lr.predictions(examples) - lr.minimize().run() + for _ in xrange(20): + lr.minimize().run() # Predictions should be 8/9 of label due to minimizing regularized loss: # (label - 2 * 2 * weight)^2 / 2 + L2 * 2 * weight^2 self.assertAllClose([-10.0 * 8 / 9, 14.0 * 8 / 9], - prediction.eval(), + predictions.eval(), rtol=0.07) def testLinearDenseFeatures(self): @@ -497,25 +565,22 @@ class SdcaOptimizerTest(TensorFlowTestCase): variables = make_dense_variable_dict(2, 2) options = dict(symmetric_l2_regularization=1, symmetric_l1_regularization=0, - loss_type='squared_loss', - prior=0.0) - tf.initialize_all_variables().run() + loss_type='squared_loss') lr = SdcaModel(CONTAINER, examples, variables, options) - prediction = lr.predictions(examples) + tf.initialize_all_variables().run() + predictions = lr.predictions(examples) - lr.minimize().run() + for _ in xrange(20): + lr.minimize().run() # Predictions should be 4/5 of label due to minimizing regularized loss: # (label - 2 * weight)^2 / 2 + L2 * weight^2 self.assertAllClose([-10.0 * 4 / 5, 14.0 * 4 / 5], - prediction.eval(), + predictions.eval(), rtol=0.01) loss = lr.regularized_loss(examples) - self.assertAllClose( - (4.0 + 7.84 + 16.0 + 31.36) / 2, - loss.eval(), - rtol=0.01) + self.assertAllClose(148.0 / 10.0, loss.eval(), atol=0.01) def testSimpleHinge(self): # Setup test data @@ -533,10 +598,9 @@ class SdcaOptimizerTest(TensorFlowTestCase): variables = make_variable_dict(1, 1) options = dict(symmetric_l2_regularization=1.0, symmetric_l1_regularization=0, - loss_type='hinge_loss', - prior=0.0) - tf.initialize_all_variables().run() + loss_type='hinge_loss') model = SdcaModel(CONTAINER, examples, variables, options) + tf.initialize_all_variables().run() # Before minimization, the weights default to zero. There is no loss due # to regularization, only unregularized loss which is 0.5 * (1+1) = 1.0. @@ -551,13 +615,15 @@ class SdcaOptimizerTest(TensorFlowTestCase): # are 4 sparse weights: 2 for age (say w1, w2) and 2 for gender (say w3 # and w4). Solving the system w1 + w3 = 1.0, w2 + w4 = -1.0 and minimizing # wrt to \|\vec{w}\|_2, gives w1=w3=1/2 and w2=w4=-1/2. This gives 0.0 - # unregularized loss and 0.5 L2 loss. - model.minimize().run() + # unregularized loss and 0.25 L2 loss. + for _ in xrange(5): + model.minimize().run() + binary_predictions = get_binary_predictions_for_hinge(predictions) self.assertAllEqual([-1.0, 1.0], predictions.eval()) self.assertAllEqual([0.0, 1.0], binary_predictions.eval()) self.assertAllClose(0.0, unregularized_loss.eval()) - self.assertAllClose(0.5, regularized_loss.eval(), atol=0.05) + self.assertAllClose(0.25, regularized_loss.eval(), atol=0.05) def testHingeDenseFeaturesPerfectlySeparable(self): with self._single_threaded_test_session(): @@ -569,22 +635,25 @@ class SdcaOptimizerTest(TensorFlowTestCase): options = dict(symmetric_l2_regularization=1.0, symmetric_l1_regularization=0, loss_type='hinge_loss') - tf.initialize_all_variables().run() model = SdcaModel(CONTAINER, examples, variables, options) + tf.initialize_all_variables().run() predictions = model.predictions(examples) binary_predictions = get_binary_predictions_for_hinge(predictions) - model.minimize().run() + + for _ in xrange(5): + model.minimize().run() + self.assertAllClose([1.0, -1.0], predictions.eval(), atol=0.05) self.assertAllClose([1.0, 0.0], binary_predictions.eval()) # (1.0, 1.0) and (1.0, -1.0) are perfectly separable by x-axis (that is, # the SVM's functional margin >=1), so the unregularized loss is ~0.0. # There is only loss due to l2-regularization. For these datapoints, it - # turns out that w_1~=0.0 and w_2~=1.0 which means that l2 loss is ~0.5. + # turns out that w_1~=0.0 and w_2~=1.0 which means that l2 loss is ~0.25. unregularized_loss = model.unregularized_loss(examples) regularized_loss = model.regularized_loss(examples) self.assertAllClose(0.0, unregularized_loss.eval(), atol=0.02) - self.assertAllClose(0.5, regularized_loss.eval(), atol=0.02) + self.assertAllClose(0.25, regularized_loss.eval(), atol=0.02) def testHingeDenseFeaturesSeparableWithinMargins(self): with self._single_threaded_test_session(): @@ -596,22 +665,24 @@ class SdcaOptimizerTest(TensorFlowTestCase): options = dict(symmetric_l2_regularization=1.0, symmetric_l1_regularization=0, loss_type='hinge_loss') - tf.initialize_all_variables().run() model = SdcaModel(CONTAINER, examples, variables, options) + tf.initialize_all_variables().run() predictions = model.predictions(examples) binary_predictions = get_binary_predictions_for_hinge(predictions) - model.minimize().run() + + for _ in xrange(5): + model.minimize().run() # (1.0, 0.5) and (1.0, -0.5) are separable by x-axis but the datapoints # are within the margins so there is unregularized loss (1/2 per example). # For these datapoints, optimal weights are w_1~=0.0 and w_2~=1.0 which - # gives an L2 loss of ~0.5. + # gives an L2 loss of ~0.25. self.assertAllClose([0.5, -0.5], predictions.eval(), rtol=0.05) self.assertAllClose([1.0, 0.0], binary_predictions.eval()) unregularized_loss = model.unregularized_loss(examples) regularized_loss = model.regularized_loss(examples) self.assertAllClose(0.5, unregularized_loss.eval(), atol=0.02) - self.assertAllClose(1.0, regularized_loss.eval(), atol=0.02) + self.assertAllClose(0.75, regularized_loss.eval(), atol=0.02) def testHingeDenseFeaturesWeightedExamples(self): with self._single_threaded_test_session(): @@ -623,24 +694,26 @@ class SdcaOptimizerTest(TensorFlowTestCase): options = dict(symmetric_l2_regularization=1.0, symmetric_l1_regularization=0, loss_type='hinge_loss') - tf.initialize_all_variables().run() model = SdcaModel(CONTAINER, examples, variables, options) + tf.initialize_all_variables().run() predictions = model.predictions(examples) binary_predictions = get_binary_predictions_for_hinge(predictions) - model.minimize().run() + for _ in xrange(5): + model.minimize().run() # Point (1.0, 0.5) has higher weight than (1.0, -0.5) so the model will # try to increase the margin from (1.0, 0.5). Due to regularization, # (1.0, -0.5) will be within the margin. For these points and example # weights, the optimal weights are w_1~=0.4 and w_2~=1.2 which give an L2 - # loss of 0.25 * 1.6 = 0.4. The binary predictions will be correct, but - # the boundary will be much closer to the 2nd point than the first one. + # loss of 0.5 * 0.25 * 0.25 * 1.6 = 0.2. The binary predictions will be + # correct, but the boundary will be much closer to the 2nd point than the + # first one. self.assertAllClose([1.0, -0.2], predictions.eval(), atol=0.05) self.assertAllClose([1.0, 0.0], binary_predictions.eval(), atol=0.05) unregularized_loss = model.unregularized_loss(examples) regularized_loss = model.regularized_loss(examples) self.assertAllClose(0.2, unregularized_loss.eval(), atol=0.02) - self.assertAllClose(0.6, regularized_loss.eval(), atol=0.02) + self.assertAllClose(0.4, regularized_loss.eval(), atol=0.02) if __name__ == '__main__': googletest.main() diff --git a/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py b/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py index e986d40338..957a734b07 100644 --- a/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py +++ b/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py @@ -20,11 +20,16 @@ from __future__ import print_function import os.path import uuid +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops from tensorflow.python.framework.load_library import load_op_library from tensorflow.python.framework.ops import convert_to_tensor from tensorflow.python.framework.ops import name_scope from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import state_ops +from tensorflow.python.ops import variables as var_ops from tensorflow.python.ops.nn import sigmoid_cross_entropy_with_logits from tensorflow.python.platform import resource_loader @@ -99,11 +104,12 @@ class SdcaModel(object): the model, by resetting its (possibly shared) container. ```python - # Execute opt_op once to perform training, which continues until - convergence. - The op makes use of duality gap as a certificate for termination. Duality - gap is set to 0.01 as default. - opt_op.run() + # Execute opt_op and train for num_steps. + for _ in xrange(num_steps): + opt_op.run() + + # You can also check for convergence by calling + # lr.approximate_duality_gap() ``` """ @@ -125,8 +131,7 @@ class SdcaModel(object): self._assertList(['sparse_features', 'dense_features'], examples) self._assertSpecified( - ['sparse_features_weights', 'dense_features_weights', - 'primal_loss'], variables) + ['sparse_features_weights', 'dense_features_weights'], variables) self._assertList( ['sparse_features_weights', 'dense_features_weights'], variables) @@ -138,9 +143,26 @@ class SdcaModel(object): self._examples = examples self._variables = variables self._options = options - self._primal_loss = convert_to_tensor(self._variables['primal_loss'], - as_ref=True) self._solver_uuid = uuid.uuid4().hex + self._create_slots(variables) + + # TODO(rohananil): Use optimizer interface to make use of slot creation + # logic + def _create_slots(self, variables): + self._slots = {} + # TODO(rohananil): Rename the slot keys to "unshrinked" weights. + self._slots['sparse_features_weights'] = [] + self._slots['dense_features_weights'] = [] + self._assign_ops = [] + # Make an internal variable which has the updates before applying L1 + # regularization. + for var_type in ['sparse_features_weights', 'dense_features_weights']: + for var in variables[var_type]: + if var is not None: + self._slots[var_type].append(var_ops.Variable(array_ops.zeros_like( + var.initialized_value(), dtypes.float32))) + self._assign_ops.append(state_ops.assign(var, self._slots[var_type][ + -1])) def _assertSpecified(self, items, check_in): for x in items: @@ -160,7 +182,7 @@ class SdcaModel(object): dense_weights = self._convert_n_to_tensor(self._variables[ 'dense_features_weights']) l1 = self._options['symmetric_l1_regularization'] - loss = 0 + loss = 0.0 for w in sparse_weights: loss += l1 * math_ops.reduce_sum(abs(w)) for w in dense_weights: @@ -175,12 +197,13 @@ class SdcaModel(object): dense_weights = self._convert_n_to_tensor(self._variables[ 'dense_features_weights']) l2 = self._options['symmetric_l2_regularization'] - loss = 0 + loss = 0.0 for w in sparse_weights: loss += l2 * math_ops.reduce_sum(math_ops.square(w)) for w in dense_weights: loss += l2 * math_ops.reduce_sum(math_ops.square(w)) - return loss + # SDCA L2 regularization cost is 1/2 * l2 * sum(weights^2) + return loss / 2.0 def _convert_n_to_tensor(self, input_list, as_ref=False): """Converts input list to a set of tensors.""" @@ -247,23 +270,52 @@ class SdcaModel(object): sparse_features_indices.append(convert_to_tensor(sf.indices)) sparse_features_weights.append(convert_to_tensor(sf.values)) - return _sdca_ops.sdca_solver( + step_op = _sdca_ops.sdca_solver( sparse_features_indices, sparse_features_weights, self._convert_n_to_tensor(self._examples['dense_features']), convert_to_tensor(self._examples['example_weights']), convert_to_tensor(self._examples['example_labels']), convert_to_tensor(self._examples['example_ids']), - self._convert_n_to_tensor(self._variables['sparse_features_weights'], + self._convert_n_to_tensor(self._slots['sparse_features_weights'], as_ref=True), - self._convert_n_to_tensor(self._variables['dense_features_weights'], + self._convert_n_to_tensor(self._slots['dense_features_weights'], as_ref=True), - self._primal_loss, l1=self._options['symmetric_l1_regularization'], l2=self._options['symmetric_l2_regularization'], + num_inner_iterations=2, loss_type=self._options['loss_type'], container=self._container, solver_uuid=self._solver_uuid) + with ops.control_dependencies([step_op]): + assign_ops = control_flow_ops.group(*self._assign_ops) + with ops.control_dependencies([assign_ops]): + return _sdca_ops.sdca_shrink_l1( + self._convert_n_to_tensor( + self._variables['sparse_features_weights'], + as_ref=True), + self._convert_n_to_tensor( + self._variables['dense_features_weights'], + as_ref=True), + l1=self._options['symmetric_l1_regularization'], + l2=self._options['symmetric_l2_regularization']) + + def approximate_duality_gap(self): + """Add operations to compute the approximate duality gap. + + Returns: + An Operation that computes the approximate duality gap over all + examples. + """ + return _sdca_ops.compute_duality_gap( + self._convert_n_to_tensor(self._slots['sparse_features_weights'], + as_ref=True), + self._convert_n_to_tensor(self._slots['dense_features_weights'], + as_ref=True), + l1=self._options['symmetric_l1_regularization'], + l2=self._options['symmetric_l2_regularization'], + container=self._container, + solver_uuid=self._solver_uuid) def unregularized_loss(self, examples): """Add operations to compute the loss (without the regularization loss). @@ -310,8 +362,9 @@ class SdcaModel(object): err = math_ops.sub(labels, predictions) weighted_squared_err = math_ops.mul(math_ops.square(err), weights) + # SDCA squared loss function is sum(err^2) / (2*sum(weights)) return (math_ops.reduce_sum(weighted_squared_err) / - math_ops.reduce_sum(weights)) + (2.0 * math_ops.reduce_sum(weights))) def regularized_loss(self, examples): """Add operations to compute the loss with regularization loss included. @@ -321,7 +374,7 @@ class SdcaModel(object): Returns: An Operation that computes mean (regularized) loss for given set of - examples. + examples. Raises: ValueError: if examples are not well defined. """ diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index dc9d398a1b..5ee2337647 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -36,10 +36,6 @@ # # filegroup ":android_proto_srcs" - Protos # filegroup ":android_srcs" - Core sources -# filegroup ":android_core_ops" - Essential kernels -# filegroup ":android_extended_ops" - Optional kernels -# filegroup ":android_extended_ops_group1" - Optional kernels, first batch -# filegroup ":android_extended_ops_group2" - Optional kernels, second batch # cc_library ":android_tensorflow_lib" - Native library # portable_proto_library ":android_proto_lib" (Google-internal) @@ -117,7 +113,6 @@ tf_proto_library_cc( srcs = ["protobuf/master.proto"], cc_api_version = 2, cc_libs = [":protos_all_cc"], - py_api_version = 2, visibility = [ "//tensorflow:internal", ], @@ -131,7 +126,6 @@ tf_proto_library_cc( cc_grpc_version = 1, cc_libs = [":master_proto_cc"], cc_stubby_versions = ["2"], - py_api_version = 2, visibility = [ "//tensorflow:internal", ], @@ -611,57 +605,6 @@ filegroup( visibility = ["//visibility:public"], ) -# Core kernels we want on Android. Only a subset of kernels to keep -# base library small. -filegroup( - name = "android_core_ops", - srcs = [ - "//tensorflow/core/kernels:android_core_ops", - ], - visibility = ["//visibility:public"], -) - -# Other kernels we may want on Android. -# -# The kernels can be consumed as a whole or in two groups for -# supporting separate compilation. Note that the split into groups -# is entirely for improving compilation time, and not for -# organizational reasons; you should not depend on any -# of those groups independently. -filegroup( - name = "android_extended_ops", - srcs = [ - ":android_extended_ops_group1", - ":android_extended_ops_group2", - ], - visibility = ["//visibility:public"], -) - -filegroup( - name = "android_extended_ops_headers", - srcs = [ - "//tensorflow/core/kernels:android_extended_ops_headers", - ], -) - -filegroup( - name = "android_extended_ops_group1", - srcs = [ - ":android_extended_ops_headers", - "//tensorflow/core/kernels:android_extended_ops_group1", - ], - visibility = ["//visibility:public"], -) - -filegroup( - name = "android_extended_ops_group2", - srcs = [ - ":android_extended_ops_headers", - "//tensorflow/core/kernels:android_extended_ops_group2", - ], - visibility = ["//visibility:public"], -) - # Config setting for determining if we are building for Android. config_setting( name = "android", @@ -718,8 +661,8 @@ cc_library( cc_library( name = "android_tensorflow_lib", srcs = [ - "//tensorflow/core:android_core_ops", - "//tensorflow/core:android_extended_ops", + "//tensorflow/core/kernels:android_core_ops", + "//tensorflow/core/kernels:android_extended_ops", ], copts = select({ ":android": ANDROID_TF_COPTS, @@ -738,6 +681,7 @@ cc_library( ], ) +# ----------------------------------------------------------------------------- # Libraries for GPU facilities that are useful for writing kernels. cc_library( name = "gpu_lib", @@ -963,10 +907,6 @@ tf_cuda_library( ], ), copts = tf_copts(), - cuda_deps = [ - ":core_gpu_internal", - ":stream_executor", - ], deps = [ ":framework", ":framework_internal", @@ -982,39 +922,6 @@ tf_cuda_library( alwayslink = 1, ) -# This target should not link in any GPU runtime (CUDA) dependencies, -# only libraries for interfacing with GPUs that can be safely linked -# into CPU binaries. -cc_library( - name = "core_gpu_internal", - srcs = [ - "common_runtime/gpu/gpu_allocator_retry.cc", - "common_runtime/gpu/gpu_bfc_allocator.cc", - "common_runtime/gpu/gpu_debug_allocator.cc", - "common_runtime/gpu/gpu_init.cc", - "common_runtime/gpu/pool_allocator.cc", - "common_runtime/gpu/process_state.cc", - ], - hdrs = [ - "common_runtime/gpu/gpu_allocator_retry.h", - "common_runtime/gpu/gpu_bfc_allocator.h", - "common_runtime/gpu/gpu_debug_allocator.h", - "common_runtime/gpu/gpu_init.h", - "common_runtime/gpu/pool_allocator.h", - "common_runtime/gpu/process_state.h", - "common_runtime/gpu/visitable_allocator.h", - ], - copts = tf_copts(), - deps = [ - ":framework", - ":framework_internal", - ":lib", - ":lib_internal", - ":protos_all_cc", - ":stream_executor", - ], -) - cc_library( name = "cuda", deps = [ @@ -1044,16 +951,29 @@ tf_cuda_library( tf_cuda_library( name = "gpu_runtime", srcs = [ + "common_runtime/gpu/gpu_allocator_retry.cc", + "common_runtime/gpu/gpu_bfc_allocator.cc", + "common_runtime/gpu/gpu_debug_allocator.cc", "common_runtime/gpu/gpu_device.cc", "common_runtime/gpu/gpu_device_factory.cc", + "common_runtime/gpu/gpu_init.cc", "common_runtime/gpu/gpu_stream_util.cc", "common_runtime/gpu/gpu_util.cc", "common_runtime/gpu/gpu_util_platform_specific.cc", + "common_runtime/gpu/pool_allocator.cc", + "common_runtime/gpu/process_state.cc", ], hdrs = [ + "common_runtime/gpu/gpu_allocator_retry.h", + "common_runtime/gpu/gpu_bfc_allocator.h", + "common_runtime/gpu/gpu_debug_allocator.h", "common_runtime/gpu/gpu_device.h", + "common_runtime/gpu/gpu_init.h", "common_runtime/gpu/gpu_stream_util.h", "common_runtime/gpu/gpu_util.h", + "common_runtime/gpu/pool_allocator.h", + "common_runtime/gpu/process_state.h", + "common_runtime/gpu/visitable_allocator.h", ], copts = tf_copts(), cuda_deps = [ @@ -1063,7 +983,6 @@ tf_cuda_library( deps = [ ":core_cpu", ":core_cpu_internal", - ":core_gpu_internal", ":framework", ":framework_internal", ":gpu_lib", @@ -1247,7 +1166,6 @@ tf_cc_tests( ":all_kernels", ":core_cpu", ":core_cpu_internal", - ":core_gpu_internal", ":direct_session", ":framework", ":framework_internal", diff --git a/tensorflow/core/common_runtime/gpu/gpu_device_factory.cc b/tensorflow/core/common_runtime/gpu/gpu_device_factory.cc index d37a55784d..d0726f235c 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_device_factory.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_device_factory.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/gpu/gpu_device.h" #include "tensorflow/core/common_runtime/gpu/process_state.h" +#include "tensorflow/core/common_runtime/threadpool_device.h" namespace tensorflow { @@ -61,6 +62,49 @@ class GPUDeviceFactory : public BaseGPUDeviceFactory { REGISTER_LOCAL_DEVICE_FACTORY("GPU", GPUDeviceFactory); +//------------------------------------------------------------------------------ +// A CPUDevice that optimizes for interaction with GPUs in the +// process. +// ----------------------------------------------------------------------------- +class GPUCompatibleCPUDevice : public ThreadPoolDevice { + public: + GPUCompatibleCPUDevice(const SessionOptions& options, const string& name, + Bytes memory_limit, BusAdjacency bus_adjacency, + Allocator* allocator) + : ThreadPoolDevice(options, name, memory_limit, bus_adjacency, + allocator) {} + ~GPUCompatibleCPUDevice() override {} + + Allocator* GetAllocator(AllocatorAttributes attr) override { + ProcessState* ps = ProcessState::singleton(); + if (attr.gpu_compatible()) { + return ps->GetCUDAHostAllocator(0); + } else { + // Call the parent's implementation. + return ThreadPoolDevice::GetAllocator(attr); + } + } +}; + +// The associated factory. +class GPUCompatibleCPUDeviceFactory : public DeviceFactory { + public: + void CreateDevices(const SessionOptions& options, const string& name_prefix, + std::vector<Device*>* devices) override { + int n = 1; + auto iter = options.config.device_count().find("CPU"); + if (iter != options.config.device_count().end()) { + n = iter->second; + } + for (int i = 0; i < n; i++) { + string name = strings::StrCat(name_prefix, "/cpu:", i); + devices->push_back(new GPUCompatibleCPUDevice( + options, name, Bytes(256 << 20), BUS_ANY, cpu_allocator())); + } + } +}; +REGISTER_LOCAL_DEVICE_FACTORY("CPU", GPUCompatibleCPUDeviceFactory, 50); + } // namespace tensorflow #endif // GOOGLE_CUDA diff --git a/tensorflow/core/common_runtime/threadpool_device.cc b/tensorflow/core/common_runtime/threadpool_device.cc index 98f42a7e45..6477e9a336 100644 --- a/tensorflow/core/common_runtime/threadpool_device.cc +++ b/tensorflow/core/common_runtime/threadpool_device.cc @@ -26,10 +26,6 @@ limitations under the License. #include "tensorflow/core/platform/types.h" #include "tensorflow/core/public/session_options.h" -#if GOOGLE_CUDA -#include "tensorflow/core/common_runtime/gpu/process_state.h" -#endif // GOOGLE_CUDA - namespace tensorflow { ThreadPoolDevice::ThreadPoolDevice(const SessionOptions& options, @@ -56,12 +52,6 @@ void ThreadPoolDevice::Compute(OpKernel* op_kernel, OpKernelContext* context) { } Allocator* ThreadPoolDevice::GetAllocator(AllocatorAttributes attr) { -#if GOOGLE_CUDA - ProcessState* ps = ProcessState::singleton(); - if (attr.gpu_compatible()) { - return ps->GetCUDAHostAllocator(0); - } -#endif // GOOGLE_CUDA return allocator_; } diff --git a/tensorflow/core/distributed_runtime/BUILD b/tensorflow/core/distributed_runtime/BUILD index 00d97a6ef9..fb672196ca 100644 --- a/tensorflow/core/distributed_runtime/BUILD +++ b/tensorflow/core/distributed_runtime/BUILD @@ -269,6 +269,18 @@ cc_library( ], ) +cc_library( + name = "server_lib", + srcs = ["server_lib.cc"], + hdrs = ["server_lib.h"], + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + ], +) + # TODO(mrry): Move executor_test.cc to ../common_runtime when once it no longer depends # on grpc_testlib. tf_cc_tests( diff --git a/tensorflow/core/distributed_runtime/rpc/BUILD b/tensorflow/core/distributed_runtime/rpc/BUILD index d9d016e67e..df86046c45 100644 --- a/tensorflow/core/distributed_runtime/rpc/BUILD +++ b/tensorflow/core/distributed_runtime/rpc/BUILD @@ -211,7 +211,7 @@ cc_library( srcs = [ "grpc_server_lib.cc", ], - hdrs = ["grpc_server_lib.h"], + linkstatic = 1, # Seems to be needed since alwayslink is broken in bazel deps = [ "@grpc//:grpc++_unsecure", ":async_service_interface", @@ -230,8 +230,10 @@ cc_library( "//tensorflow/core/distributed_runtime:master_env", "//tensorflow/core/distributed_runtime:master_session", "//tensorflow/core/distributed_runtime:process_util", + "//tensorflow/core/distributed_runtime:server_lib", "//tensorflow/core/distributed_runtime:worker_env", ], + alwayslink = 1, ) cc_binary( @@ -247,6 +249,7 @@ cc_binary( "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", + "//tensorflow/core/distributed_runtime:server_lib", ], ) @@ -276,6 +279,7 @@ cc_binary( "//tensorflow/core:core_cpu", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", + "//tensorflow/core/distributed_runtime:server_lib", ], ) @@ -344,5 +348,6 @@ tf_cc_tests( "//tensorflow/core:test_main", "//tensorflow/core:testlib", "//tensorflow/core/distributed_runtime:process_util", + "//tensorflow/core/distributed_runtime:server_lib", ], ) diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc index 5cefc7605f..629441cee4 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc @@ -13,8 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h" - #include <memory> #include "grpc++/grpc++.h" @@ -33,6 +31,7 @@ limitations under the License. #include "tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h" #include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h" +#include "tensorflow/core/distributed_runtime/server_lib.h" #include "tensorflow/core/distributed_runtime/worker_env.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/lib/strings/strcat.h" @@ -41,14 +40,14 @@ limitations under the License. #include "tensorflow/core/public/session_options.h" namespace tensorflow { - namespace { -class TensorFlowServer : public ServerInterface { + +class GrpcServer : public ServerInterface { public: - TensorFlowServer(const ServerDef& server_def, Env* env) + GrpcServer(const ServerDef& server_def, Env* env) : server_def_(server_def), env_(env), state_(NEW) {} - ~TensorFlowServer() { + ~GrpcServer() { Stop(); Join(); @@ -59,8 +58,14 @@ class TensorFlowServer : public ServerInterface { // to destroy them. delete master_env_.worker_cache; // Shared with worker_env.worker_cache. - delete worker_env_.device_mgr; + // We must delete graph_mgr before device_mgr, due to shared + // ownership of OpKernels in the executors. (The graph_mgr will + // free all stateless OpKernels, and pass over borrowed stateful + // OpKernels, which are also held in their respective devices' + // OpSegments.) delete worker_env_.graph_mgr; + delete worker_env_.device_mgr; + delete worker_env_.rendezvous_mgr; // Do not delete (as these are not owned by the server): @@ -91,6 +96,56 @@ class TensorFlowServer : public ServerInterface { return errors::Internal("Could not parse worker name."); } + // Look up the port that has been requested for this task in `server_def_`. + requested_port_ = -1; + for (const auto& job : server_def_.cluster().job()) { + if (job.name() == server_def_.job_name()) { + auto iter = job.tasks().find(server_def_.task_index()); + if (iter == job.tasks().end()) { + return errors::InvalidArgument("Task ", server_def_.task_index(), + " was not defined in job \"", + server_def_.job_name(), "\""); + } else if (!str_util::NumericParse32( + str_util::Split(iter->second, ':')[1], + &requested_port_)) { + return errors::Internal( + "Could not parse port for local server from \"", iter->second, + "\""); + } else { + break; + } + } + } + if (requested_port_ == -1) { + return errors::Internal("Job \"", server_def_.job_name(), + "\" was not defined in cluster"); + } + + // N.B. The order of initialization here is intricate, because we + // wish to allow `requested_port_ == 0` (for choosing any port, + // mostly for testing). Therefore, the construction of the channel + // and worker caches depends on `bound_port_`, which is not set + // until we call `builder.BuildAndStart()`. We must create the + // service objects before calling `builder.BuildAndStart()`, but + // `master_env_` and `worker_env_` are only partially + // configured. However, this is not dangerous, because we do not + // start serving requests until `this->Start()` is called, which + // happens after this method returns. + // + // TODO(mrry): Provide a general mechanism for dynamically setting + // the identities of tasks in the worker pool after the service is + // running. + ::grpc::ServerBuilder builder; + builder.AddListeningPort(strings::StrCat("0.0.0.0:", requested_port_), + ::grpc::InsecureServerCredentials(), &bound_port_); + master_service_ = NewGrpcMasterService(&master_env_, &builder); + worker_service_ = NewGrpcWorkerService(&worker_env_, &builder); + server_ = builder.BuildAndStart(); + + if (!server_) { + return errors::Internal("Could not start gRPC server"); + } + GrpcChannelSpec channel_spec; for (const auto& job : server_def_.cluster().job()) { int max_task_id = -1; @@ -99,7 +154,12 @@ class TensorFlowServer : public ServerInterface { } std::vector<string> host_ports(max_task_id + 1); for (const auto& task : job.tasks()) { - host_ports[task.first] = task.second; + if (job.name() == server_def_.job_name() && + task.first == server_def_.task_index()) { + host_ports[task.first] = strings::StrCat("localhost:", bound_port_); + } else { + host_ports[task.first] = task.second; + } } channel_spec.AddHostPortsJob(job.name(), host_ports, host_ports.size()); } @@ -133,12 +193,6 @@ class TensorFlowServer : public ServerInterface { mutex_lock l(mu_); switch (state_) { case NEW: { - ::grpc::ServerBuilder builder; - builder.AddListeningPort(strings::StrCat("0.0.0.0:", requested_port_), - ::grpc::InsecureServerCredentials()); - master_service_ = NewGrpcMasterService(&master_env_, &builder); - worker_service_ = NewGrpcWorkerService(&worker_env_, &builder); - server_ = builder.BuildAndStart(); master_thread_.reset( env_->StartThread(ThreadOptions(), "TF_master_service", [this] { master_service_->HandleRPCsLoop(); })); @@ -196,7 +250,9 @@ class TensorFlowServer : public ServerInterface { } } - const string& target() const override { return target_; } + const string target() const override { + return strings::StrCat("grpc://localhost:", bound_port_); + } private: // The overall server configuration. @@ -204,8 +260,9 @@ class TensorFlowServer : public ServerInterface { Env* env_; // The port requested for this server. - // TODO(mrry): Support requested_port_ == 0 to bind to any available port. int requested_port_; + // The port to which this server is bound. + int bound_port_ = 0; // The `SessionOptions.target` to be used when connecting to this // server (as a master). @@ -238,15 +295,30 @@ class TensorFlowServer : public ServerInterface { std::unique_ptr<::grpc::Server> server_ GUARDED_BY(mu_); }; -} // namespace -Status NewServer(const ServerDef& server_def, - std::unique_ptr<ServerInterface>* out_server) { - std::unique_ptr<TensorFlowServer> ret( - new TensorFlowServer(server_def, Env::Default())); - TF_RETURN_IF_ERROR(ret->Init()); - *out_server = std::move(ret); - return Status::OK(); -} +class GrpcServerFactory : public ServerFactory { + public: + bool AcceptsOptions(const ServerDef& server_def) override { + return server_def.protocol() == "grpc"; + } + Status NewServer(const ServerDef& server_def, + std::unique_ptr<ServerInterface>* out_server) override { + std::unique_ptr<GrpcServer> ret(new GrpcServer(server_def, Env::Default())); + TF_RETURN_IF_ERROR(ret->Init()); + *out_server = std::move(ret); + return Status::OK(); + } +}; + +// Registers a `ServerFactory` for `GrpcServer` instances. +class GrpcServerRegistrar { + public: + GrpcServerRegistrar() { + ServerFactory::Register("GRPC_SERVER", new GrpcServerFactory()); + } +}; +static GrpcServerRegistrar registrar; + +} // namespace } // namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h deleted file mode 100644 index a06989d88c..0000000000 --- a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h +++ /dev/null @@ -1,65 +0,0 @@ -/* Copyright 2016 Google Inc. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_SERVER_LIB_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_SERVER_LIB_H_ - -#include <memory> - -#include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/platform/macros.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/protobuf/tensorflow_server.pb.h" - -namespace tensorflow { - -// Represents a single TensorFlow server, which exports Master and Worker -// services. -class ServerInterface { - public: - ServerInterface() {} - virtual ~ServerInterface() {} - - // Starts the server running asynchronously. Returns OK on success, otherwise - // returns an error. - virtual Status Start() = 0; - - // Stops the server asynchronously. Returns OK on success, otherwise returns - // an error. - // - // After calling `Stop()`, the caller may call `Join()` to block until the - // server has stopped. - virtual Status Stop() = 0; - - // Blocks until the server has stopped. Returns OK on success, otherwise - // returns an error. - virtual Status Join() = 0; - - // Returns a target string that can be used to connect to this server using - // `tensorflow::NewSession()`. - virtual const string& target() const = 0; - - private: - TF_DISALLOW_COPY_AND_ASSIGN(ServerInterface); -}; - -// Creates a server based on the given `server_def`, and stores it in -// *out_server. Returns OK on success, otherwise returns an error. -Status NewServer(const ServerDef& server_def, - std::unique_ptr<ServerInterface>* out_server); - -} // namespace tensorflow - -#endif // THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_SERVER_LIB_H_ diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib_test.cc b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib_test.cc index 902519769c..a56afb05a6 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib_test.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h" +#include "tensorflow/core/distributed_runtime/server_lib.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_session.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -25,6 +25,7 @@ namespace tensorflow { // when no calls are made against the server. TEST(Server, StopAfterNoop) { ServerDef def; + def.set_protocol("grpc"); def.set_job_name("localhost"); def.set_task_index(0); JobDef* job_def = def.mutable_cluster()->add_job(); @@ -42,6 +43,7 @@ TEST(Server, StopAfterNoop) { // when a simple call is made against the server. TEST(Server, StopAfterCall) { ServerDef def; + def.set_protocol("grpc"); def.set_job_name("localhost"); def.set_task_index(0); JobDef* job_def = def.mutable_cluster()->add_job(); diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_tensorflow_server.cc b/tensorflow/core/distributed_runtime/rpc/grpc_tensorflow_server.cc index 4b4aa0a2f9..27651b4770 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_tensorflow_server.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_tensorflow_server.cc @@ -19,7 +19,7 @@ limitations under the License. #include "grpc++/security/credentials.h" #include "grpc++/server_builder.h" -#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h" +#include "tensorflow/core/distributed_runtime/server_lib.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" @@ -31,10 +31,13 @@ limitations under the License. #include "tensorflow/core/util/command_line_flags.h" // This binary starts a TensorFlow server (master and worker). +// +// TODO(mrry): Replace with a py_binary that uses `tf.GrpcServer()`. namespace tensorflow { namespace { Status ParseFlagsForTask(int argc, char* argv[], ServerDef* options) { + options->set_protocol("grpc"); string cluster_spec; int task_index = 0; const bool parse_result = ParseFlags( diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_testlib_server.cc b/tensorflow/core/distributed_runtime/rpc/grpc_testlib_server.cc index 700ae1f373..a563f124c4 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_testlib_server.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_testlib_server.cc @@ -17,7 +17,7 @@ limitations under the License. #include "grpc++/security/credentials.h" #include "grpc++/server_builder.h" -#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h" +#include "tensorflow/core/distributed_runtime/server_lib.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" @@ -33,6 +33,7 @@ namespace tensorflow { namespace { Status ParseFlagsForTask(int argc, char* argv[], ServerDef* options) { + options->set_protocol("grpc"); string job_spec; int num_cpus = 1; int num_gpus = 0; diff --git a/tensorflow/core/distributed_runtime/server_lib.cc b/tensorflow/core/distributed_runtime/server_lib.cc new file mode 100644 index 0000000000..45d4f70a3f --- /dev/null +++ b/tensorflow/core/distributed_runtime/server_lib.cc @@ -0,0 +1,73 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/distributed_runtime/server_lib.h" + +#include <unordered_map> + +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/mutex.h" + +namespace tensorflow { + +namespace { +mutex* get_server_factory_lock() { + static mutex server_factory_lock; + return &server_factory_lock; +} + +typedef std::unordered_map<string, ServerFactory*> ServerFactories; +ServerFactories* server_factories() { + static ServerFactories* factories = new ServerFactories; + return factories; +} +} // namespace + +/* static */ +void ServerFactory::Register(const string& server_type, + ServerFactory* factory) { + mutex_lock l(*get_server_factory_lock()); + if (!server_factories()->insert({server_type, factory}).second) { + LOG(ERROR) << "Two server factories are being registered under " + << server_type; + } +} + +/* static */ +Status ServerFactory::GetFactory(const ServerDef& server_def, + ServerFactory** out_factory) { + mutex_lock l(*get_server_factory_lock()); + // TODO(mrry): Improve the error reporting here. + for (const auto& server_factory : *server_factories()) { + if (server_factory.second->AcceptsOptions(server_def)) { + *out_factory = server_factory.second; + return Status::OK(); + } + } + return errors::NotFound( + "No server factory registered for the given ServerDef: ", + server_def.DebugString()); +} + +// Creates a server based on the given `server_def`, and stores it in +// `*out_server`. Returns OK on success, otherwise returns an error. +Status NewServer(const ServerDef& server_def, + std::unique_ptr<ServerInterface>* out_server) { + ServerFactory* factory; + TF_RETURN_IF_ERROR(ServerFactory::GetFactory(server_def, &factory)); + return factory->NewServer(server_def, out_server); +} + +} // namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/server_lib.h b/tensorflow/core/distributed_runtime/server_lib.h new file mode 100644 index 0000000000..dea682795a --- /dev/null +++ b/tensorflow/core/distributed_runtime/server_lib.h @@ -0,0 +1,98 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_SERVER_LIB_H_ +#define THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_SERVER_LIB_H_ + +#include <memory> + +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/protobuf/tensorflow_server.pb.h" + +namespace tensorflow { + +// This library supports a registration/factory-based mechanism for +// creating TensorFlow server objects. Each server implementation must +// have an accompanying implementation of ServerFactory, and create a +// static "registrar" object that calls `ServerFactory::Register()` +// with an instance of the factory class. See "rpc/grpc_server_lib.cc" +// for an example. + +// Represents a single TensorFlow server that exports Master and Worker +// services. +class ServerInterface { + public: + ServerInterface() {} + virtual ~ServerInterface() {} + + // Starts the server running asynchronously. Returns OK on success, otherwise + // returns an error. + virtual Status Start() = 0; + + // Stops the server asynchronously. Returns OK on success, otherwise returns + // an error. + // + // After calling `Stop()`, the caller may call `Join()` to block until the + // server has stopped. + virtual Status Stop() = 0; + + // Blocks until the server has stopped. Returns OK on success, otherwise + // returns an error. + virtual Status Join() = 0; + + // Returns a target string that can be used to connect to this server using + // `tensorflow::NewSession()`. + virtual const string target() const = 0; + + private: + TF_DISALLOW_COPY_AND_ASSIGN(ServerInterface); +}; + +class ServerFactory { + public: + // Creates a new server based on the given `server_def`, and stores + // it in `*out_server`. Returns OK on success, otherwise returns an + // error. + virtual Status NewServer(const ServerDef& server_def, + std::unique_ptr<ServerInterface>* out_server) = 0; + + // Returns true if and only if this factory can create a server + // based on the given `server_def`. + virtual bool AcceptsOptions(const ServerDef& server_def) = 0; + + virtual ~ServerFactory() {} + + // For each `ServerFactory` subclass, an instance of that class must + // be registered by calling this method. + // + // The `server_type` must be unique to the server factory. + static void Register(const string& server_type, ServerFactory* factory); + + // Looks up a factory that can create a server based on the given + // `server_def`, and stores it in `*out_factory`. Returns OK on + // success, otherwise returns an error. + static Status GetFactory(const ServerDef& server_def, + ServerFactory** out_factory); +}; + +// Creates a server based on the given `server_def`, and stores it in +// `*out_server`. Returns OK on success, otherwise returns an error. +Status NewServer(const ServerDef& server_def, + std::unique_ptr<ServerInterface>* out_server); + +} // namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_SERVER_LIB_H_ diff --git a/tensorflow/core/framework/op_kernel.h b/tensorflow/core/framework/op_kernel.h index f3aecf0b96..09a1aa1a17 100644 --- a/tensorflow/core/framework/op_kernel.h +++ b/tensorflow/core/framework/op_kernel.h @@ -1234,7 +1234,7 @@ inline void OpOutputList::set_ref(int i, mutex* mu, Tensor* tensor_for_ref) { } #define OP_REQUIRES(CTX, EXP, STATUS) \ - if (!(EXP)) { \ + if (!TF_PREDICT_TRUE(EXP)) { \ (CTX)->CtxFailure((STATUS)); \ return; \ } @@ -1242,14 +1242,14 @@ inline void OpOutputList::set_ref(int i, mutex* mu, Tensor* tensor_for_ref) { #define OP_REQUIRES_OK(CTX, STATUS) \ do { \ ::tensorflow::Status _s(STATUS); \ - if (!_s.ok()) { \ + if (!TF_PREDICT_TRUE(_s.ok())) { \ (CTX)->CtxFailureWithWarning(_s); \ return; \ } \ } while (0) #define OP_REQUIRES_ASYNC(CTX, EXP, STATUS, CALLBACK) \ - if (!(EXP)) { \ + if (!TF_PREDICT_TRUE(EXP)) { \ (CTX)->CtxFailure((STATUS)); \ (CALLBACK)(); \ return; \ @@ -1258,7 +1258,7 @@ inline void OpOutputList::set_ref(int i, mutex* mu, Tensor* tensor_for_ref) { #define OP_REQUIRES_OK_ASYNC(CTX, STATUS, CALLBACK) \ do { \ ::tensorflow::Status _s(STATUS); \ - if (!_s.ok()) { \ + if (!TF_PREDICT_TRUE(_s.ok())) { \ (CTX)->CtxFailureWithWarning(_s); \ (CALLBACK)(); \ return; \ diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index a7c7551ea6..7ca186eaec 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -28,16 +28,6 @@ cc_library( ], ) -cc_library( - name = "bounds_check", - hdrs = ["bounds_check.h"], - visibility = ["//visibility:private"], - deps = [ - "//tensorflow/core:framework", - "//third_party/eigen3", - ], -) - tf_kernel_library( name = "concat_lib", srcs = ["concat_lib_cpu.cc"], @@ -153,7 +143,6 @@ tf_proto_library( cc_api_version = 2, go_api_version = 2, java_api_version = 2, - py_api_version = 2, ) cc_library( @@ -200,6 +189,18 @@ cc_library( ], ) +# Private support libraries --------------------------------------------------- + +cc_library( + name = "bounds_check", + hdrs = ["bounds_check.h"], + visibility = ["//visibility:private"], + deps = [ + "//tensorflow/core:framework", + "//third_party/eigen3", + ], +) + # OpKernel libraries ---------------------------------------------------------- tf_kernel_libraries( @@ -652,6 +653,7 @@ tf_kernel_libraries( "sparse_matmul_op", ], deps = [ + ":bounds_check", ":fill_functor", ":transpose_functor", "//tensorflow/core:core_cpu", @@ -734,6 +736,7 @@ tf_kernel_libraries( "xent_op", ], deps = [ + ":bounds_check", ":conv_2d", ":conv_ops", ":depthwise_conv_op", @@ -980,6 +983,8 @@ filegroup( ], ) +# Core kernels we want on Android. Only a subset of kernels to keep +# base library small. filegroup( name = "android_core_ops", srcs = [ @@ -1036,6 +1041,22 @@ filegroup( ], ) +# Other kernels we may want on Android. +# +# The kernels can be consumed as a whole or in two groups for +# supporting separate compilation. Note that the split into groups +# is entirely for improving compilation time, and not for +# organizational reasons; you should not depend on any +# of those groups independently. +filegroup( + name = "android_extended_ops", + srcs = [ + ":android_extended_ops_group1", + ":android_extended_ops_group2", + ], + visibility = ["//visibility:public"], +) + filegroup( name = "android_extended_ops_headers", srcs = [ @@ -1090,6 +1111,7 @@ filegroup( "cwise_op_sub.cc", "cwise_op_tanh.cc", "dynamic_partition_op.cc", + ":android_extended_ops_headers", ], ) @@ -1122,6 +1144,7 @@ filegroup( "transpose_op.cc", "where_op.cc", "xent_op.cc", + ":android_extended_ops_headers", ], ) diff --git a/tensorflow/core/kernels/bounds_check.h b/tensorflow/core/kernels/bounds_check.h index 665cbdaff9..9bfbde9bc7 100644 --- a/tensorflow/core/kernels/bounds_check.h +++ b/tensorflow/core/kernels/bounds_check.h @@ -33,6 +33,19 @@ EIGEN_ALWAYS_INLINE bool FastBoundsCheck(Index index, Index limit) { static_cast<UIndex>(limit)); } +// Upcasting specializations when the index and bounds do not match; +// always move to the larger type. + +EIGEN_ALWAYS_INLINE bool FastBoundsCheck(int64 index, int32 limit) { + return TF_PREDICT_TRUE(static_cast<uint64>(index) < + static_cast<uint64>(limit)); +} + +EIGEN_ALWAYS_INLINE bool FastBoundsCheck(int32 index, int64 limit) { + return TF_PREDICT_TRUE(static_cast<uint64>(index) < + static_cast<uint64>(limit)); +} + namespace internal { // Ensure that the compiler cannot elide a copy into a local, for // bounds checking on source tensors that might be updated asynchronously. diff --git a/tensorflow/core/kernels/decode_csv_op.cc b/tensorflow/core/kernels/decode_csv_op.cc index c48dd309c7..18ee40e623 100644 --- a/tensorflow/core/kernels/decode_csv_op.cc +++ b/tensorflow/core/kernels/decode_csv_op.cc @@ -214,9 +214,10 @@ class DecodeCSVOp : public OpKernel { } OP_REQUIRES( - ctx, input[current_idx] == '"' && - (static_cast<size_t>(current_idx) == input.size() - 1 || - input[current_idx + 1] == delim_), + ctx, (static_cast<size_t>(current_idx) < input.size() && + input[current_idx] == '"' && + (static_cast<size_t>(current_idx) == input.size() - 1 || + input[current_idx + 1] == delim_)), errors::InvalidArgument("Quoted field has to end with quote " "followed by delim or end")); diff --git a/tensorflow/core/kernels/in_topk_op.cc b/tensorflow/core/kernels/in_topk_op.cc index 0172031e43..47f334ba76 100644 --- a/tensorflow/core/kernels/in_topk_op.cc +++ b/tensorflow/core/kernels/in_topk_op.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/kernels/bounds_check.h" namespace tensorflow { @@ -55,7 +56,10 @@ class InTopK : public OpKernel { const auto size = targets.size(); const auto num_classes = predictions.dimension(1); for (int b = 0; b < size; b++) { - T target_prediction = predictions(b, targets(b)); + auto target = internal::SubtleMustCopy(targets(b)); + OP_REQUIRES(context, FastBoundsCheck(target, num_classes), + errors::InvalidArgument("targets[", b, "] is out of range")); + T target_prediction = predictions(b, target); bool cannot_say = !std::isfinite(target_prediction); int more_probable_classes = 0; if (!cannot_say) { diff --git a/tensorflow/core/kernels/segment_reduction_ops.cc b/tensorflow/core/kernels/segment_reduction_ops.cc index 8b672960d3..a4e75a7796 100644 --- a/tensorflow/core/kernels/segment_reduction_ops.cc +++ b/tensorflow/core/kernels/segment_reduction_ops.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/bounds_check.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/util/util.h" @@ -78,7 +79,9 @@ class SegmentReductionOp : public OpKernel { // Note that the current implementation assumes that segment_vec values are // sorted. const Index output_rows = - num_indices > 0 ? segment_vec(num_indices - 1) + 1 : 0; + num_indices > 0 + ? internal::SubtleMustCopy(segment_vec(num_indices - 1)) + 1 + : 0; TensorShape output_shape = input.shape(); output_shape.set_dim(0, output_rows); @@ -118,7 +121,14 @@ class SegmentReductionOp : public OpKernel { typedef Eigen::TensorMap<Eigen::Tensor<T, 1, Eigen::RowMajor>, Eigen::Unaligned> OutT; - T* out_slice_ptr = &output_flat(segment_vec(start), 0); + + Index out_index = internal::SubtleMustCopy(segment_vec(start)); + OP_REQUIRES( + context, FastBoundsCheck(out_index, output_rows), + errors::InvalidArgument( + "Segment id ", out_index, " out of range [0, ", output_rows, + "), probably because 'segment_ids' input is not sorted.")); + T* out_slice_ptr = &output_flat(out_index, 0); OutT out_slice(out_slice_ptr, out_slice_shape); // We don't use out_slice.device(context->eigen_device<Device>) // because these pieces of work are likely to be very small and @@ -208,7 +218,6 @@ class UnsortedSegmentSumOp : public OpKernel { context, IsLegacyScalar(num_segments.shape()), errors::InvalidArgument("num_segments should be a scalar, not shape ", num_segments.shape().DebugString())); - OP_REQUIRES( context, TensorShapeUtils::StartsWith(data.shape(), segment_ids.shape()), @@ -218,15 +227,11 @@ class UnsortedSegmentSumOp : public OpKernel { const auto segment_flat = segment_ids.flat<Index>(); const int32 N = segment_flat.dimension(0); - const int32 output_rows = num_segments.scalar<int32>()(); - - for (int i = 0; i < N; i++) { - int j = segment_flat(i); - OP_REQUIRES(context, 0 <= j && j < output_rows, - errors::InvalidArgument( - "segment_ids", SliceDebugString(segment_ids.shape(), i), - " = ", j, " is out of range [0, ", output_rows, ")")); - } + const Index output_rows = + internal::SubtleMustCopy(num_segments.scalar<int32>()()); + OP_REQUIRES(context, output_rows >= 0, + errors::InvalidArgument("Input num_segments == ", output_rows, + " must not be negative.")); TensorShape output_shape; output_shape.AddDim(output_rows); @@ -242,8 +247,12 @@ class UnsortedSegmentSumOp : public OpKernel { if (data.NumElements() > 0) { auto data_flat = data.shaped<T, 2>({N, data.NumElements() / N}); for (int i = 0; i < N; ++i) { - output_flat.template chip<0>(segment_flat(i)) += - data_flat.template chip<0>(i); + Index j = internal::SubtleMustCopy(segment_flat(i)); + OP_REQUIRES(context, FastBoundsCheck(j, output_rows), + errors::InvalidArgument( + "segment_ids", SliceDebugString(segment_ids.shape(), i), + " = ", j, " is out of range [0, ", output_rows, ")")); + output_flat.template chip<0>(j) += data_flat.template chip<0>(i); } } } diff --git a/tensorflow/core/kernels/stack_ops.cc b/tensorflow/core/kernels/stack_ops.cc index 2ba571bcdb..4bddcd7e98 100644 --- a/tensorflow/core/kernels/stack_ops.cc +++ b/tensorflow/core/kernels/stack_ops.cc @@ -181,47 +181,52 @@ class StackPushOp : public AsyncOpKernel { // Push the tensor onto the stack. Swap the tensor to CPU if instructed. const Tensor& tensor = ctx->input(1); AllocatorAttributes alloc_attrs = ctx->input_alloc_attr(1); - DeviceContext* device_ctxt = ctx->op_device_context(); - auto device = static_cast<tensorflow::Device*>(ctx->device()); - Allocator* allocator = device->GetAllocator(alloc_attrs); - AllocatorStats stats; - allocator->GetStats(&stats); + // For now, we use a simple heuristic for swapping: A GPU tensor is moved + // to CPU if the tensor has more than kCopyThreshold bytes and the GPU + // allocator says more than kOccupancy of the memory is in use. static constexpr int kCopyThreshold = 2048; static constexpr double kOccupancy = 0.7; if (swap_memory_ && !alloc_attrs.on_host() && std::is_same<Device, GPUDevice>::value && - stats.bytes_in_use > (stats.bytes_limit * kOccupancy) && tensor.TotalBytes() > kCopyThreshold) { - // Asynchronously copy the tensor from GPU to CPU memory. - // TODO(yuanbyu): Swap the oldest tensor first. - AllocatorAttributes host_alloc_attrs; - host_alloc_attrs.set_gpu_compatible(true); - host_alloc_attrs.set_on_host(true); - Allocator* cpu_allocator = device->GetAllocator(host_alloc_attrs); - Tensor* cpu_tensor = - new Tensor(cpu_allocator, tensor.dtype(), tensor.shape()); - device_ctxt->CopyDeviceTensorToCPU( - &tensor, "StackPush", device, cpu_tensor, - [cpu_tensor, stack, ctx, done](const Status& s) { - ctx->SetStatus(s); - if (s.ok()) { - AllocatorAttributes alloc_attrs = ctx->input_alloc_attr(1); - ctx->SetStatus(stack->Push( - {PersistentTensor(*cpu_tensor), alloc_attrs, true})); - } - if (ctx->status().ok()) { - ctx->set_output(0, *cpu_tensor); - } - done(); - delete cpu_tensor; - }); - } else { - // Execute synchronously if not swapped. - OP_REQUIRES_OK( - ctx, stack->Push({PersistentTensor(tensor), alloc_attrs, false})); - ctx->set_output(0, tensor); - done(); + DeviceContext* device_ctxt = ctx->op_device_context(); + auto device = static_cast<tensorflow::Device*>(ctx->device()); + Allocator* allocator = device->GetAllocator(alloc_attrs); + AllocatorStats stats; + allocator->GetStats(&stats); + if (stats.bytes_in_use > (stats.bytes_limit * kOccupancy)) { + // Asynchronously copy the tensor from GPU to CPU memory. + // TODO(yuanbyu): Swap the oldest tensor first. + AllocatorAttributes host_alloc_attrs; + host_alloc_attrs.set_gpu_compatible(true); + host_alloc_attrs.set_on_host(true); + Allocator* cpu_allocator = device->GetAllocator(host_alloc_attrs); + Tensor* cpu_tensor = + new Tensor(cpu_allocator, tensor.dtype(), tensor.shape()); + device_ctxt->CopyDeviceTensorToCPU( + &tensor, "StackPush", device, cpu_tensor, + [cpu_tensor, stack, ctx, done](const Status& s) { + ctx->SetStatus(s); + if (s.ok()) { + AllocatorAttributes alloc_attrs = ctx->input_alloc_attr(1); + ctx->SetStatus(stack->Push( + {PersistentTensor(*cpu_tensor), alloc_attrs, true})); + } + if (ctx->status().ok()) { + ctx->set_output(0, *cpu_tensor); + } + done(); + delete cpu_tensor; + }); + return; + } } + + // Execute synchronously if not swapped. + OP_REQUIRES_OK(ctx, + stack->Push({PersistentTensor(tensor), alloc_attrs, false})); + ctx->set_output(0, tensor); + done(); } bool IsExpensive() override { return false; } diff --git a/tensorflow/core/kernels/transpose_op.cc b/tensorflow/core/kernels/transpose_op.cc index 88786ec774..5ecef9c6f9 100644 --- a/tensorflow/core/kernels/transpose_op.cc +++ b/tensorflow/core/kernels/transpose_op.cc @@ -169,7 +169,7 @@ Status TransposeGpuOp::DoTranspose(OpKernelContext* ctx, const Tensor& in, .TypeConstraint<T>("T") \ .HostMemory("perm"), \ TransposeGpuOp); -TF_CALL_NUMBER_TYPES(REGISTER); +TF_CALL_POD_TYPES(REGISTER); #undef REGISTER #endif diff --git a/tensorflow/core/lib/random/philox_random.h b/tensorflow/core/lib/random/philox_random.h index 5fc2d5d20d..dc8de09d2c 100644 --- a/tensorflow/core/lib/random/philox_random.h +++ b/tensorflow/core/lib/random/philox_random.h @@ -25,7 +25,7 @@ limitations under the License. #include "tensorflow/core/platform/types.h" // Function qualifiers that need to work on both CPU and GPU. -#ifdef __CUDA_ARCH__ +#if defined(__CUDACC__) // For nvcc. #define PHILOX_DEVICE_FUNC __host__ __device__ #define PHILOX_INLINE __inline__ diff --git a/tensorflow/core/platform/macros.h b/tensorflow/core/platform/macros.h index 9cc08eca52..c7d5e63a1b 100644 --- a/tensorflow/core/platform/macros.h +++ b/tensorflow/core/platform/macros.h @@ -50,8 +50,8 @@ limitations under the License. #define TF_PREDICT_FALSE(x) (__builtin_expect(x, 0)) #define TF_PREDICT_TRUE(x) (__builtin_expect(!!(x), 1)) #else -#define TF_PREDICT_FALSE(x) x -#define TF_PREDICT_TRUE(x) x +#define TF_PREDICT_FALSE(x) (x) +#define TF_PREDICT_TRUE(x) (x) #endif // A macro to disallow the copy constructor and operator= functions diff --git a/tensorflow/core/protobuf/tensorflow_server.proto b/tensorflow/core/protobuf/tensorflow_server.proto index 5b4ee3e85a..9b8ec1b5ed 100644 --- a/tensorflow/core/protobuf/tensorflow_server.proto +++ b/tensorflow/core/protobuf/tensorflow_server.proto @@ -105,4 +105,9 @@ message ServerDef { // The default configuration for sessions that run on this server. ConfigProto default_session_config = 4; + + // The protocol to be used by this server. + // + // Acceptable values include: "grpc". + string protocol = 5; } diff --git a/tensorflow/examples/udacity/README.md b/tensorflow/examples/udacity/README.md index 8308a56766..a6d6f8742a 100644 --- a/tensorflow/examples/udacity/README.md +++ b/tensorflow/examples/udacity/README.md @@ -6,7 +6,7 @@ Course information can be found at https://www.udacity.com/course/deep-learning- Running the Docker container from the Google Cloud repository ------------------------------------------------------------- - docker run -p 8888:8888 -it --rm b.gcr.io/tensorflow-udacity/assignments + docker run -p 8888:8888 -it --rm b.gcr.io/tensorflow-udacity/assignments:0.3.0 Accessing the Notebooks ----------------------- @@ -61,9 +61,9 @@ This will allow you to save work and have access to generated files on the host Pushing a Google Cloud release ------------------------------ - V=0.2.0 + V=0.3.0 docker tag $USER/assignments b.gcr.io/tensorflow-udacity/assignments:$V - docker tag $USER/assignments b.gcr.io/tensorflow-udacity/assignments:latest + docker tag -f $USER/assignments b.gcr.io/tensorflow-udacity/assignments:latest gcloud docker push b.gcr.io/tensorflow-udacity/assignments History @@ -71,3 +71,4 @@ History * 0.1.0: Initial release. * 0.2.0: Many fixes, including lower memory footprint and support for Python 3. +* 0.3.0: Use 0.7.1 release. diff --git a/tensorflow/models/image/mnist/convolutional.py b/tensorflow/models/image/mnist/convolutional.py index 507165b0aa..aaba57b3b0 100644 --- a/tensorflow/models/image/mnist/convolutional.py +++ b/tensorflow/models/image/mnist/convolutional.py @@ -285,7 +285,7 @@ def main(argv=None): # pylint: disable=unused-argument batch_data = train_data[offset:(offset + BATCH_SIZE), ...] batch_labels = train_labels[offset:(offset + BATCH_SIZE)] # This dictionary maps the batch data (as a numpy array) to the - # node in the graph is should be fed to. + # node in the graph it should be fed to. feed_dict = {train_data_node: batch_data, train_labels_node: batch_labels} # Run the graph and fetch some of the nodes. diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 7533317ac3..6fcb33d39b 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -821,6 +821,7 @@ py_library( deps = [ ":framework", ":ops", + ":server_lib", ":session", ":training_ops", ], @@ -897,6 +898,7 @@ tf_py_wrap_cc( srcs = ["tensorflow.i"], swig_includes = [ "client/events_writer.i", + "client/server_lib.i", "client/tf_session.i", "framework/python_op_gen.i", "lib/core/py_func.i", @@ -915,6 +917,8 @@ tf_py_wrap_cc( ":py_record_writer_lib", ":python_op_gen", ":tf_session_helper", + "//tensorflow/core/distributed_runtime:server_lib", + "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib", "//tensorflow/core/distributed_runtime/rpc:grpc_session", "//util/python:python_headers", ], @@ -940,6 +944,28 @@ py_library( ], ) +py_library( + name = "server_lib", + srcs = ["client/server_lib.py"], + srcs_version = "PY2AND3", + deps = [ + ":pywrap_tensorflow", + ], +) + +py_test( + name = "server_lib_test", + srcs = ["client/server_lib_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":extra_py_tests_deps", + ":framework", + ":framework_test_lib", + ":server_lib", + ":session", + ], +) + # Just used by tests. tf_cuda_library( name = "construction_fails_op", diff --git a/tensorflow/python/__init__.py b/tensorflow/python/__init__.py index aab8ada371..b04912b107 100644 --- a/tensorflow/python/__init__.py +++ b/tensorflow/python/__init__.py @@ -111,6 +111,7 @@ __all__ = make_all(__name__, # documentation, or remove. __all__.extend([ 'AttrValue', + 'ClusterDef', 'ConfigProto', 'Event', 'GPUOptions', @@ -119,7 +120,9 @@ __all__.extend([ 'GRAPH_DEF_VERSION_MIN_PRODUCER', 'GraphDef', 'GraphOptions', + 'GrpcServer', 'HistogramProto', + 'JobDef', 'LogMessage', 'NameAttrList', 'NodeDef', @@ -127,6 +130,7 @@ __all__.extend([ 'PaddingFIFOQueue', 'RunOptions', 'RunOutputs', + 'ServerDef', 'SessionLog', 'Summary', 'arg_max', diff --git a/tensorflow/python/client/client_lib.py b/tensorflow/python/client/client_lib.py index b06b37b7d0..0ab8f9dce0 100644 --- a/tensorflow/python/client/client_lib.py +++ b/tensorflow/python/client/client_lib.py @@ -51,6 +51,12 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +# NOTE(mrry): Support for `tf.GrpcServer` is currently experimental. +from tensorflow.core.protobuf.tensorflow_server_pb2 import ClusterDef +from tensorflow.core.protobuf.tensorflow_server_pb2 import JobDef +from tensorflow.core.protobuf.tensorflow_server_pb2 import ServerDef +from tensorflow.python.client.server_lib import GrpcServer + from tensorflow.python.client.session import InteractiveSession from tensorflow.python.client.session import Session diff --git a/tensorflow/python/client/server_lib.i b/tensorflow/python/client/server_lib.i new file mode 100644 index 0000000000..835f883ef4 --- /dev/null +++ b/tensorflow/python/client/server_lib.i @@ -0,0 +1,88 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +%nothread tensorflow::ServerInterface::Join; + +%include "tensorflow/python/platform/base.i" + +//%newobject tensorflow::NewServer; + +%typemap(in) const ServerDef& (tensorflow::ServerDef temp) { + char* c_string; + Py_ssize_t py_size; + if (PyBytes_AsStringAndSize($input, &c_string, &py_size) == -1) { + // Python has raised an error (likely TypeError or UnicodeEncodeError). + SWIG_fail; + } + + if (!temp.ParseFromString(string(c_string, py_size))) { + PyErr_SetString( + PyExc_TypeError, + "The ServerDef could not be parsed as a valid protocol buffer"); + SWIG_fail; + } + $1 = &temp; +} + +%typemap(in, numinputs=0) + std::unique_ptr<tensorflow::ServerInterface>* out_server ( + std::unique_ptr<tensorflow::ServerInterface> temp) { + $1 = &temp; +} + +%typemap(out) tensorflow::Status tensorflow::NewServer { + if (!$1.ok()) { + RaiseStatusNotOK($1, $descriptor(tensorflow::Status*)); + SWIG_fail; + } +} + +%typemap(argout) std::unique_ptr<tensorflow::ServerInterface>* out_server { + // TODO(mrry): Convert this to SWIG_POINTER_OWN when the issues with freeing + // a server are fixed. + $result = SWIG_NewPointerObj($1->release(), + $descriptor(tensorflow::ServerInterface*), + 0); +} + +%feature("except") tensorflow::ServerInterface::Join { + // Let other threads run while we wait for the server to shut down. + Py_BEGIN_ALLOW_THREADS + $action + Py_END_ALLOW_THREADS +} + +%{ +#include "tensorflow/core/distributed_runtime/server_lib.h" + +using tensorflow::ServerDef; +%} + +%ignoreall + +%unignore tensorflow; +%unignore tensorflow::ServerDef; +%unignore tensorflow::ServerInterface; +%unignore tensorflow::ServerInterface::~ServerInterface; +%unignore tensorflow::ServerInterface::Start; +%unignore tensorflow::ServerInterface::Stop; +%unignore tensorflow::ServerInterface::Join; +%unignore tensorflow::ServerInterface::target; + +%unignore tensorflow::NewServer; + +%include "tensorflow/core/distributed_runtime/server_lib.h" + +%unignoreall diff --git a/tensorflow/python/client/server_lib.py b/tensorflow/python/client/server_lib.py new file mode 100644 index 0000000000..38612edf15 --- /dev/null +++ b/tensorflow/python/client/server_lib.py @@ -0,0 +1,86 @@ +# Copyright 2015 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""A Python interface for creating TensorFlow servers.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import six # pylint: disable=unused-import + +from tensorflow.core.protobuf import tensorflow_server_pb2 +from tensorflow.python import pywrap_tensorflow + + +class GrpcServer(object): + """An in-process TensorFlow server. + + NOTE(mrry): This class is experimental and not yet suitable for use. + """ + + def __init__(self, server_def, start=True): + """Creates a new server with the given definition. + + Args: + server_def: A `tf.ServerDef` protocol buffer, describing the server to + be created (and the cluster of which it is a member). + start: (Optional.) Boolean, indicating whether to start the server after + creating it. Defaults to `True`. + """ + if not isinstance(server_def, tensorflow_server_pb2.ServerDef): + raise TypeError("server_def must be a tf.ServerDef") + + self._server = pywrap_tensorflow.NewServer(server_def.SerializeToString()) + if start: + self.start() + + def start(self): + """Starts this server.""" + self._server.Start() + + def stop(self): + """Stops this server. + + NOTE(mrry): This method is currently not implemented. + """ + # TODO(mrry): Implement this. + raise NotImplementedError("GrpcServer.stop()") + + def join(self): + """Blocks until the server has shut down. + + NOTE(mrry): Since `GrpcServer.stop()` is not currently implemented, this + method blocks forever. + """ + self._server.Join() + + @property + def target(self): + """Returns the target for a `tf.Session` to connect to this server. + + To create a + [`tf.Session`](../../api_docs/python/client.md#Session) that + connects to this server, use the following snippet: + + ```python + server = tf.GrpcServer(...) + with tf.Session(server.target): + # ... + ``` + + Returns: + A string containing a session target for this server. + """ + return self._server.target() diff --git a/tensorflow/python/client/server_lib_test.py b/tensorflow/python/client/server_lib_test.py new file mode 100644 index 0000000000..5705f363d5 --- /dev/null +++ b/tensorflow/python/client/server_lib_test.py @@ -0,0 +1,65 @@ +# Copyright 2016 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tf.GrpcServer.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf + + +class GrpcServerTest(tf.test.TestCase): + + def _localServer(self): + server_def = tf.ServerDef(protocol="grpc") + job_def = server_def.cluster.job.add() + job_def.name = "local" + job_def.tasks[0] = "localhost:0" + server_def.job_name = job_def.name + server_def.task_index = 0 + return server_def + + def testRunStep(self): + server = tf.GrpcServer(self._localServer()) + server.start() + + with tf.Session(server.target) as sess: + c = tf.constant([[2, 1]]) + d = tf.constant([[1], [2]]) + e = tf.matmul(c, d) + print(sess.run(e)) + # TODO(mrry): Add `server.stop()` and `server.join()` when these work. + + def testMultipleSessions(self): + server = tf.GrpcServer(self._localServer()) + server.start() + + c = tf.constant([[2, 1]]) + d = tf.constant([[1], [2]]) + e = tf.matmul(c, d) + + sess_1 = tf.Session(server.target) + sess_2 = tf.Session(server.target) + + sess_1.run(e) + sess_2.run(e) + + sess_1.close() + sess_2.close() + # TODO(mrry): Add `server.stop()` and `server.join()` when these work. + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow/python/framework/gen_docs_combined.py b/tensorflow/python/framework/gen_docs_combined.py index 7a7c58b19f..7180f7d77c 100644 --- a/tensorflow/python/framework/gen_docs_combined.py +++ b/tensorflow/python/framework/gen_docs_combined.py @@ -127,8 +127,8 @@ _hidden_symbols = ["Event", "LogMessage", "Summary", "SessionLog", "xrange", "SessionInterface", "BaseSession", "NameAttrList", "AttrValue", "TensorArray", "OptimizerOptions", "CollectionDef", "MetaGraphDef", "QueueRunnerDef", - "SaverDef", "VariableDef", "TestCase", - ] + "SaverDef", "VariableDef", "TestCase", "GrpcServer", + "ClusterDef", "JobDef", "ServerDef"] def main(unused_argv): if not FLAGS.out_dir: diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index 95c7cfc2cf..dabc474f42 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -2526,9 +2526,11 @@ class Graph(object): return name @contextlib.contextmanager - def colocate_with(self, op): + def colocate_with(self, op, ignore_existing=False): """Returns a context manager that specifies an op to colocate with. + Note: this function is not for public use, only for internal libraries. + For example: ```python @@ -2543,6 +2545,9 @@ class Graph(object): Args: op: The op to colocate all created ops with. + ignore_existing: If true, only applies colocation of this op within + the context, rather than applying all colocation properties + on the stack. Raises: ValueError: if op is None. @@ -2569,6 +2574,10 @@ class Graph(object): device_fn_tmp = self._device_function_stack self._device_function_stack = [] + if ignore_existing: + current_stack = self._colocation_stack + self._colocation_stack = [] + self._colocation_stack.append(op) try: @@ -2578,6 +2587,10 @@ class Graph(object): self._device_function_stack = device_fn_tmp self._colocation_stack.pop() + # Reset the colocation stack if requested. + if ignore_existing: + self._colocation_stack = current_stack + @contextlib.contextmanager def device(self, device_name_or_function): """Returns a context manager that specifies the default device to use. @@ -3007,8 +3020,8 @@ def device(device_name_or_function): return get_default_graph().device(device_name_or_function) -def colocate_with(op): - return get_default_graph().colocate_with(op) +def colocate_with(op, ignore_existing=False): + return get_default_graph().colocate_with(op, ignore_existing) def name_scope(name): diff --git a/tensorflow/python/framework/ops_test.py b/tensorflow/python/framework/ops_test.py index b5dbd3c6f6..cfc96a0cc8 100644 --- a/tensorflow/python/framework/ops_test.py +++ b/tensorflow/python/framework/ops_test.py @@ -1283,6 +1283,14 @@ class ColocationGroupTest(test_util.TensorFlowTestCase): c = constant_op.constant(4.0) self.assertEqual(set([b"loc:@a", b"loc:@b"]), set(c.op.colocation_groups())) + def testColocationIgnoreStack(self): + a = constant_op.constant([2.0], name="a") + b = constant_op.constant(3.0, name="b") + with ops.colocate_with(a.op): + with ops.colocate_with(b.op, ignore_existing=True): + c = constant_op.constant(4.0) + self.assertEqual(set(["loc:@b"]), set(c.op.colocation_groups())) + def testColocateVariables(self): a = variables.Variable([2.0], name="a") with ops.colocate_with(a.op): diff --git a/tensorflow/python/kernel_tests/constant_op_test.py b/tensorflow/python/kernel_tests/constant_op_test.py index 766b416f75..d93020e825 100644 --- a/tensorflow/python/kernel_tests/constant_op_test.py +++ b/tensorflow/python/kernel_tests/constant_op_test.py @@ -328,6 +328,19 @@ class ZerosLikeTest(tf.test.TestCase): z = tf.zeros_like(d) self.assertEqual(d.get_shape().as_list(), z.get_shape().as_list()) + def testZerosLikeDtype(self): + # Make sure zeros_like works even for dtypes that cannot be cast between + with self.test_session(): + shape = (3, 5) + dtypes = np.float32, np.complex64 + for in_type in dtypes: + x = np.arange(15).astype(in_type).reshape(*shape) + for out_type in dtypes: + y = tf.zeros_like(x, dtype=out_type).eval() + self.assertEqual(y.dtype, out_type) + self.assertEqual(y.shape, shape) + self.assertAllEqual(y, np.zeros(shape, dtype=out_type)) + class OnesTest(tf.test.TestCase): diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py index 07f11354b9..77da519fcc 100644 --- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py +++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py @@ -898,7 +898,7 @@ class ControlFlowTest(tf.test.TestCase): r = control_flow_ops.While(c, b, [n, v], parallel_iterations=1) r = tf.gradients(r[1], x)[0] - self.assertEqual(r.get_shape().as_list(), [None]) + self.assertEqual(r.get_shape(), tensor_shape.unknown_shape()) self.assertAllClose([810.0, 2560.0], r.eval(feed_dict={x: [3.0, 4.0]})) def testWhileGrad_MultipleUses(self): diff --git a/tensorflow/python/kernel_tests/decode_csv_op_test.py b/tensorflow/python/kernel_tests/decode_csv_op_test.py index 37541284d1..959268a544 100644 --- a/tensorflow/python/kernel_tests/decode_csv_op_test.py +++ b/tensorflow/python/kernel_tests/decode_csv_op_test.py @@ -160,6 +160,21 @@ class DecodeCSVOpTest(tf.test.TestCase): args, expected_err_re="Unquoted fields cannot have quotes/CRLFs inside") + def testWrongDefaults(self): + args = { + "records": [",1", "0.2,2", "3.0adf,3"], + "record_defaults": [[1.0]] + } + + self._test(args, + expected_err_re="Expect 1 fields but have 2 in record 0") + + def testShortQuotedString(self): + args = {"records": ["\""], "record_defaults": [["default"]],} + + self._test(args, + expected_err_re="Quoted field has to end with quote followed.*") + if __name__ == "__main__": tf.test.main() diff --git a/tensorflow/python/kernel_tests/in_topk_op_test.py b/tensorflow/python/kernel_tests/in_topk_op_test.py index dd8a8350c8..97a064df9d 100644 --- a/tensorflow/python/kernel_tests/in_topk_op_test.py +++ b/tensorflow/python/kernel_tests/in_topk_op_test.py @@ -58,6 +58,14 @@ class InTopKTest(tf.test.TestCase): target = [0, 2] self._validateInTopK(predictions, target, 2, [False, False]) + def testBadTarget(self): + predictions = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.2, 0.3, 0.4]] + target = [0, 80000] + with self.test_session(): + with self.assertRaisesRegexp(tf.errors.InvalidArgumentError, + "target.*out of range"): + tf.nn.in_top_k(predictions, target, 2).eval() + if __name__ == "__main__": tf.test.main() diff --git a/tensorflow/python/kernel_tests/rnn_test.py b/tensorflow/python/kernel_tests/rnn_test.py index fe858b78b1..be59ac08c2 100644 --- a/tensorflow/python/kernel_tests/rnn_test.py +++ b/tensorflow/python/kernel_tests/rnn_test.py @@ -560,7 +560,7 @@ class LSTMTest(tf.test.TestCase): for out0, out1 in zip(outputs0_values, outputs1_values): self.assertAllEqual(out0, out1) - def _testDynamicEquivalentToStaticRNN(self, use_gpu): + def _testDynamicEquivalentToStaticRNN(self, use_gpu, use_sequence_length): time_steps = 8 num_units = 3 num_proj = 4 @@ -569,7 +569,10 @@ class LSTMTest(tf.test.TestCase): input_values = np.random.randn(time_steps, batch_size, input_size) - sequence_length = np.random.randint(0, time_steps, size=batch_size) + if use_sequence_length: + sequence_length = np.random.randint(0, time_steps, size=batch_size) + else: + sequence_length = None ########### Step 1: Run static graph and generate readouts with self.test_session(use_gpu=use_gpu, graph=tf.Graph()) as sess: @@ -744,8 +747,14 @@ class LSTMTest(tf.test.TestCase): self._testDoubleInputWithDropoutAndDynamicCalculation(use_gpu=True) def testDynamicEquivalentToStaticRNN(self): - self._testDynamicEquivalentToStaticRNN(use_gpu=False) - self._testDynamicEquivalentToStaticRNN(use_gpu=True) + self._testDynamicEquivalentToStaticRNN( + use_gpu=False, use_sequence_length=False) + self._testDynamicEquivalentToStaticRNN( + use_gpu=True, use_sequence_length=False) + self._testDynamicEquivalentToStaticRNN( + use_gpu=False, use_sequence_length=True) + self._testDynamicEquivalentToStaticRNN( + use_gpu=True, use_sequence_length=True) class BidirectionalRNNTest(tf.test.TestCase): @@ -1091,7 +1100,7 @@ def rnn_long_sequence_benchmark(batch_size, seqlen, num_units, def main(_): print("Graph Creation: Static Unroll vs. Dynamic Unroll LSTM") print("max_t \t dt(static) \t dt(dynamic) \t dt(dynamic)/dt(static)") - for max_time in (1, 25, 50): + for max_time in (1, 25, 50, 100, 200): graph_creation_static_vs_dynamic_rnn_benchmark(max_time) print("Calculation: Static Unroll with Dynamic Flow LSTM " diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py index 7257be5d59..44efdb538b 100644 --- a/tensorflow/python/ops/array_ops.py +++ b/tensorflow/python/ops/array_ops.py @@ -636,10 +636,12 @@ def zeros_like(tensor, dtype=None, name=None): """ with ops.op_scope([tensor], name, "zeros_like") as name: tensor = ops.convert_to_tensor(tensor, name="tensor") - ret = gen_array_ops._zeros_like(tensor) - if (dtype is not None) and (tensor.dtype != dtype): - ret = gen_math_ops.cast(ret, dtype) - return ret + if dtype is not None and tensor.dtype != dtype: + ret = zeros(shape(tensor), dtype, name=name) + ret.set_shape(tensor.get_shape()) + return ret + else: + return gen_array_ops._zeros_like(tensor, name=name) def ones_like(tensor, dtype=None, name=None): diff --git a/tensorflow/python/ops/control_flow_grad.py b/tensorflow/python/ops/control_flow_grad.py index cc911ca24f..aa85c12931 100644 --- a/tensorflow/python/ops/control_flow_grad.py +++ b/tensorflow/python/ops/control_flow_grad.py @@ -36,17 +36,18 @@ def _SwitchGrad(op, *grad): the merge on the first visit, and update the other input of the merge on the second visit. A next_iteration is also added on second visit. """ - real_op = GetRealOp(op) + graph = ops.get_default_graph() # pylint: disable=protected-access - ctxt = real_op._get_control_flow_context() + op_ctxt = op._get_control_flow_context() + grad_ctxt = graph._get_control_flow_context() # pylint: enable=protected-access - if isinstance(ctxt, WhileContext): - merge_op = op.grad_state.switch_map.get(real_op) + if isinstance(op_ctxt, WhileContext): + merge_op = grad_ctxt.grad_state.switch_map.get(op) if merge_op: # This is the second time this Switch is visited. It comes from # the non-exit branch of the Switch, so update the second input # to the Merge. - # TODO: Need to perform shape inference with this new input. + # TODO: Perform shape inference with this new input. # pylint: disable=protected-access merge_op._update_input(1, control_flow_ops._NextIteration(grad[1])) # pylint: enable=protected-access @@ -58,21 +59,22 @@ def _SwitchGrad(op, *grad): # input of merge when we see this Switch the second time. merge_fn = control_flow_ops._Merge # pylint: disable=protected-access merge_op = merge_fn([grad[0], grad[0]], name="b_switch")[0] - op.grad_state.switch_map[real_op] = merge_op.op + grad_ctxt.grad_state.switch_map[op] = merge_op.op return merge_op, None - elif isinstance(ctxt, CondContext): - good_grad = grad[ctxt.branch] - zero_grad = grad[1 - ctxt.branch] - # If this Switch is wrapped, it is part of a cond within a loop. In - # this case, we have called ControlFlowState.ZeroLike() so grad is - # ready for merge. Otherwise, we need a switch to control zero_grad. - if not isinstance(op, ControlFlowOpWrapper): + elif isinstance(op_ctxt, CondContext): + good_grad = grad[op_ctxt.branch] + zero_grad = grad[1 - op_ctxt.branch] + # If we are in a grad context, this switch is part of a cond within a + # loop. In this case, we have called ControlFlowState.ZeroLike() so grad + # is ready for merge. Otherwise, we need a switch to control zero_grad. + if not (grad_ctxt and grad_ctxt.grad_state): dtype = good_grad.dtype - zero_grad = switch(zero_grad, ctxt.pred, dtype=dtype)[1 - ctxt.branch] + branch = op_ctxt.branch + zero_grad = switch(zero_grad, op_ctxt.pred, dtype=dtype)[1 - branch] return merge([good_grad, zero_grad], name="cond_grad")[0], None else: - false_grad = switch(grad[0], real_op.inputs[1])[0] - true_grad = switch(grad[1], real_op.inputs[1])[1] + false_grad = switch(grad[0], op.inputs[1])[0] + true_grad = switch(grad[1], op.inputs[1])[1] return merge([false_grad, true_grad])[0], None @@ -83,24 +85,24 @@ ops.RegisterGradient("RefSwitch")(_SwitchGrad) @ops.RegisterGradient("Merge") def _MergeGrad(op, grad, _): """Gradients for a Merge op are calculated using a Switch op.""" - real_op = GetRealOp(op) - input_op = real_op.inputs[0].op + input_op = op.inputs[0].op + graph = ops.get_default_graph() # pylint: disable=protected-access - ctxt = input_op._get_control_flow_context() + op_ctxt = input_op._get_control_flow_context() + grad_ctxt = graph._get_control_flow_context() # pylint: enable=protected-access - if isinstance(ctxt, WhileContext): - grad_ctxt = op.grad_state.grad_context + if isinstance(op_ctxt, WhileContext): # pylint: disable=protected-access return control_flow_ops._SwitchRefOrTensor(grad, grad_ctxt.pivot) # pylint: enable=protected-access - elif isinstance(ctxt, CondContext): - pred = ctxt.pred - if isinstance(op, ControlFlowOpWrapper): + elif isinstance(op_ctxt, CondContext): + pred = op_ctxt.pred + if grad_ctxt and grad_ctxt.grad_state: # This Merge node is part of a cond within a loop. # The backprop needs to have the value of this predicate for every # iteration. So we must have its values accumulated in the forward, and # use the accumulated values as the predicate for this backprop switch. - grad_state = op.grad_state + grad_state = grad_ctxt.grad_state real_pred = grad_state.history_map.get(pred.name) if not real_pred: # Remember the value of pred for every iteration. @@ -118,8 +120,8 @@ def _MergeGrad(op, grad, _): return control_flow_ops._SwitchRefOrTensor(grad, pred, name="cond_grad") # pylint: enable=protected-access else: - num_inputs = len(real_op.inputs) - cond = [math_ops.equal(real_op.outputs[1], i) for i in xrange(num_inputs)] + num_inputs = len(op.inputs) + cond = [math_ops.equal(op.outputs[1], i) for i in xrange(num_inputs)] # pylint: disable=protected-access return [control_flow_ops._SwitchRefOrTensor(grad, cond[i])[1] for i in xrange(num_inputs)] @@ -132,16 +134,17 @@ def _RefMergeGrad(op, grad, _): @ops.RegisterGradient("Exit") -def _ExitGrad(op, grad): +def _ExitGrad(_, grad): """Gradients for an exit op are calculated using an Enter op.""" - real_op = GetRealOp(op) + graph = ops.get_default_graph() # pylint: disable=protected-access - forward_ctxt = real_op._get_control_flow_context() + grad_ctxt = graph._get_control_flow_context() # pylint: enable=protected-access - if not forward_ctxt.back_prop: - # No gradient computation for this loop. + if not grad_ctxt.back_prop: + # The flag `back_prop` is set by users to suppress gradient + # computation for this loop. If the flag `back_prop` is true, + # no gradient computation. return None - grad_ctxt = op.grad_state.grad_context grad_ctxt.AddName(grad.name) enter_fn = control_flow_ops._Enter # pylint: disable=protected-access grad_ctxt.Enter() @@ -176,17 +179,14 @@ def _EnterGrad(op, grad): For loop variables, grad is the gradient so just add an exit. For loop invariants, we need to add an accumulator loop. """ - real_op = GetRealOp(op) + graph = ops.get_default_graph() # pylint: disable=protected-access - forward_ctxt = real_op._get_control_flow_context() + grad_ctxt = graph._get_control_flow_context() # pylint: enable=protected-access - if not forward_ctxt.back_prop: - # The flag `back_prop` is set by users to suppress gradient - # computation for this loop. If the flag `back_prop` is true, - # no gradient computation. + if not grad_ctxt.back_prop: + # If the flag `back_prop` is true, no gradient computation. return grad - grad_ctxt = op.grad_state.grad_context - if real_op.get_attr("is_constant"): + if op.get_attr("is_constant"): # Add a gradient accumulator for each loop invariant. result = grad_ctxt.AddBackPropAccumulator(grad) else: diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py index cfd2bed9c5..e7f8a6d76c 100644 --- a/tensorflow/python/ops/control_flow_ops.py +++ b/tensorflow/python/ops/control_flow_ops.py @@ -279,23 +279,23 @@ def _SwitchRefOrTensor(data, pred, name="Switch"): TypeError: if data is not a Tensor or IndexedSlices """ data = ops.convert_to_tensor_or_indexed_slices(data, name="data") - # NOTE(mrry): ops.device(None) below addresses the following scenario. + # NOTE(vrv): ops.colocate_with(data, ignore_existing=True) below + # addresses the following scenario. # # Assume you execute Optimizer.apply_gradients() in a branch of a cond(). # - # 1. The update op is created inside a `with tf.device(var.device):` block - # say var.device = "/job:ps/task:1". + # 1. The update op is created inside a `with ops.colocate(var):` block # # 2. Some tensor `data` is captured and a switch is created in a - # `with tf.device(data.device):` block (data.device = "/job:worker_train"). + # `with ops.colocate_with(data):` block. # - # with tf.device("/job:ps/task:1"): - # with tf.device("/job:worker_train"): + # with ops.colocate_with(var): + # with ops.colocate_with(data): # op = ... # - # But then calling `print op.device` returns: - # ==> "/job:worker_train/task:1" -- a device that doesn't exist in this case! - with ops.colocate_with(data): + # var and data may be pinned to different devices, so we want to ops + # created within ops.colocate_with(data) to ignore the existing stack. + with ops.colocate_with(data, ignore_existing=True): if isinstance(data, ops.Tensor): if not data.dtype.is_ref_dtype: return switch(data, pred, name=name) @@ -324,142 +324,21 @@ def _convert_flows_to_tensorarrays(tensors_or_tensorarrays, tensors_or_flows): for (ta, t_or_flow) in zip(tensors_or_tensorarrays, tensors_or_flows)] -class ControlFlowOpWrapper(object): - """A wrapper class for Operation. - - A wrapped op allows us to capture the uses of its inputs and outputs. In - gradients(), right before calling the gradient function of an op, we wrap - the op by calling MakeWrapper. So during the exection of the gradient - function of an op , any time when one of its inputs/outputs is used, we - generate code to remember its values for all iterations. - """ - - class _ControlFlowOpInputs(object): - """An indirection to capture the input tensors needed in backprop.""" - - def __init__(self, op, grad_state): - self._op = op - self._grad_state = grad_state - self._inputs = None - - def __len__(self): - return len(self._op._inputs) - - def __getitem__(self, index): - if self._inputs is None: - self._inputs = [None for _ in self._op.inputs] - if isinstance(index, int): - val = self._inputs[index] - if val is None: - f_val = self._op.inputs[index] - val = self._grad_state.GetRealValue(f_val) - self._inputs[index] = val - return val - elif isinstance(index, slice): - start, stop, step = index.indices(len(self)) - vals = [self[i] for i in xrange(start, stop, step)] - return vals - else: - raise TypeError("index must be an integer or slice") - - class _ControlFlowOpOutputs(object): - """An indirection to capture the output tensors needed in backprop.""" - - def __init__(self, op, grad_state): - self._op = op - self._grad_state = grad_state - self._outputs = None - - def __len__(self): - return len(self._op._outputs) - - def __getitem__(self, index): - if self._outputs is None: - self._outputs = [None for _ in self._op.outputs] - if isinstance(index, int): - val = self._outputs[index] - if val is None: - f_val = self._op.outputs[index] - val = self._grad_state.GetRealValue(f_val) - self._outputs[index] = val - return val - elif isinstance(index, slice): - start, stop, step = index.indices(len(self)) - vals = [self[i] for i in xrange(start, stop, step)] - return vals - else: - raise TypeError("index must be an integer or slice") - - def __init__(self, op, grad_state): - self._grad_state = grad_state # The GradLoopState this op belongs to. - self._op = op - self._inputs = None - self._outputs = None - - @property - def grad_state(self): - return self._grad_state - - @property - def inputs(self): - if self._inputs is None: - self._inputs = self._ControlFlowOpInputs(self._op, self._grad_state) - return self._inputs - - @property - def outputs(self): - if self._outputs is None: - self._outputs = self._ControlFlowOpOutputs(self._op, self._grad_state) - return self._outputs - - @property - def op(self): - return self._op - - @property - def name(self): - """Returns the name of this instance of op.""" - return self._op.name - - @property - def _id(self): - """Returns the unique id of this operation.""" - return self._op._id - - @property - def device(self): - """Returns the device of this operation. - - Returns: - a string or None if the device was not set. - """ - return self._op.device - - @property - def type(self): - """Returns the type of the op.""" - return self._op.type - - @property - def graph(self): - """The `Graph` that contains this operation.""" - return self._op.graph - - def get_attr(self, name): - """Returns the value of the attr of this op with the given `name`.""" - return self._op.get_attr(name) - - def _get_control_flow_context(self): - """Returns the control flow context of this op.""" - return self._op._get_control_flow_context() - - def _IsLoopConstantEnter(op): - """Returns true iff op is a loop invariant.""" + """Return true iff op is a loop invariant.""" is_enter = (op.type == "Enter" or op.type == "RefEnter") return is_enter and op.get_attr("is_constant") +def _GetLoopConstantEnter(value): + """Return the enter op if we can infer `value` to be a loop invariant.""" + id_ops = {"Switch", "RefSwitch", "Identity", "RefIdentity"} + op = value.op + while op.type in id_ops: + op = op.inputs[0].op + return op if _IsLoopConstantEnter(op) else None + + def _IsLoopExit(op): return op.type == "Exit" or op.type == "RefExit" @@ -531,7 +410,8 @@ class GradLoopState(object): self._grad_context = WhileContext(forward_ctxt.parallel_iterations, forward_ctxt.back_prop, forward_ctxt.swap_memory, - forward_ctxt.name) + forward_ctxt.name, + self) real_cnt = outer_grad_state.AddBackPropAccumulatedValue(history_cnt, cnt) self._grad_index = self._grad_context.AddBackPropCounter(real_cnt) outer_grad_ctxt.Exit() @@ -540,7 +420,8 @@ class GradLoopState(object): self._grad_context = WhileContext(forward_ctxt.parallel_iterations, forward_ctxt.back_prop, forward_ctxt.swap_memory, - forward_ctxt.name) + forward_ctxt.name, + self) self._grad_index = self._grad_context.AddBackPropCounter(cnt) if outer_forward_ctxt: outer_forward_ctxt.Exit() @@ -629,55 +510,59 @@ class GradLoopState(object): edge from the push op to either `forward_index.op` or `forward_sync`. Args: - value: The tensor that is to be accumulated. + value: The source tensor in forward that is to be accumulated. dead_branch: True iff the tensor is on a dead branch of a cond. Returns: The stack that contains the accumulated history of the tensor. """ - # TODO(yuanbyu): Make sure the colocation of stack ops and value. - # pylint: disable=protected-access - acc = gen_data_flow_ops._stack(value.dtype.base_dtype, name="f_acc") - # pylint: enable=protected-access - - # Make acc available in the forward context. - enter_acc = self.forward_context.AddValue(acc) - - # Add the stack_push op in the context of value.op. - swap_enabled = self.forward_context.swap_memory - value_ctxt = value.op._get_control_flow_context() - if _IsLoopExit(value.op): - value_ctxt = value_ctxt.outer_context - if value_ctxt == self.forward_context: - # value is not nested in the forward context. - self.forward_context.Enter() - push = gen_data_flow_ops._stack_push(enter_acc, value, - swap_memory=swap_enabled) - self.forward_context.Exit() - # Protect stack push and order it before forward_index. - self.forward_index.op._add_control_input(push.op) - else: - # value is in a cond context within the forward context. - assert isinstance(value_ctxt, CondContext) - if dead_branch: - # The special case for creating a zero tensor for a dead - # branch of a switch. See ControlFlowState.ZerosLike(). - value_ctxt.outer_context.Enter() - push = gen_data_flow_ops._stack_push(enter_acc, value, - swap_memory=swap_enabled) - value_ctxt.outer_context.Exit() - push.op._set_control_flow_context(value_ctxt) + curr_ctxt = ops.get_default_graph()._get_control_flow_context() + with ops.control_dependencies(None): + if curr_ctxt: curr_ctxt.Enter() + with ops.colocate_with(value): + # pylint: disable=protected-access + acc = gen_data_flow_ops._stack(value.dtype.base_dtype, name="f_acc") + # pylint: enable=protected-access + if curr_ctxt: curr_ctxt.Exit() + + # Make acc available in the forward context. + enter_acc = self.forward_context.AddValue(acc) + + # Add the stack_push op in the context of value.op. + swap_enabled = self.forward_context.swap_memory + value_ctxt = value.op._get_control_flow_context() + if _IsLoopExit(value.op): + value_ctxt = value_ctxt.outer_context + if value_ctxt == self.forward_context: + # value is not nested in the forward context. + self.forward_context.Enter() + push = gen_data_flow_ops._stack_push( + enter_acc, value, swap_memory=swap_enabled) + self.forward_context.Exit() + # Protect stack push and order it before forward_index. + self.forward_index.op._add_control_input(push.op) else: - value_ctxt.Enter() - push = gen_data_flow_ops._stack_push(enter_acc, value, - swap_memory=swap_enabled) - value_ctxt.Exit() - # Protect stack push and order it before forward_sync. - self.forward_sync._add_control_input(push.op) - # Order stack push after the successor of forward_index - add_op = self.forward_index.op.inputs[0].op - push.op._add_control_input(add_op) - return acc + # value is in a cond context within the forward context. + assert isinstance(value_ctxt, CondContext) + if dead_branch: + # The special case for creating a zero tensor for a dead + # branch of a switch. See ControlFlowState.ZerosLike(). + value_ctxt.outer_context.Enter() + push = gen_data_flow_ops._stack_push( + enter_acc, value, swap_memory=swap_enabled) + value_ctxt.outer_context.Exit() + push.op._set_control_flow_context(value_ctxt) + else: + value_ctxt.Enter() + push = gen_data_flow_ops._stack_push( + enter_acc, value, swap_memory=swap_enabled) + value_ctxt.Exit() + # Protect stack push and order it before forward_sync. + self.forward_sync._add_control_input(push.op) + # Order stack push after the successor of forward_index + add_op = self.forward_index.op.inputs[0].op + push.op._add_control_input(add_op) + return acc def AddBackPropAccumulatedValue(self, history_value, value, dead_branch=False): @@ -704,60 +589,67 @@ class GradLoopState(object): cond_ctxt = value_ctxt break value_ctxt = value_ctxt.outer_context - if cond_ctxt: - # Guard stack pop with a switch if it is controlled by a cond - grad_state = self - pred = None - while not pred and grad_state: - pred = grad_state.history_map.get(cond_ctxt.pred.name) - grad_state = grad_state.outer_grad_state - branch = (1 - cond_ctxt.branch) if dead_branch else cond_ctxt.branch - history_value = _SwitchRefOrTensor(history_value, pred)[branch] - pop = gen_data_flow_ops._stack_pop(history_value, value.dtype.base_dtype) + with ops.control_dependencies(None): + self.grad_context.Enter() + if cond_ctxt: + # Guard stack pop with a switch if it is controlled by a cond + grad_state = self + pred = None + while not pred and grad_state: + pred = grad_state.history_map.get(cond_ctxt.pred.name) + grad_state = grad_state.outer_grad_state + branch = (1 - cond_ctxt.branch) if dead_branch else cond_ctxt.branch + history_value = _SwitchRefOrTensor(history_value, pred)[branch] + pop = gen_data_flow_ops._stack_pop(history_value, value.dtype.base_dtype) + self.grad_context.Exit() if self.grad_context.parallel_iterations > 1: # All pops are ordered after pivot_for_body and before grad_sync. self.grad_sync._add_control_input(pop.op) return pop def GetRealValue(self, value): - """Get the real value. + """Get the real value of `value`. - If backprop "uses" a value produced by forward inference, an - accumulator is added in the forward loop to accumulate its values. - We use the accumulated value. + If backprop "uses" a value produced by forward inference, an accumulator + is added in the forward loop to accumulate its values. We use the + accumulated value. This method must be called in the grad loop context. + `value` must be in forward and needed for backprop. Args: value: A tensor to be captured. Returns: - The same tensor value from the saved history. + The same tensor obtained from the saved history. """ assert value.op.type != "Variable" real_value = self._history_map.get(value.name) if real_value is None: - if _IsLoopConstantEnter(value.op): - # Special case for loop invariant. - if self._outer_grad_state: - # This is a nested loop so we record the history of this - # value in outer_forward_ctxt. + cur_value = value + cur_grad_state = self + while True: + enter_op = _GetLoopConstantEnter(cur_value) + if enter_op: + # Special case: cur_value comes from a constant Enter node. + cur_value = enter_op.inputs[0] + if self._outer_grad_state: + cur_grad_state = cur_grad_state.outer_grad_state + else: + # We are now outside all nested loops for this gradient(), + # so `value` is a loop invariant and there is no need to + # save the history of value. + real_value = self._grad_context.AddValue(cur_value) + break + else: + # Record the history of this value in forward_ctxt. + # TODO(yuanbyu): Avoid recording constants. self._grad_context.Exit() - outer_value = value.op.inputs[0] - history_value = self._outer_grad_state.AddForwardAccumulator( - outer_value) + h_value = cur_grad_state.AddForwardAccumulator(cur_value) self._grad_context.Enter() - else: - # Just use the input value of this Enter node. - real_value = GetRealOp(value.op).inputs[0] - else: - # Record the history of this value in forward_ctxt. - # NOTE(yuanbyu): Don't record for constants. - self._grad_context.Exit() - history_value = self.AddForwardAccumulator(value) - self._grad_context.Enter() + break if real_value is None: # Add the stack pop op in the grad context. - real_value = self.AddBackPropAccumulatedValue(history_value, value) + real_value = self.AddBackPropAccumulatedValue(h_value, value) self._history_map[value.name] = real_value return real_value @@ -776,9 +668,9 @@ class ControlFlowState(object): def __init__(self): self._map = {} # maps forward loop context to GradLoopState - def _GetGradState(self, op): - """Get the gradient loop state for this op if any.""" - if _IsLoopExit(op): + def _GetGradState(self, op, before): + """Return the grad state for this op if it's in a forward loop context.""" + if before and _IsLoopExit(op): forward_ctxt = op._get_control_flow_context() forward_ctxt = forward_ctxt.outer_context if forward_ctxt: @@ -789,15 +681,6 @@ class ControlFlowState(object): return self._map.get(forward_ctxt) return None - def MakeWrapper(self, op): - """Make a wrapper for op if it is in a WhileContext.""" - forward_ctxt = _GetWhileContext(op) - if forward_ctxt: - grad_state = self._map.get(forward_ctxt) - if grad_state: - return ControlFlowOpWrapper(op, grad_state) - return op - def GetAllLoopExits(self): """Return a list containing the exits of all the loops.""" loop_exits = [] @@ -806,15 +689,15 @@ class ControlFlowState(object): loop_exits.append(loop_exit) return loop_exits - def EnterGradWhileContext(self, op): + def EnterGradWhileContext(self, op, before): """Enter the WhileContext for gradient computation.""" - grad_state = self._GetGradState(op) + grad_state = self._GetGradState(op, before) if grad_state: grad_state.grad_context.Enter() - def ExitGradWhileContext(self, op): + def ExitGradWhileContext(self, op, before): """Exit the WhileContext for gradient computation.""" - grad_state = self._GetGradState(op) + grad_state = self._GetGradState(op, before) if grad_state: grad_state.grad_context.Exit() @@ -877,12 +760,18 @@ class ControlFlowState(object): result = array_ops.zeros(val_shape.dims, val.dtype) outer_grad_state.grad_context.Exit() else: - history_val = outer_grad_state.AddForwardAccumulator(val) + # Only the shape of value is needed for backprop. + forward_ctxt.outer_context.Enter() + shape = array_ops.shape(value) + forward_ctxt.outer_context.Exit() + # Save the shape to a stack. + history_shape = outer_grad_state.AddForwardAccumulator(shape) + # Get the shape back from the stack. outer_grad_ctxt = outer_grad_state.grad_context outer_grad_ctxt.Enter() - real_val = outer_grad_state.AddBackPropAccumulatedValue( - history_val, val) - result = array_ops.zeros_like(real_val) + real_shape = outer_grad_state.AddBackPropAccumulatedValue( + history_shape, shape) + result = array_ops.zeros(real_shape, value.dtype) outer_grad_ctxt.Exit() else: # This is not a nested loop. @@ -943,23 +832,17 @@ class ControlFlowState(object): # Add forward accumulator for shape. grad_state.grad_context.Exit() - history_shape = grad_state.AddForwardAccumulator(zeros_shape, dead_branch) + h_shape = grad_state.AddForwardAccumulator( + zeros_shape, dead_branch=dead_branch) grad_state.grad_context.Enter() # Create a zero tensor with the right shape. shape = grad_state.AddBackPropAccumulatedValue( - history_shape, zeros_shape, dead_branch) + h_shape, zeros_shape, dead_branch) result = array_ops.zeros(shape, val.dtype) return result -def GetRealOp(op): - """Get the real op by removing the wrapper.""" - while isinstance(op, ControlFlowOpWrapper): - op = op.op - return op - - def MaybeCreateControlFlowState(between_op_list, between_ops): """Create the state for all the while loops involved in one gradients(). @@ -1106,6 +989,9 @@ class CondContext(ControlFlowContext): return result def AddOp(self, op): + self._AddOpInternal(op) + + def _AddOpInternal(self, op): """Add `op` to the current context.""" if not op.inputs: # Add this op to the enclosing while context @@ -1248,7 +1134,8 @@ def cond(pred, fn1, fn2, name=None): class WhileContext(ControlFlowContext): """The context for the loop construct.""" - def __init__(self, parallel_iterations, back_prop, swap_memory, name): + def __init__(self, parallel_iterations, back_prop, swap_memory, name, + grad_state=None): ControlFlowContext.__init__(self) self._name = ops.get_default_graph().unique_name(name) self._parallel_iterations = parallel_iterations @@ -1263,6 +1150,8 @@ class WhileContext(ControlFlowContext): self._pivot = None # The list of exit tensors for loop variables. self._loop_exits = None + # The gradient loop state + self._grad_state = grad_state @property def name(self): @@ -1293,6 +1182,11 @@ class WhileContext(ControlFlowContext): """The list of exit tensors for loop variables.""" return self._loop_exits + @property + def grad_state(self): + """The gradient loop state.""" + return self._grad_state + def GetWhileContext(self): return self @@ -1306,6 +1200,22 @@ class WhileContext(ControlFlowContext): result = val if val.name not in self._values: self._values.add(val.name) + + # If we are in a grad context and val is from its forward context, + # use GetRealValue(), which adds the logic to save the history of + # val in forward. + grad_ctxt = ops.get_default_graph()._get_control_flow_context() + if grad_ctxt: + grad_ctxt = grad_ctxt.GetWhileContext() + if grad_ctxt.grad_state: + forward_ctxt = _GetWhileContext(val.op) + if _IsLoopExit(val.op): + forward_ctxt = forward_ctxt.outer_context + if forward_ctxt == grad_ctxt.grad_state.forward_context: + real_val = grad_ctxt.grad_state.GetRealValue(val) + self._external_values[val.name] = real_val + return real_val + if self._outer_context is not None: result = self._outer_context.AddValue(val) # Create an Enter to make `result` known to this loop context. @@ -1327,7 +1237,27 @@ class WhileContext(ControlFlowContext): return result def AddOp(self, op): - """Adds `op` to the current context.""" + """Add `op` to the current context.""" + # For a reduction op, if op is in a grad context and its input is from + # its forward context, moving op to the forward context means we would + # store the tensor after the reduction as opposed to the tensor before + # reduction, and therefore could significantly reduce memory consumption. + # For now, we do this only for a few ops. + if op.type in {"Shape", "Size", "Rank"}: + grad_ctxt = ops.get_default_graph()._get_control_flow_context() + if grad_ctxt: + grad_ctxt = grad_ctxt.GetWhileContext() + if grad_ctxt.grad_state: + op_input_forward_ctxt = _GetWhileContext(op.inputs[0].op) + if op_input_forward_ctxt == grad_ctxt.grad_state.forward_context: + op_input_ctxt = op.inputs[0].op._get_control_flow_context() + op._set_control_flow_context(op_input_ctxt) + op_input_ctxt._AddOpInternal(op) + return + self._AddOpInternal(op) + + def _AddOpInternal(self, op): + """Add `op` to the current context.""" if not op.inputs: if not op.control_inputs: # Add a control edge from the control pivot to this op. @@ -1863,7 +1793,6 @@ def foldr(fn, elems, initializer=None, name=None): fn: The function to be performed. elems: A tensor that is unpacked into a sequence of tensors to apply `fn`. initializer: (optional) The initial value for the accumulator. - use_tensor_array: (optional) use tensor_array if true. name: (optional) Name prefix for the returned tensors. Returns: diff --git a/tensorflow/python/ops/gradients.py b/tensorflow/python/ops/gradients.py index ced841f269..9fc1aa80d1 100644 --- a/tensorflow/python/ops/gradients.py +++ b/tensorflow/python/ops/gradients.py @@ -427,14 +427,14 @@ def gradients(ys, op = queue.popleft() with _maybe_colocate_with(op, colocate_gradients_with_ops): if loop_state: - loop_state.EnterGradWhileContext(op) + loop_state.EnterGradWhileContext(op, before=True) out_grads = _AggregatedGrads(grads, op, loop_state, aggregation_method) - grad_fn = None + if loop_state: + loop_state.ExitGradWhileContext(op, before=True) + grad_fn = None # pylint: disable=protected-access is_func_call = ops.get_default_graph()._is_function(op.type) - # pylint: enable=protected-access - if not is_func_call and any(out_grads) and op._id not in stop_ops: # pylint: enable=protected-access # A grad_fn must be defined, either as a function or as None @@ -445,6 +445,9 @@ def gradients(ys, raise LookupError( "No gradient defined for operation '%s' (op type: %s)" % (op.name, op.type)) + + if loop_state: + loop_state.EnterGradWhileContext(op, before=False) if (grad_fn or is_func_call) and any(out_grads): # NOTE: If _AggregatedGrads didn't compute a value for the i'th # output, it means that the cost does not depend on output[i], @@ -461,9 +464,6 @@ def gradients(ys, # pylint: disable=protected-access with ops.get_default_graph()._original_op(op): # pylint: enable=protected-access - wrapped_op = op - if loop_state: - wrapped_op = loop_state.MakeWrapper(op) if is_func_call: # For function call ops, we add a 'SymbolicGradient' # node to the graph to compute gradients. @@ -474,7 +474,7 @@ def gradients(ys, f_in, f_types, op.type)) # pylint: enable=protected-access else: - in_grads = _AsList(grad_fn(wrapped_op, *out_grads)) + in_grads = _AsList(grad_fn(op, *out_grads)) _VerifyGeneratedGradients(in_grads, op) if gate_gradients and len(tuple(filter(None, in_grads))) > 1: in_grads = control_flow_ops.tuple(in_grads) @@ -491,7 +491,7 @@ def gradients(ys, if in_grad: _SetGrad(grads, t_in, in_grad) if loop_state: - loop_state.ExitGradWhileContext(op) + loop_state.ExitGradWhileContext(op, before=False) # update pending count for the inputs of op. # pylint: disable=protected-access diff --git a/tensorflow/python/ops/image_ops.py b/tensorflow/python/ops/image_ops.py index 619922983a..d65910abb2 100644 --- a/tensorflow/python/ops/image_ops.py +++ b/tensorflow/python/ops/image_ops.py @@ -644,7 +644,8 @@ def resize_images(images, new_width_const = tensor_util.constant_value(new_width) new_height_const = tensor_util.constant_value(new_height) - if width == new_width_const and height == new_height_const: + if new_width_const is not None and new_height_const is not None and ( + width == new_width_const and height == new_height_const): if not is_batch: images = array_ops.squeeze(images, squeeze_dims=[0]) return images diff --git a/tensorflow/python/ops/rnn.py b/tensorflow/python/ops/rnn.py index e9a029259c..611f5fa314 100644 --- a/tensorflow/python/ops/rnn.py +++ b/tensorflow/python/ops/rnn.py @@ -269,6 +269,10 @@ def _reverse_seq(input_seq, lengths): # Join into (time, batch_size, depth) s_joined = array_ops.pack(input_seq) + # TODO(schuster, ebrevdo): Remove cast when reverse_sequence takes int32 + if lengths is not None: + lengths = math_ops.to_int64(lengths) + # Reverse along dimension 0 s_reversed = array_ops.reverse_sequence(s_joined, lengths, 0, 1) # Split again into list @@ -346,9 +350,9 @@ def bidirectional_rnn(cell_fw, cell_bw, inputs, return (outputs, output_state_fw, output_state_bw) -def dynamic_rnn(cell, inputs, sequence_length, initial_state=None, dtype=None, - parallel_iterations=None, swap_memory=False, time_major=False, - scope=None): +def dynamic_rnn(cell, inputs, sequence_length=None, initial_state=None, + dtype=None, parallel_iterations=None, swap_memory=False, + time_major=False, scope=None): """Creates a recurrent neural network specified by RNNCell "cell". This function is functionally identical to the function `rnn` above, but @@ -369,9 +373,9 @@ def dynamic_rnn(cell, inputs, sequence_length, initial_state=None, dtype=None, `[batch_size, max_time, cell.input_size]`. If time_major == True, this must be a tensor of shape: `[max_time, batch_size, cell.input_size]`. - sequence_length: An int32/int64 vector (tensor) size [batch_size]. + sequence_length: (optional) An int32/int64 vector sized `[batch_size]`. initial_state: (optional) An initial state for the RNN. This must be - a tensor of appropriate type and shape [batch_size x cell.state_size]. + a tensor of appropriate type and shape `[batch_size x cell.state_size]`. dtype: (optional) The data type for the initial state. Required if initial_state is not provided. parallel_iterations: (Default: 32). The number of iterations to run in @@ -415,8 +419,10 @@ def dynamic_rnn(cell, inputs, sequence_length, initial_state=None, dtype=None, inputs = array_ops.transpose(inputs, [1, 0, 2]) # (B,T,D) => (T,B,D) parallel_iterations = parallel_iterations or 32 - sequence_length = math_ops.to_int32(sequence_length) - sequence_length = array_ops.identity(sequence_length, name="sequence_length") + if sequence_length is not None: + sequence_length = math_ops.to_int32(sequence_length) + sequence_length = array_ops.identity( # Just to find it in the graph. + sequence_length, name="sequence_length") # Create a new scope in which the caching device is either # determined by the parent scope, or is set to place the cached @@ -442,15 +448,16 @@ def dynamic_rnn(cell, inputs, sequence_length, initial_state=None, dtype=None, ["Expected shape for Tensor %s is " % x.name, packed_shape, " but saw shape: ", x_shape]) - # Perform some shape validation - with ops.control_dependencies( - [_assert_has_shape(sequence_length, [batch_size])]): - sequence_length = array_ops.identity(sequence_length, name="CheckSeqLen") + if sequence_length is not None: + # Perform some shape validation + with ops.control_dependencies( + [_assert_has_shape(sequence_length, [batch_size])]): + sequence_length = array_ops.identity( + sequence_length, name="CheckSeqLen") (outputs, final_state) = _dynamic_rnn_loop( - cell, inputs, state, sequence_length, - parallel_iterations=parallel_iterations, - swap_memory=swap_memory) + cell, inputs, state, parallel_iterations=parallel_iterations, + swap_memory=swap_memory, sequence_length=sequence_length) # Outputs of _dynamic_rnn_loop are always shaped [time, batch, depth]. # If we are performing batch-major calculations, transpose output back @@ -461,17 +468,18 @@ def dynamic_rnn(cell, inputs, sequence_length, initial_state=None, dtype=None, return (outputs, final_state) -def _dynamic_rnn_loop(cell, inputs, initial_state, sequence_length, - parallel_iterations, swap_memory): +def _dynamic_rnn_loop( + cell, inputs, initial_state, parallel_iterations, swap_memory, + sequence_length=None): """Internal implementation of Dynamic RNN. Args: cell: An instance of RNNCell. inputs: A `Tensor` of shape [time, batch_size, depth]. initial_state: A `Tensor` of shape [batch_size, depth]. - sequence_length: An `int32` `Tensor` of shape [batch_size]. parallel_iterations: Positive Python int. swap_memory: A Python boolean + sequence_length: (optional) An `int32` `Tensor` of shape [batch_size]. Returns: Tuple (final_outputs, final_state). @@ -502,8 +510,9 @@ def _dynamic_rnn_loop(cell, inputs, initial_state, sequence_length, # Prepare dynamic conditional copying of state & output zero_output = array_ops.zeros( array_ops.pack([batch_size, cell.output_size]), inputs.dtype) - min_sequence_length = math_ops.reduce_min(sequence_length) - max_sequence_length = math_ops.reduce_max(sequence_length) + if sequence_length is not None: + min_sequence_length = math_ops.reduce_min(sequence_length) + max_sequence_length = math_ops.reduce_max(sequence_length) time = array_ops.constant(0, dtype=dtypes.int32, name="time") @@ -536,9 +545,14 @@ def _dynamic_rnn_loop(cell, inputs, initial_state, sequence_length, # Restore some shape information input_t.set_shape([const_batch_size, const_depth]) - (output, new_state) = _rnn_step( - time, sequence_length, min_sequence_length, max_sequence_length, - zero_output, state, lambda: cell(input_t, state)) + call_cell = lambda: cell(input_t, state) + + if sequence_length is not None: + (output, new_state) = _rnn_step( + time, sequence_length, min_sequence_length, max_sequence_length, + zero_output, state, call_cell) + else: + (output, new_state) = call_cell() output_ta_t = output_ta_t.write(time, output) diff --git a/tensorflow/python/tensorflow.i b/tensorflow/python/tensorflow.i index 65ea4d2e17..766bdf7dd3 100644 --- a/tensorflow/python/tensorflow.i +++ b/tensorflow/python/tensorflow.i @@ -28,5 +28,6 @@ limitations under the License. %include "tensorflow/python/client/events_writer.i" %include "tensorflow/python/client/tf_session.i" +%include "tensorflow/python/client/server_lib.i" %include "tensorflow/python/framework/python_op_gen.i" diff --git a/tensorflow/python/training/coordinator.py b/tensorflow/python/training/coordinator.py index 6bc36429d9..661bae7bc1 100644 --- a/tensorflow/python/training/coordinator.py +++ b/tensorflow/python/training/coordinator.py @@ -335,7 +335,6 @@ class LooperThread(threading.Thread): looper.start() return looper - # pylint: disable=broad-except def run(self): with self._coord.stop_on_exception(): self.start_loop() @@ -349,12 +348,16 @@ class LooperThread(threading.Thread): while not self._coord.wait_for_stop(next_timer_time - time.time()): next_timer_time += self._timer_interval_secs self.run_loop() - # pylint: enable=broad-except + self.stop_loop() def start_loop(self): """Called when the thread starts.""" pass + def stop_loop(self): + """Called when the thread stops.""" + pass + def run_loop(self): """Called at 'timer_interval_secs' boundaries.""" if self._target: diff --git a/tensorflow/tensorboard/components/tf-categorizer/test/categorizerTest.ts b/tensorflow/tensorboard/components/tf-categorizer/test/categorizerTest.ts index 8f858becb2..ede6a1f5a3 100644 --- a/tensorflow/tensorboard/components/tf-categorizer/test/categorizerTest.ts +++ b/tensorflow/tensorboard/components/tf-categorizer/test/categorizerTest.ts @@ -12,8 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ - -/// <reference path="../categorizer.ts" /> var assert = chai.assert; module Categorizer { diff --git a/tensorflow/tensorboard/components/tf-categorizer/test/index.html b/tensorflow/tensorboard/components/tf-categorizer/test/index.html new file mode 100644 index 0000000000..fd4a097708 --- /dev/null +++ b/tensorflow/tensorboard/components/tf-categorizer/test/index.html @@ -0,0 +1,13 @@ +<!doctype html> +<html> +<head> + <meta charset="utf-8"> + <script src="../../webcomponentsjs/webcomponents-lite.min.js"></script> + <script src="../../web-component-tester/browser.js"></script> + <link rel="import" href="../../tf-imports/d3.html"> +</head> +<body> + <script src="../categorizer.js"></script> + <script src="categorizerTest.js"></script> +</body> +</html> diff --git a/tensorflow/tensorboard/components/tf-dashboard-common/urlGenerator.ts b/tensorflow/tensorboard/components/tf-dashboard-common/urlGenerator.ts index 7148fd3fce..00c593a049 100644 --- a/tensorflow/tensorboard/components/tf-dashboard-common/urlGenerator.ts +++ b/tensorflow/tensorboard/components/tf-dashboard-common/urlGenerator.ts @@ -12,9 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ - -/// <reference path="../plottable/plottable.d.ts" /> - module TF { export module Urls { export type RunTagUrlFn = (tag: string, run: string) => string; @@ -69,7 +66,21 @@ module TF { }; }; - export function demoRouter(dataDir: string): Router { + export function demoRouter(dataDir: string, + oldVersion = false): Router { + if (oldVersion) { + return { + runs: () => dataDir + "runs.json", + graph: (run) => dataDir + run + "-graph.pbtxt", + scalars: (tag, run) => { + return dataDir + run.split("_")[0] + ".json"; + }, + histograms: () => null, + compressedHistograms: () => null, + images: () => null, + individualImage: () => null + }; + } /* Retrieves static .json data generated by demo_from_server.py */ function demoRoute(route) { return function(tag, run) { diff --git a/tensorflow/tensorboard/components/tf-event-dashboard/dataCoordinator.ts b/tensorflow/tensorboard/components/tf-event-dashboard/dataCoordinator.ts index 489a2138f0..5407800710 100644 --- a/tensorflow/tensorboard/components/tf-event-dashboard/dataCoordinator.ts +++ b/tensorflow/tensorboard/components/tf-event-dashboard/dataCoordinator.ts @@ -12,9 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ - -/// <reference path="../plottable/plottable.d.ts" /> - module TF { /* The DataCoordinator generates TF.Datasets for each run/tag combination, diff --git a/tensorflow/tensorboard/components/tf-event-dashboard/dataset.ts b/tensorflow/tensorboard/components/tf-event-dashboard/dataset.ts index 8ced6ad0e2..3677a300d1 100644 --- a/tensorflow/tensorboard/components/tf-event-dashboard/dataset.ts +++ b/tensorflow/tensorboard/components/tf-event-dashboard/dataset.ts @@ -12,9 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ - -/// <reference path="../plottable/plottable.d.ts" /> - module TF { /* An extension of Plottable.Dataset that knows how to load data from a backend. */ diff --git a/tensorflow/tensorboard/components/tf-event-dashboard/tf-chart.ts b/tensorflow/tensorboard/components/tf-event-dashboard/tf-chart.ts index 05fcf6b3e9..d799c190cf 100644 --- a/tensorflow/tensorboard/components/tf-event-dashboard/tf-chart.ts +++ b/tensorflow/tensorboard/components/tf-event-dashboard/tf-chart.ts @@ -12,9 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ - -/// <reference path="../plottable/plottable.d.ts" /> - module TF { type TFDatum = [number, number, number]; type tooltipMap = {[run: string]: string}; diff --git a/tensorflow/tensorboard/components/tf-graph-common/lib/graph.ts b/tensorflow/tensorboard/components/tf-graph-common/lib/graph.ts index ed89706b45..b2f6d21598 100644 --- a/tensorflow/tensorboard/components/tf-graph-common/lib/graph.ts +++ b/tensorflow/tensorboard/components/tf-graph-common/lib/graph.ts @@ -12,8 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ - -/// <reference path="common.ts" /> module tf.graph { /** Delimiter used in node names to denote namespaces. */ diff --git a/tensorflow/tensorboard/components/tf-graph-common/lib/hierarchy.ts b/tensorflow/tensorboard/components/tf-graph-common/lib/hierarchy.ts index 98f34bdd3f..af5c1e97b6 100644 --- a/tensorflow/tensorboard/components/tf-graph-common/lib/hierarchy.ts +++ b/tensorflow/tensorboard/components/tf-graph-common/lib/hierarchy.ts @@ -12,10 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ - -/// <reference path="graph.ts" /> -/// <reference path="template.ts" /> - /** * Package for the Graph Hierarchy for TensorFlow graph. */ diff --git a/tensorflow/tensorboard/components/tf-graph-common/lib/layout.ts b/tensorflow/tensorboard/components/tf-graph-common/lib/layout.ts index 0e7b1d17d5..0d9e5b53bf 100644 --- a/tensorflow/tensorboard/components/tf-graph-common/lib/layout.ts +++ b/tensorflow/tensorboard/components/tf-graph-common/lib/layout.ts @@ -12,10 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ - -/// <reference path="graph.ts" /> -/// <reference path="render.ts" /> - module tf.graph.layout { /** Set of parameters that define the look and feel of the graph. */ diff --git a/tensorflow/tensorboard/components/tf-graph-common/lib/parser.ts b/tensorflow/tensorboard/components/tf-graph-common/lib/parser.ts index f88da0dd33..6d1aa875ee 100644 --- a/tensorflow/tensorboard/components/tf-graph-common/lib/parser.ts +++ b/tensorflow/tensorboard/components/tf-graph-common/lib/parser.ts @@ -12,8 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ - -/// <reference path="common.ts" /> module tf.graph.parser { /** diff --git a/tensorflow/tensorboard/components/tf-graph-common/lib/render.ts b/tensorflow/tensorboard/components/tf-graph-common/lib/render.ts index b0ee19a25e..fa0ee99d19 100644 --- a/tensorflow/tensorboard/components/tf-graph-common/lib/render.ts +++ b/tensorflow/tensorboard/components/tf-graph-common/lib/render.ts @@ -12,14 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ - -/// <reference path="graph.ts" /> -/// <reference path="hierarchy.ts" /> - /** * Package for the Render Hierarchy for TensorFlow graph. */ - module tf.graph.render { export type Point = {x: number, y: number}; diff --git a/tensorflow/tensorboard/components/tf-graph-common/lib/scene/annotation.ts b/tensorflow/tensorboard/components/tf-graph-common/lib/scene/annotation.ts index a50d31b5b9..b48d62c346 100644 --- a/tensorflow/tensorboard/components/tf-graph-common/lib/scene/annotation.ts +++ b/tensorflow/tensorboard/components/tf-graph-common/lib/scene/annotation.ts @@ -12,13 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ - -/// <reference path="../graph.ts" /> -/// <reference path="../render.ts" /> -/// <reference path="scene.ts" /> -/// <reference path="edge.ts" /> -/// <reference path="contextmenu.ts" /> - module tf.graph.scene.annotation { /** diff --git a/tensorflow/tensorboard/components/tf-graph-common/lib/scene/edge.ts b/tensorflow/tensorboard/components/tf-graph-common/lib/scene/edge.ts index d0f1e8fad6..2938aa3f1d 100644 --- a/tensorflow/tensorboard/components/tf-graph-common/lib/scene/edge.ts +++ b/tensorflow/tensorboard/components/tf-graph-common/lib/scene/edge.ts @@ -12,11 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ - -/// <reference path="../graph.ts" /> -/// <reference path="../render.ts" /> -/// <reference path="scene.ts" /> - module tf.graph.scene.edge { /** Delimiter between dimensions when showing sizes of tensors. */ diff --git a/tensorflow/tensorboard/components/tf-graph-common/lib/scene/minimap.ts b/tensorflow/tensorboard/components/tf-graph-common/lib/scene/minimap.ts index 72464c69c4..bd8917929f 100644 --- a/tensorflow/tensorboard/components/tf-graph-common/lib/scene/minimap.ts +++ b/tensorflow/tensorboard/components/tf-graph-common/lib/scene/minimap.ts @@ -12,9 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ - -/// <reference path="../common.ts" /> - module tf.scene { /** Show minimap when the viewpoint area is less than X% of the whole area. */ diff --git a/tensorflow/tensorboard/components/tf-graph-common/lib/scene/node.ts b/tensorflow/tensorboard/components/tf-graph-common/lib/scene/node.ts index cef46578b5..f2e73976ff 100644 --- a/tensorflow/tensorboard/components/tf-graph-common/lib/scene/node.ts +++ b/tensorflow/tensorboard/components/tf-graph-common/lib/scene/node.ts @@ -12,12 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ - -/// <reference path="../graph.ts" /> -/// <reference path="scene.ts" /> -/// <reference path="annotation.ts" /> -/// <reference path="contextmenu.ts" /> - module tf.graph.scene.node { /** diff --git a/tensorflow/tensorboard/components/tf-graph-common/lib/scene/scene.ts b/tensorflow/tensorboard/components/tf-graph-common/lib/scene/scene.ts index 685ad646f7..b6eb3f7d81 100644 --- a/tensorflow/tensorboard/components/tf-graph-common/lib/scene/scene.ts +++ b/tensorflow/tensorboard/components/tf-graph-common/lib/scene/scene.ts @@ -12,12 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ - -/// <reference path="../graph.ts" /> -/// <reference path="edge.ts" /> -/// <reference path="node.ts" /> -/// <reference path="../layout.ts" /> - module tf.graph.scene { /** Enums element class of objects in the scene */ diff --git a/tensorflow/tensorboard/components/tf-graph-common/lib/template.ts b/tensorflow/tensorboard/components/tf-graph-common/lib/template.ts index 0423e1c863..93d1540939 100644 --- a/tensorflow/tensorboard/components/tf-graph-common/lib/template.ts +++ b/tensorflow/tensorboard/components/tf-graph-common/lib/template.ts @@ -12,10 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ - -/// <reference path="graph.ts" /> -/// <reference path="hierarchy.ts" /> - module tf.graph.template { /** diff --git a/tensorflow/tensorboard/components/tf-tensorboard/demo/index.html b/tensorflow/tensorboard/components/tf-tensorboard/demo/index.html index e97a1815c2..829769c3d0 100644 --- a/tensorflow/tensorboard/components/tf-tensorboard/demo/index.html +++ b/tensorflow/tensorboard/components/tf-tensorboard/demo/index.html @@ -1,40 +1,11 @@ <!DOCTYPE html> <html> - <head> - <script src="../../webcomponentsjs/webcomponents-lite.min.js"></script> - <link rel="import" href="../tf-tensorboard.html"> +<head> + <script src="../../webcomponentsjs/webcomponents-lite.min.js"></script> + <link rel="import" href="../tf-tensorboard-demo.html"> <link rel="stylesheet" type="text/css" href="../../../lib/css/global.css"> - <title>TensorBoard Demo</title> - </head> - <body> - <base href="/"> - <dom-module id="x-demo"> - <template> - <tf-tensorboard - id="demo" - router="[[demoRouter]]"> - </tf-tensorboard> - </template> - <script> - var dataDir = "components/tf-tensorboard/demo/data/"; - var demoRouter = { - runs: function() { return dataDir + "runs.json";}, - graph: function(run) {return dataDir + run + "-graph.pbtxt";}, - scalars: function(tag, run) { - return dataDir + run.split("_")[0] + ".json"; - }, - }; - Polymer({ - is: "x-demo", - properties: { - demoRouter: { - type: Object, - value: demoRouter, - }, - }, - }); - </script> - </dom-module> - <x-demo></x-demo> - </body> +</head> +<body> + <tf-tensorboard-demo old-version="true" data-dir="data/"></tf-tensorboard-demo> +</body> </html> diff --git a/tensorflow/tensorboard/components/tf-tensorboard/tf-tensorboard-demo.html b/tensorflow/tensorboard/components/tf-tensorboard/tf-tensorboard-demo.html index 8fe248aff0..abed65ef66 100644 --- a/tensorflow/tensorboard/components/tf-tensorboard/tf-tensorboard-demo.html +++ b/tensorflow/tensorboard/components/tf-tensorboard/tf-tensorboard-demo.html @@ -9,6 +9,7 @@ json data from a "dataDir" rather than connecting to a live backend. <tf-tensorboard id="tensorboard" router="[[_demoRouter]]" + no-hash="[[noHash]]" ></tf-tensorboard> <style> :host { @@ -23,15 +24,27 @@ json data from a "dataDir" rather than connecting to a live backend. properties: { _demoRouter: { type: Object, - computed: "_makeDemoRouter(dataDir)", + computed: "_makeDemoRouter(dataDir, oldVersion)", }, dataDir: { type: String, value: "data", }, + // To use the old version of the router which can serve the + // demo/data folder that is checked into the repository. + oldVersion: { + type: Boolean, + value: false + }, + // If true, tab switching in TensorBoard will not update + // location hash. Hash update interferes with selenium tests. + noHash: { + type: Boolean, + value: false + } }, - _makeDemoRouter: function(dataDir) { - return TF.Urls.demoRouter(dataDir); + _makeDemoRouter: function(dataDir, oldVersion) { + return TF.Urls.demoRouter(dataDir, oldVersion); }, }); </script> diff --git a/tensorflow/tensorboard/components/tf-tensorboard/tf-tensorboard.html b/tensorflow/tensorboard/components/tf-tensorboard/tf-tensorboard.html index bfcbb7ae5f..1c5ff47564 100644 --- a/tensorflow/tensorboard/components/tf-tensorboard/tf-tensorboard.html +++ b/tensorflow/tensorboard/components/tf-tensorboard/tf-tensorboard.html @@ -20,11 +20,11 @@ allows the user to toggle between various dashboards. <paper-toolbar id="toolbar"> <div id="toolbar-content"> <div class="toolbar-title">TensorBoard</div> - <paper-tabs selected="0" noink class="tabs" id="tabs"> - <paper-tab data-mode="events" on-click="changeMode">Events</paper-tab> - <paper-tab data-mode="images" on-click="changeMode">Images</paper-tab> - <paper-tab data-mode="graphs" on-click="changeMode">Graph</paper-tab> - <paper-tab data-mode="histograms" on-click="changeMode">Histograms</paper-tab> + <paper-tabs selected="{{modeIndex}}" noink class="tabs" id="tabs"> + <paper-tab data-mode="events">Events</paper-tab> + <paper-tab data-mode="images">Images</paper-tab> + <paper-tab data-mode="graphs">Graph</paper-tab> + <paper-tab data-mode="histograms">Histograms</paper-tab> </paper-tabs> </div> </paper-toolbar> @@ -111,14 +111,24 @@ allows the user to toggle between various dashboards. type: Object, value: TF.Urls.productionRouter(), }, + // Which tab is selected (events, graph, images etc). mode: { type: String, - value: "events", + computed: '_getModeFromIndex(modeIndex)' }, + // If true, tab switching in TensorBoard will not update + // location hash. Hash update interferes with selenium tests. + noHash: { + type: Boolean, + value: false + } }, - changeMode: function(ev) { - var mode = ev.target.parentElement.getAttribute('data-mode'); - this._changeMode(mode, true); + _getModeFromIndex: function(modeIndex) { + var mode = this.tabs[modeIndex]; + if (!this.noHash) { + window.location.hash = mode; + } + return mode; }, eventDashboard: function(mode) { return mode === "events"; @@ -132,36 +142,26 @@ allows the user to toggle between various dashboards. histogramDashboard: function(mode) { return mode === "histograms"; }, - loadPreviousMode: function() { - this._changeMode(this._getMode(), false); - }, ready: function() { - this._changeMode(this._getMode(), true); - - var self = this; - window.addEventListener('hashchange', function(){ - self.loadPreviousMode(); + this.tabs = [].slice.call(this.querySelectorAll('paper-tab')).map(function(a) { + return a.dataset.mode; }); + this._getModeFromHash(); + window.addEventListener('hashchange', function() { + this._getModeFromHash(); + }.bind(this)); }, - _changeMode: function(mode, isNewState) { - this.mode = mode; - - // Change the selected tab - this.$.tabs.selected = this._tabs().indexOf(mode); - - if (isNewState){ - window.location.hash = mode; - } - }, - _getMode: function() { + _getModeFromHash: function() { // Return the mode as it is stored in the hash. - // If no mode can be found, default to the first tab. - var hash = window.location.hash; - return hash.length > 0 ? hash.slice(1, hash.length) : this._tabs()[0]; - }, - _tabs: function() { - var elts = Array.prototype.slice.call(this.querySelectorAll('paper-tab')); - return elts.map(function(elt){ return elt.getAttribute('data-mode')}); + var tabName = window.location.hash.trim().slice(1); + var modeIndex = this.tabs.indexOf(tabName); + if (modeIndex == -1 && this.modeIndex == null) { + // Selecting the first tab as default. + this.set('modeIndex', 0); + } + if (modeIndex != -1 && modeIndex != this.modeIndex) { + this.set('modeIndex', modeIndex); + } }, }); </script> diff --git a/tensorflow/tensorboard/components/tf-test/index.html b/tensorflow/tensorboard/components/tf-test/index.html deleted file mode 100644 index d551750e3c..0000000000 --- a/tensorflow/tensorboard/components/tf-test/index.html +++ /dev/null @@ -1,16 +0,0 @@ -<!doctype html> -<html> -<head> - <meta charset="utf-8"> - <script src="../web-component-tester/browser.js"></script> -</head> -<body> -<script> -// Run the tests for each main component in tensorboard. -WCT.loadSuites([ - '../tf-graph-common/test/index.html', - '../tf-graph-loader/test/index.html', -]); -</script> -</body> -</html> diff --git a/tensorflow/tensorboard/gulpfile.js b/tensorflow/tensorboard/gulpfile.js index dcc79f3008..6eeb24ddbe 100644 --- a/tensorflow/tensorboard/gulpfile.js +++ b/tensorflow/tensorboard/gulpfile.js @@ -98,8 +98,7 @@ gulp.task('compile.all', ['typings'], function() { }); gulp.task('test', ['tslint-strict', 'compile.all'], function(done) { - tester({suites: ['components/tf-test/'], - plugins: {local: {}, sauce: false}}, function(error) { + tester({}, function(error) { if (error) { // Pretty error for gulp. error = new Error(error.message || error); diff --git a/tensorflow/tensorboard/lib/js/backend/test/index.html b/tensorflow/tensorboard/lib/js/backend/test/index.html index 2305cf9426..7965ce6d0b 100644 --- a/tensorflow/tensorboard/lib/js/backend/test/index.html +++ b/tensorflow/tensorboard/lib/js/backend/test/index.html @@ -14,13 +14,9 @@ limitations under the License. =============================================================================--> <!doctype html> <html> -<!-- This test file has import paths that are suitable for gulp test and - direct loading in the browser --> <head> <meta charset="utf-8"> - <script src="../../../../../components/webcomponentsjs/webcomponents-lite.min.js"></script> - <script src="../../../../components/web-component-tester/browser.js"></script> - + <script src="../../web-component-tester/browser.js"></script> </head> <body> <script src="../../requestManager/requestManager.js"></script> diff --git a/tensorflow/tensorboard/lib/js/nanite/test/index.html b/tensorflow/tensorboard/lib/js/nanite/test/index.html index 0ac18a1bf2..2a886afe62 100644 --- a/tensorflow/tensorboard/lib/js/nanite/test/index.html +++ b/tensorflow/tensorboard/lib/js/nanite/test/index.html @@ -1,13 +1,10 @@ <!doctype html> <html> -<!-- This test file has import paths that are suitable for gulp test and - direct loading in the browser --> <head> <meta charset="utf-8"> - <script src="../../../../../components/webcomponentsjs/webcomponents-lite.min.js"></script> - <script src="../../../../components/web-component-tester/browser.js"></script> - <link rel="import" href="../../../../../components/polymer/polymer.html"> - + <script src="../../webcomponentsjs/webcomponents-lite.min.js"></script> + <script src="../../web-component-tester/browser.js"></script> + <link rel="import" href="../../polymer/polymer.html"> </head> <body> <script src="../nanite.js"></script> diff --git a/tensorflow/tensorboard/lib/js/nanite/test/naniteTest.ts b/tensorflow/tensorboard/lib/js/nanite/test/naniteTest.ts index ecc792944e..ba9dce0f57 100644 --- a/tensorflow/tensorboard/lib/js/nanite/test/naniteTest.ts +++ b/tensorflow/tensorboard/lib/js/nanite/test/naniteTest.ts @@ -14,11 +14,9 @@ limitations under the License. ==============================================================================*/ var assert = chai.assert; declare function fixture(id: string): void; -declare module HTMLImports { - export function whenReady(f: Function): void; -} + module TF.Nanite { - HTMLImports.whenReady(function() { + window.HTMLImports.whenReady(function() { Polymer({ is: "test-element", properties: { diff --git a/tensorflow/tensorboard/lib/js/node-radar/test/index.html b/tensorflow/tensorboard/lib/js/node-radar/test/index.html index afb21ba15f..83c3018ed2 100644 --- a/tensorflow/tensorboard/lib/js/node-radar/test/index.html +++ b/tensorflow/tensorboard/lib/js/node-radar/test/index.html @@ -1,12 +1,8 @@ - <!doctype html> <html> - <!-- This test file has import paths that are suitable for gulp test and - direct loading in the browser --> <head> <meta charset="utf-8"> - <script src="../../../../../components/webcomponentsjs/webcomponents-lite.min.js"></script> - <script src="../../../../components/web-component-tester/browser.js"></script> + <script src="../../web-component-tester/browser.js"></script> </head> <body> <script src="../nodeRadar.js"></script> diff --git a/tensorflow/tensorboard/lib/js/requestManager/test/index.html b/tensorflow/tensorboard/lib/js/requestManager/test/index.html index b9712e8daf..53487f1f58 100644 --- a/tensorflow/tensorboard/lib/js/requestManager/test/index.html +++ b/tensorflow/tensorboard/lib/js/requestManager/test/index.html @@ -2,8 +2,7 @@ <html> <head> <meta charset="utf-8"> - <script src="../../../../../components/webcomponentsjs/webcomponents-lite.min.js"></script> - <script src="../../../../components/web-component-tester/browser.js"></script> + <script src="../../web-component-tester/browser.js"></script> </head> <body> <script src="../requestManager.js"></script> diff --git a/tensorflow/tensorboard/package.json b/tensorflow/tensorboard/package.json index 1902bc4756..25bb35df67 100644 --- a/tensorflow/tensorboard/package.json +++ b/tensorflow/tensorboard/package.json @@ -25,7 +25,7 @@ "tslint": "^3.2.1", "typescript": "1.8.0", "vulcanize": "^1.14.0", - "web-component-tester": "~3.4.2", + "web-component-tester": "4.2.2", "gulp-header": "~1.7.1", "gulp-rename": "~1.2.2", "gulp-typings": "~1.1.0", diff --git a/tensorflow/tensorboard/wct.conf.json b/tensorflow/tensorboard/wct.conf.json new file mode 100644 index 0000000000..0a5c6c20b6 --- /dev/null +++ b/tensorflow/tensorboard/wct.conf.json @@ -0,0 +1,12 @@ +{ + "suites": [ + "components/tf-*/test", + "lib/js/*/test" + ], + "plugins": ["local"], + "webserver": { + "pathMappings": [ + {"/components/<basename>/lib/js": "components"} + ] + } +}
\ No newline at end of file diff --git a/tensorflow/tools/docker/notebooks/3_mnist_from_scratch.ipynb b/tensorflow/tools/docker/notebooks/3_mnist_from_scratch.ipynb index 9bb889e41a..8f8bedbdfe 100644 --- a/tensorflow/tools/docker/notebooks/3_mnist_from_scratch.ipynb +++ b/tensorflow/tools/docker/notebooks/3_mnist_from_scratch.ipynb @@ -1273,7 +1273,7 @@ "batch_labels = train_labels[:BATCH_SIZE]\n", "\n", "# This dictionary maps the batch data (as a numpy array) to the\n", - "# node in the graph is should be fed to.\n", + "# node in the graph it should be fed to.\n", "feed_dict = {train_data_node: batch_data,\n", " train_labels_node: batch_labels}\n", "\n", @@ -1680,7 +1680,7 @@ " batch_data = train_data[offset:(offset + BATCH_SIZE), :, :, :]\n", " batch_labels = train_labels[offset:(offset + BATCH_SIZE)]\n", " # This dictionary maps the batch data (as a numpy array) to the\n", - " # node in the graph is should be fed to.\n", + " # node in the graph it should be fed to.\n", " feed_dict = {train_data_node: batch_data,\n", " train_labels_node: batch_labels}\n", " # Run the graph and fetch some of the nodes.\n", diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl index 1a8a026b7e..6debeabd97 100644 --- a/tensorflow/workspace.bzl +++ b/tensorflow/workspace.bzl @@ -10,8 +10,8 @@ def tf_workspace(path_prefix = ""): native.new_http_archive( name = "eigen_archive", - url = "https://bitbucket.org/eigen/eigen/get/f1ce2528ee99.tar.gz", - sha256 = "2c4ce322d13a613bbc53de3381760cf56f1c9b03c409233b764a6434ee1db909", + url = "https://bitbucket.org/eigen/eigen/get/88444e025a5c.tar.gz", + sha256 = "42e6f6de56b3ff010531a2bbf3e2db1db46be30d3965efb1eaa5634c5db013dd", build_file = path_prefix + "eigen.BUILD", ) diff --git a/third_party/eigen3/Eigen/Cholesky b/third_party/eigen3/Eigen/Cholesky index af815350c8..95a503d611 100644 --- a/third_party/eigen3/Eigen/Cholesky +++ b/third_party/eigen3/Eigen/Cholesky @@ -1 +1 @@ -#include "eigen-eigen-f1ce2528ee99/Eigen/Cholesky" +#include "eigen-eigen-88444e025a5c/Eigen/Cholesky" diff --git a/third_party/eigen3/Eigen/Core b/third_party/eigen3/Eigen/Core index 1625edf8f7..b4a10f6ed1 100644 --- a/third_party/eigen3/Eigen/Core +++ b/third_party/eigen3/Eigen/Core @@ -1 +1 @@ -#include "eigen-eigen-f1ce2528ee99/Eigen/Core" +#include "eigen-eigen-88444e025a5c/Eigen/Core" diff --git a/third_party/eigen3/Eigen/Eigenvalues b/third_party/eigen3/Eigen/Eigenvalues index f5e92ae98a..56657aa837 100644 --- a/third_party/eigen3/Eigen/Eigenvalues +++ b/third_party/eigen3/Eigen/Eigenvalues @@ -1 +1 @@ -#include "eigen-eigen-f1ce2528ee99/Eigen/Eigenvalues" +#include "eigen-eigen-88444e025a5c/Eigen/Eigenvalues" diff --git a/third_party/eigen3/Eigen/LU b/third_party/eigen3/Eigen/LU index 77f592a412..3c491eeef9 100644 --- a/third_party/eigen3/Eigen/LU +++ b/third_party/eigen3/Eigen/LU @@ -1 +1 @@ -#include "eigen-eigen-f1ce2528ee99/Eigen/LU" +#include "eigen-eigen-88444e025a5c/Eigen/LU" diff --git a/third_party/eigen3/Eigen/QR b/third_party/eigen3/Eigen/QR index 2f1eeb9a6e..5a97880470 100644 --- a/third_party/eigen3/Eigen/QR +++ b/third_party/eigen3/Eigen/QR @@ -1 +1 @@ -#include "eigen-eigen-f1ce2528ee99/Eigen/QR" +#include "eigen-eigen-88444e025a5c/Eigen/QR" diff --git a/third_party/eigen3/unsupported/Eigen/CXX11/Tensor b/third_party/eigen3/unsupported/Eigen/CXX11/Tensor index b87d22f207..20150d0594 100644 --- a/third_party/eigen3/unsupported/Eigen/CXX11/Tensor +++ b/third_party/eigen3/unsupported/Eigen/CXX11/Tensor @@ -1 +1 @@ -#include "eigen-eigen-f1ce2528ee99/unsupported/Eigen/CXX11/Tensor" +#include "eigen-eigen-88444e025a5c/unsupported/Eigen/CXX11/Tensor" diff --git a/tools/bazel.rc.template b/tools/bazel.rc.template index e90ec790fd..d2b1b0b25a 100644 --- a/tools/bazel.rc.template +++ b/tools/bazel.rc.template @@ -3,8 +3,6 @@ build:cuda --define=using_cuda=true build --force_python=py$PYTHON_MAJOR_VERSION build --python$PYTHON_MAJOR_VERSION_path=$PYTHON_BINARY -build --define=use_fast_cpp_protos=true -build --define=allow_oversize_protos=true build --spawn_strategy=standalone test --spawn_strategy=standalone |