diff options
143 files changed, 4189 insertions, 573 deletions
diff --git a/ISSUE_TEMPLATE.md b/ISSUE_TEMPLATE.md index 2f3df7cda9..52faed9297 100644 --- a/ISSUE_TEMPLATE.md +++ b/ISSUE_TEMPLATE.md @@ -15,9 +15,10 @@ If you open a GitHub issue, here is our policy: ### System information - **Have I written custom code (as opposed to using a stock example script provided in TensorFlow)**: - **OS Platform and Distribution (e.g., Linux Ubuntu 16.04)**: +- **Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if the issue happens on mobile device**: - **TensorFlow installed from (source or binary)**: - **TensorFlow version (use command below)**: -- **Python version**: +- **Python version**: - **Bazel version (if compiling from source)**: - **GCC/Compiler version (if compiling from source)**: - **CUDA/cuDNN version**: @@ -18,7 +18,7 @@ closure_repositories() # files, in case the parsing of those build files depends on the bazel # version we require here. load("//tensorflow:version_check.bzl", "check_bazel_version_at_least") -check_bazel_version_at_least("0.10.0") +check_bazel_version_at_least("0.15.0") load("//tensorflow:workspace.bzl", "tf_workspace") diff --git a/configure.py b/configure.py index c482628ec8..25729adf36 100644 --- a/configure.py +++ b/configure.py @@ -1429,7 +1429,7 @@ def main(): # environment variables. environ_cp = dict(os.environ) - check_bazel_version('0.10.0') + check_bazel_version('0.15.0') reset_tf_configure_bazelrc(args.workspace) cleanup_makefile() diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index c2245b8eae..9174a67cc6 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -304,11 +304,13 @@ cc_library( name = "compilation_passes", srcs = [ "build_xla_launch_ops_pass.cc", + "deadness_analysis.cc", "encapsulate_subgraphs_pass.cc", "mark_for_compilation_pass.cc", ], hdrs = [ "build_xla_launch_ops_pass.h", + "deadness_analysis.h", "encapsulate_subgraphs_pass.h", "mark_for_compilation_pass.h", ], @@ -325,6 +327,7 @@ cc_library( "//tensorflow/compiler/tf2xla:dump_graph", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:util", "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", @@ -377,6 +380,7 @@ tf_cc_test( name = "compilation_passes_test", size = "small", srcs = [ + "deadness_analysis_test.cc", "encapsulate_subgraphs_pass_test.cc", "mark_for_compilation_pass_test.cc", ], @@ -387,6 +391,7 @@ tf_cc_test( "//tensorflow/cc:cc_ops_internal", "//tensorflow/cc:function_ops", "//tensorflow/cc:ops", + "//tensorflow/cc:sendrecv_ops", "//tensorflow/compiler/jit/kernels:xla_launch_op", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/kernels:xla_ops", @@ -458,6 +463,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":common", + ":compilation_passes", ":union_find", ":xla_cluster_util", "//tensorflow/compiler/jit/graphcycles", diff --git a/tensorflow/compiler/jit/deadness_analysis.cc b/tensorflow/compiler/jit/deadness_analysis.cc new file mode 100644 index 0000000000..b2d119029a --- /dev/null +++ b/tensorflow/compiler/jit/deadness_analysis.cc @@ -0,0 +1,546 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/deadness_analysis.h" +#include "tensorflow/core/graph/algorithm.h" +#include "tensorflow/core/graph/tensor_id.h" +#include "tensorflow/core/lib/gtl/flatset.h" +#include "tensorflow/core/lib/hash/hash.h" + +// ALGORITHM OVERVIEW +// +// We map every output produced by each node in the TensorFlow graph (including +// control dependence) into an instance of the Predicate class. Instances of +// Predicate denote logical formulas and mapping a node `n` to a predicate +// `pred` implies that `n` is executed whenver `pred` is true. Then we can +// deduce mismatching liveness in the inputs to node by comparing the predicate +// those inputs are mapped to. +// +// Loops are handled pessimistically -- we map Merge nodes with backedges to +// uninterpreted symbols (the same kind we use to represent Switch and _Recv). +// Predicate equality has to hold over all possible assignments to these +// uninterpreted symbols. + +namespace tensorflow { + +namespace { + +// Represents a logical predicate, used as described in the algorithm overview +// above. +class Predicate { + public: + enum class Kind { kAnd, kOr, kNot, kSymbol }; + + virtual string ToString() const = 0; + virtual bool operator==(const Predicate& other) const = 0; + virtual bool operator!=(const Predicate& other) const { + return !(*this == other); + } + int64 hash() const { return hash_; } + + virtual Kind kind() const = 0; + virtual ~Predicate() {} + + protected: + explicit Predicate(int64 hash) : hash_(hash) {} + + private: + const int64 hash_; +}; + +int64 HashPredicateSequence(Predicate::Kind kind, + gtl::ArraySlice<Predicate*> preds) { + int64 hash = ::tensorflow::hash<Predicate::Kind>()(kind); + for (Predicate* pred : preds) { + hash = Hash64Combine(hash, pred->hash()); + } + return hash; +} + +bool PredicateSequenceEqual(gtl::ArraySlice<Predicate*> lhs, + gtl::ArraySlice<Predicate*> rhs) { + if (lhs.size() != rhs.size()) { + return false; + } + for (int64 i = 0; i < lhs.size(); i++) { + if (*lhs[i] != *rhs[i]) { + return false; + } + } + return true; +} + +// Represents a logical conjunction of a set of predicates. +class AndPredicate : public Predicate { + public: + explicit AndPredicate(std::vector<Predicate*> operands) + : Predicate(HashPredicateSequence(Kind::kAnd, operands)), + operands_(std::move(operands)) {} + + string ToString() const override { + if (operands().empty()) { + return "#true"; + } + + std::vector<string> operands_str; + std::transform(operands().begin(), operands().end(), + std::back_inserter(operands_str), + [](Predicate* pred) { return pred->ToString(); }); + + return strings::StrCat("(", str_util::Join(operands_str, " & "), ")"); + } + + bool operator==(const Predicate& other) const override { + return other.kind() == Kind::kAnd && + PredicateSequenceEqual( + dynamic_cast<const AndPredicate&>(other).operands(), operands()); + } + + Kind kind() const override { return Kind::kAnd; } + + const tensorflow::gtl::ArraySlice<Predicate*> operands() const { + return operands_; + } + + private: + std::vector<Predicate*> operands_; +}; + +// Represents a logical disjunction of a set of predicates. +class OrPredicate : public Predicate { + public: + explicit OrPredicate(std::vector<Predicate*> operands) + : Predicate(HashPredicateSequence(Kind::kOr, operands)), + operands_(std::move(operands)) {} + + string ToString() const override { + if (operands().empty()) { + return "#false"; + } + + std::vector<string> operands_str; + std::transform(operands().begin(), operands().end(), + std::back_inserter(operands_str), + [](Predicate* pred) { return pred->ToString(); }); + + return strings::StrCat("(", str_util::Join(operands_str, " | "), ")"); + } + + bool operator==(const Predicate& other) const override { + return other.kind() == Kind::kOr && + PredicateSequenceEqual( + dynamic_cast<const OrPredicate&>(other).operands(), operands()); + } + + Kind kind() const override { return Kind::kOr; } + const tensorflow::gtl::ArraySlice<Predicate*> operands() const { + return operands_; + } + + private: + std::vector<Predicate*> operands_; +}; + +// Represents a logical negation of a set of predicates. +class NotPredicate : public Predicate { + public: + explicit NotPredicate(Predicate* operand) + : Predicate(HashPredicateSequence(Kind::kNot, {operand})), + operand_(operand) {} + + string ToString() const override { + return strings::StrCat("~", operand()->ToString()); + } + + bool operator==(const Predicate& other) const override { + return other.kind() == Kind::kNot && + *dynamic_cast<const NotPredicate&>(other).operand() == *operand(); + } + + Kind kind() const override { return Kind::kNot; } + Predicate* operand() const { return operand_; } + + private: + Predicate* operand_; +}; + +// Represents an uninterpreted symbol in a logical predicate. +// +// Two predicates are equivalent iff they are equivalent for all assignments to +// the symbols contained in them. +class SymbolPredicate : public Predicate { + public: + explicit SymbolPredicate(TensorId tensor_id, bool must_be_true) + : Predicate(Hash(tensor_id, must_be_true)), + tensor_id_(std::move(tensor_id)), + must_be_true_(must_be_true) {} + + string ToString() const override { return tensor_id_.ToString(); } + bool operator==(const Predicate& other) const override { + return other.kind() == Kind::kSymbol && + must_be_true() == + dynamic_cast<const SymbolPredicate&>(other).must_be_true() && + dynamic_cast<const SymbolPredicate&>(other).tensor_id() == + tensor_id(); + } + + Kind kind() const override { return Kind::kSymbol; } + + // If `must_be_true()` is true this SymbolPredicate represents the proposition + // "tensor_id() is live and evaluates to true". + // + // If `must_be_true()` is false then this SymbolPredicate represents the + // proposition "tensor_id() is live (and may evalutate to any value)" + TensorId tensor_id() const { return tensor_id_; } + bool must_be_true() const { return must_be_true_; } + + private: + TensorId tensor_id_; + bool must_be_true_; + + static int64 Hash(const TensorId tensor_id, bool must_be_true) { + return Hash64Combine( + ::tensorflow::hash<bool>()(must_be_true), + Hash64Combine(::tensorflow::hash<Predicate::Kind>()(Kind::kSymbol), + TensorId::Hasher{}(tensor_id))); + } +}; + +// Creates and owns Predicate instances. Simplifies predicates as it creates +// them. +class PredicateFactory { + public: + Predicate* MakeAndPredicate(gtl::ArraySlice<Predicate*> operands) { + return MakeAndOrImpl(operands, /*is_and=*/true); + } + Predicate* MakeOrPredicate(gtl::ArraySlice<Predicate*> operands) { + return MakeAndOrImpl(operands, /*is_and=*/false); + } + + Predicate* MakeNotPredicate(Predicate* pred) { + return Make<NotPredicate>(pred); + } + + Predicate* MakeSymbolPredicate(TensorId tensor_id, bool must_be_true) { + return Make<SymbolPredicate>(tensor_id, must_be_true); + } + + Predicate* MakeTrue() { return MakeAndPredicate({}); } + Predicate* MakeFalse() { return MakeOrPredicate({}); } + + private: + template <typename PredicateT, typename... Args> + Predicate* Make(Args... args) { + std::unique_ptr<PredicateT> pred( + new PredicateT(std::forward<Args>(args)...)); + predicate_storage_.emplace_back(std::move(pred)); + return predicate_storage_.back().get(); + } + + Predicate* MakeAndOrImpl(gtl::ArraySlice<Predicate*> operands, bool is_and); + + struct PredicatePtrHash { + size_t operator()(const Predicate* pred) const { return pred->hash(); } + }; + + struct PredicatePtrEq { + size_t operator()(const Predicate* a, const Predicate* b) const { + return *a == *b; + } + }; + + using PredicateSet = + gtl::FlatSet<Predicate*, PredicatePtrHash, PredicatePtrEq>; + + std::vector<std::unique_ptr<Predicate>> predicate_storage_; +}; + +// Common code to create AndPredicate or OrPredicate instances. +Predicate* PredicateFactory::MakeAndOrImpl(gtl::ArraySlice<Predicate*> operands, + bool is_and) { + Predicate::Kind pred_kind = + is_and ? Predicate::Kind::kAnd : Predicate::Kind::kOr; + PredicateSet simplified_ops_set; + std::vector<Predicate*> simplified_ops; + for (Predicate* op : operands) { + // Simplify A&A => A and A|A => A. + if (!simplified_ops_set.insert(op).second) { + continue; + } + + if (op->kind() == pred_kind) { + // "Inline" the operands of an inner And/Or into the parent And/Or. + gtl::ArraySlice<Predicate*> operands = + is_and ? dynamic_cast<AndPredicate*>(op)->operands() + : dynamic_cast<OrPredicate*>(op)->operands(); + for (Predicate* subop : operands) { + if (simplified_ops_set.insert(subop).second) { + simplified_ops.push_back(subop); + } + } + } else { + simplified_ops.push_back(op); + } + } + + if (simplified_ops.size() == 1) { + return simplified_ops[0]; + } + + // Simplify "A&~A=>False" and "A|~A=>True". + PredicateSet negated_ops; + for (Predicate* op : simplified_ops) { + if (op->kind() == Predicate::Kind::kNot) { + negated_ops.insert(dynamic_cast<NotPredicate&>(*op).operand()); + } + } + + for (Predicate* op : simplified_ops) { + if (negated_ops.count(op)) { + return is_and ? MakeFalse() : MakeTrue(); + } + } + + std::stable_sort( + simplified_ops.begin(), simplified_ops.end(), + [](Predicate* a, Predicate* b) { return a->hash() < b->hash(); }); + + return is_and ? Make<AndPredicate>(std::move(simplified_ops)) + : Make<OrPredicate>(std::move(simplified_ops)); +} + +class DeadnessAnalysisImpl : public DeadnessAnalysis { + public: + explicit DeadnessAnalysisImpl(const Graph* graph) + : graph_(*graph), vlog_(VLOG_IS_ON(2)) {} + + Status Populate(); + bool HasInputsWithMismatchingDeadness(const Node& node) override; + void Print() const override; + + private: + enum class EdgeKind { kDataAndControl, kDataOnly, kControlOnly }; + + std::vector<Predicate*> GetIncomingPreds(Node* n, EdgeKind edge_kind); + void SetPred(Node* n, int output_idx, Predicate* pred) { + CHECK( + predicate_map_.insert({TensorId(n->name(), output_idx), pred}).second); + } + void SetPred(Node* n, gtl::ArraySlice<int> output_idxs, Predicate* pred) { + for (int output_idx : output_idxs) { + SetPred(n, output_idx, pred); + } + } + + Status HandleSwitch(Node* n); + Status HandleMerge(Node* n); + Status HandleRecv(Node* n); + Status HandleGeneric(Node* n); + + const Graph& graph_; + gtl::FlatMap<TensorId, Predicate*, TensorId::Hasher> predicate_map_; + PredicateFactory predicate_factory_; + bool vlog_; +}; + +TensorId InputEdgeToTensorId(const Edge* e) { + return TensorId(e->src()->name(), e->src_output()); +} + +std::vector<Predicate*> DeadnessAnalysisImpl::GetIncomingPreds( + Node* n, DeadnessAnalysisImpl::EdgeKind edge_kind) { + std::vector<Predicate*> incoming_preds; + for (const Edge* in_edge : n->in_edges()) { + bool should_process = + edge_kind == EdgeKind::kDataAndControl || + (in_edge->IsControlEdge() && edge_kind == EdgeKind::kControlOnly) || + (!in_edge->IsControlEdge() && edge_kind == EdgeKind::kDataOnly); + + if (should_process) { + auto it = predicate_map_.find(InputEdgeToTensorId(in_edge)); + CHECK(it != predicate_map_.end()); + incoming_preds.push_back(it->second); + } + } + return incoming_preds; +} + +Status DeadnessAnalysisImpl::HandleSwitch(Node* n) { + std::vector<Predicate*> input_preds = + GetIncomingPreds(n, EdgeKind::kDataAndControl); + const Edge* pred_edge; + TF_RETURN_IF_ERROR(n->input_edge(1, &pred_edge)); + Predicate* true_switch = predicate_factory_.MakeSymbolPredicate( + TensorId(pred_edge->src()->name(), pred_edge->src_output()), + /*must_be_true=*/true); + Predicate* false_switch = predicate_factory_.MakeNotPredicate(true_switch); + + // Output 0 is alive iff all inputs are alive and the condition is false. + input_preds.push_back(false_switch); + SetPred(n, 0, predicate_factory_.MakeAndPredicate(input_preds)); + input_preds.pop_back(); + + // Output 1 is alive iff all inputs are alive and the condition is true. + input_preds.push_back(true_switch); + SetPred(n, 1, predicate_factory_.MakeAndPredicate(input_preds)); + input_preds.pop_back(); + + // Control is alive iff any inputs are alive. + SetPred(n, Graph::kControlSlot, + predicate_factory_.MakeAndPredicate(input_preds)); + + return Status::OK(); +} + +Status DeadnessAnalysisImpl::HandleMerge(Node* n) { + // Merge ignores deadness of its control inputs. A merge that isn't the + // target of a backedge has is alive iff any of its data inputs are. We treat + // the liveness of a merge that is the target of a backedge symbolically. + + bool has_backedge = std::any_of( + n->in_edges().begin(), n->in_edges().end(), [](const Edge* e) { + return !e->IsControlEdge() && e->src()->IsNextIteration(); + }); + + Predicate* input_data_pred = + has_backedge ? predicate_factory_.MakeSymbolPredicate( + TensorId(n->name(), 0), /*must_be_true=*/false) + : predicate_factory_.MakeOrPredicate( + GetIncomingPreds(n, EdgeKind::kDataOnly)); + + SetPred(n, {0, 1, Graph::kControlSlot}, input_data_pred); + return Status::OK(); +} + +Status DeadnessAnalysisImpl::HandleRecv(Node* n) { + // In addition to being alive or dead based on the inputs, a _Recv can also + // acquire a dead signal from a _Send. + std::vector<Predicate*> input_preds = + GetIncomingPreds(n, EdgeKind::kDataAndControl); + input_preds.push_back(predicate_factory_.MakeSymbolPredicate( + TensorId(n->name(), 0), /*must_be_true=*/false)); + SetPred(n, {0, Graph::kControlSlot}, + predicate_factory_.MakeAndPredicate(input_preds)); + return Status::OK(); +} + +Status DeadnessAnalysisImpl::HandleGeneric(Node* n) { + // Generally nodes are alive iff all their inputs are alive. + Predicate* pred = predicate_factory_.MakeAndPredicate( + GetIncomingPreds(n, EdgeKind::kDataAndControl)); + for (int output_idx = 0; output_idx < n->num_outputs(); output_idx++) { + SetPred(n, output_idx, pred); + } + SetPred(n, Graph::kControlSlot, pred); + return Status::OK(); +} + +Status DeadnessAnalysisImpl::Populate() { + std::vector<Node*> rpo; + GetReversePostOrder(graph_, &rpo, /*stable_comparator=*/{}, + /*edge_filter=*/[](const Edge& edge) { + return !edge.src()->IsNextIteration(); + }); + + // This an abstract interpretation over the deadness propagation semantics of + // the graph executor. + for (Node* n : rpo) { + if (n->IsSwitch()) { + TF_RETURN_IF_ERROR(HandleSwitch(n)); + } else if (n->IsMerge()) { + TF_RETURN_IF_ERROR(HandleMerge(n)); + } else if (n->IsControlTrigger()) { + SetPred(n, Graph::kControlSlot, predicate_factory_.MakeTrue()); + } else if (n->IsRecv() || n->IsHostRecv()) { + TF_RETURN_IF_ERROR(HandleRecv(n)); + } else { + TF_RETURN_IF_ERROR(HandleGeneric(n)); + } + } + + return Status::OK(); +} + +bool DeadnessAnalysisImpl::HasInputsWithMismatchingDeadness(const Node& node) { + CHECK(!node.IsMerge()); + + if (vlog_) { + VLOG(2) << "HasInputsWithMismatchingDeadness(" << node.name() << ")"; + } + + Predicate* pred = nullptr; + for (const Edge* edge : node.in_edges()) { + auto it = predicate_map_.find(InputEdgeToTensorId(edge)); + CHECK(it != predicate_map_.end()); + if (vlog_) { + VLOG(2) << " " << InputEdgeToTensorId(edge).ToString() << ": " + << it->second->ToString(); + } + + // Today we just compare the predicates for equality (with some + // canonicalization/simplification happening before) but we could be more + // sophisticated here if need be. + if (pred != nullptr && *pred != *it->second) { + if (vlog_) { + VLOG(2) << "HasInputsWithMismatchingDeadness(" << node.name() + << ") -> true"; + } + return true; + } + pred = it->second; + } + + if (vlog_) { + VLOG(2) << "HasInputsWithMismatchingDeadness(" << node.name() + << ") -> false"; + } + + return false; +} + +void DeadnessAnalysisImpl::Print() const { + std::vector<TensorId> tensor_ids; + for (const auto& kv_pair : predicate_map_) { + tensor_ids.push_back(kv_pair.first); + } + + std::sort(tensor_ids.begin(), tensor_ids.end()); + + for (TensorId tensor_id : tensor_ids) { + auto it = predicate_map_.find(tensor_id); + CHECK(it != predicate_map_.end()) << tensor_id.ToString(); + VLOG(2) << tensor_id.ToString() << " -> " << it->second->ToString(); + } +} + +} // namespace + +DeadnessAnalysis::~DeadnessAnalysis() {} + +/*static*/ Status DeadnessAnalysis::Run( + const Graph& graph, std::unique_ptr<DeadnessAnalysis>* result) { + std::unique_ptr<DeadnessAnalysisImpl> analysis( + new DeadnessAnalysisImpl(&graph)); + TF_RETURN_IF_ERROR(analysis->Populate()); + + if (VLOG_IS_ON(2)) { + analysis->Print(); + } + + *result = std::move(analysis); + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/deadness_analysis.h b/tensorflow/compiler/jit/deadness_analysis.h new file mode 100644 index 0000000000..6e7ab41161 --- /dev/null +++ b/tensorflow/compiler/jit/deadness_analysis.h @@ -0,0 +1,68 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_JIT_DEADNESS_ANALYSIS_H_ +#define TENSORFLOW_COMPILER_JIT_DEADNESS_ANALYSIS_H_ + +#include "tensorflow/core/graph/graph.h" + +namespace tensorflow { + +// This analyzes a TensorFlow graph to identify nodes which may have partially +// dead inputs (i.e. these nodes may have some dead inputs and some alive +// inputs). +// +// For example, the ADD node in the following graph +// +// V0 PRED0 V1 PRED1 +// | | | | +// v v v v +// SWITCH SWITCH +// | | +// +---+ + ---+ +// | | +// v v +// ADD +// +// can have its inputs independently dead or alive based on the runtime values +// of PRED0 and PRED1. +// +// It is tempting to call this a liveness analysis but I avoided that because +// "liveness" already has other connotations. +class DeadnessAnalysis { + public: + // Returns true if `node` may have some live inputs and some dead inputs. + // + // This is a conservatively correct routine -- if it returns false then `node` + // is guaranteed to not have inputs with mismatching liveness, but not the + // converse. + // + // REQUIRES: node is not a Merge operation. + virtual bool HasInputsWithMismatchingDeadness(const Node& node) = 0; + + // Prints out the internal state of this instance. For debugging purposes + // only. + virtual void Print() const = 0; + virtual ~DeadnessAnalysis(); + + // Run the deadness analysis over `graph` and returns an error or a populated + // instance of DeadnessAnalysis in `result`. + static Status Run(const Graph& graph, + std::unique_ptr<DeadnessAnalysis>* result); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_DEADNESS_ANALYSIS_H_ diff --git a/tensorflow/compiler/jit/deadness_analysis_test.cc b/tensorflow/compiler/jit/deadness_analysis_test.cc new file mode 100644 index 0000000000..584385cab7 --- /dev/null +++ b/tensorflow/compiler/jit/deadness_analysis_test.cc @@ -0,0 +1,443 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/deadness_analysis.h" + +#include "tensorflow/cc/framework/ops.h" +#include "tensorflow/cc/ops/array_ops.h" +#include "tensorflow/cc/ops/control_flow_ops_internal.h" +#include "tensorflow/cc/ops/function_ops.h" +#include "tensorflow/cc/ops/sendrecv_ops.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/compiler/jit/defs.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/graph/algorithm.h" +#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/graph/graph_def_builder.h" +#include "tensorflow/core/graph/graph_def_builder_util.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +Status AnalyzeDeadness(Graph* graph, + std::unique_ptr<DeadnessAnalysis>* result) { + FixupSourceAndSinkEdges(graph); + return DeadnessAnalysis::Run(*graph, result); +} + +ops::Switch CreateSwitch(const Scope& root, const string& prefix) { + Output value = ops::Placeholder(root.WithOpName(prefix + "/value"), DT_FLOAT); + Output predicate = + ops::Placeholder(root.WithOpName(prefix + "/pred"), DT_BOOL); + return ops::Switch(root.WithOpName(prefix + "/switch"), value, predicate); +} + +Output CreateInductionVariable(const Scope& root, const string& prefix, + const string& frame_name, int32 init) { + Output initial_value = ops::Const(root.WithOpName(prefix + "/init"), init); + Output enter_initial_value = ops::internal::Enter( + root.WithOpName(prefix + "/enter"), initial_value, frame_name); + + ops::Merge iv(root.WithOpName(prefix + "/iv"), {enter_initial_value}); + Output increment_by = ops::Const(root.WithOpName(prefix + "/incr"), 1); + Output final_value = ops::Const(root.WithOpName(prefix + "/final"), 10); + Output loop_cond_expr = + ops::Less(root.WithOpName(prefix + "/less"), iv.output, final_value); + Output loop_cond = + ops::LoopCond(root.WithOpName(prefix + "/cond"), loop_cond_expr); + ops::Switch latch(root.WithOpName(prefix + "/latch"), iv.output, loop_cond); + ops::internal::Exit exit(root.WithOpName(prefix + "/exit"), iv.output); + Output iv_next = + ops::Add(root.WithOpName(prefix + "/ivnext"), iv.output, increment_by); + Output next_iteration = + ops::NextIteration(root.WithOpName(prefix + "next_iteration"), iv_next); + + root.graph()->AddEdge(next_iteration.node(), 0, iv.output.node(), 1); + root.graph()->AddControlEdge(iv.output.node(), increment_by.node()); + root.graph()->AddControlEdge(iv.output.node(), final_value.node()); + + return iv.output; +} + +TEST(DeadnessAnalysisTest, BasicPositive) { + Scope root = Scope::NewRootScope().ExitOnError(); + + ops::Switch sw = CreateSwitch(root, "0"); + Output add = + ops::Add(root.WithOpName("add"), sw.output_true, sw.output_false); + + std::unique_ptr<DeadnessAnalysis> result; + TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result)); + + EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*add.node())); +} + +TEST(DeadnessAnalysisTest, BasicNegative) { + Scope root = Scope::NewRootScope().ExitOnError(); + + Output a = ops::Placeholder(root.WithOpName("a"), DT_FLOAT); + Output b = ops::Placeholder(root.WithOpName("b"), DT_FLOAT); + Output add = ops::Add(root.WithOpName("add"), a, b); + + std::unique_ptr<DeadnessAnalysis> result; + TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result)); + + EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add.node())); +} + +TEST(DeadnessAnalysisTest, AndIsCommutative) { + Scope root = Scope::NewRootScope().ExitOnError(); + + ops::Switch sw_0 = CreateSwitch(root, "0"); + ops::Switch sw_1 = CreateSwitch(root, "1"); + + Output a0 = + ops::Add(root.WithOpName("a0"), sw_0.output_false, sw_1.output_false); + Output a1 = + ops::Add(root.WithOpName("a1"), sw_1.output_false, sw_0.output_false); + + Output b0 = + ops::Add(root.WithOpName("b0"), sw_0.output_false, sw_1.output_true); + Output b1 = + ops::Add(root.WithOpName("b1"), sw_1.output_true, sw_0.output_false); + + Output live0 = ops::Add(root.WithOpName("live0"), a0, a1); + Output live1 = ops::Add(root.WithOpName("live1"), b0, b1); + + Output halfdead0 = ops::Add(root.WithOpName("halfdead0"), a0, b0); + Output halfdead1 = ops::Add(root.WithOpName("halfdead1"), a1, b1); + + std::unique_ptr<DeadnessAnalysis> result; + TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result)); + + EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*live0.node())); + EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*live1.node())); + + EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*halfdead0.node())); + EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*halfdead1.node())); +} + +TEST(DeadnessAnalysisTest, AndIsAssociative) { + Scope root = Scope::NewRootScope().ExitOnError(); + + ops::Switch sw_0 = CreateSwitch(root, "0"); + ops::Switch sw_1 = CreateSwitch(root, "1"); + ops::Switch sw_2 = CreateSwitch(root, "2"); + + Output a0 = + ops::Add(root.WithOpName("a0"), sw_0.output_false, sw_1.output_false); + Output a1 = ops::Add(root.WithOpName("a1"), a0, sw_2.output_false); + + Output b0 = + ops::Add(root.WithOpName("b0"), sw_1.output_false, sw_2.output_false); + Output b1 = ops::Add(root.WithOpName("b1"), sw_0.output_false, b0); + + Output add = ops::Add(root.WithOpName("add"), a1, b1); + + std::unique_ptr<DeadnessAnalysis> result; + TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result)); + + EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add.node())); +} + +TEST(DeadnessAnalysisTest, OrIsCommutative) { + Scope root = Scope::NewRootScope().ExitOnError(); + + ops::Switch sw_0 = CreateSwitch(root, "0"); + ops::Switch sw_1 = CreateSwitch(root, "1"); + + ops::Merge m0(root.WithOpName("m0"), {sw_0.output_false, sw_1.output_false}); + ops::Merge m1(root.WithOpName("m1"), {sw_1.output_false, sw_0.output_false}); + ops::Merge m2(root.WithOpName("m2"), {sw_0.output_false, sw_1.output_true}); + ops::Merge m3(root.WithOpName("m3"), {sw_1.output_true, sw_0.output_false}); + + Output live0 = ops::Add(root.WithOpName("live0"), m0.output, m1.output); + Output live1 = ops::Add(root.WithOpName("live1"), m2.output, m3.output); + + Output halfdead0 = + ops::Add(root.WithOpName("halfdead0"), m0.output, m2.output); + Output halfdead1 = + ops::Add(root.WithOpName("halfdead1"), m1.output, m3.output); + + std::unique_ptr<DeadnessAnalysis> result; + TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result)); + + EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*live0.node())); + EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*live1.node())); + + EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*halfdead0.node())); + EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*halfdead1.node())); +} + +TEST(DeadnessAnalysisTest, OrIsAssociative) { + Scope root = Scope::NewRootScope().ExitOnError(); + + ops::Switch sw_0 = CreateSwitch(root, "0"); + ops::Switch sw_1 = CreateSwitch(root, "1"); + ops::Switch sw_2 = CreateSwitch(root, "2"); + + ops::Merge m0(root.WithOpName("m0"), {sw_0.output_false, sw_1.output_false}); + ops::Merge m1(root.WithOpName("m1"), {m0.output, sw_2.output_false}); + ops::Merge m2(root.WithOpName("m2"), {sw_1.output_false, sw_2.output_false}); + ops::Merge m3(root.WithOpName("m3"), {sw_0.output_false, m2.output}); + + Output add = ops::Add(root.WithOpName("add"), m1.output, m3.output); + + std::unique_ptr<DeadnessAnalysis> result; + TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result)); + + EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add.node())); +} + +TEST(DeadnessAnalysisTest, AndOfOr) { + Scope root = Scope::NewRootScope().ExitOnError(); + + ops::Switch sw_0 = CreateSwitch(root, "0"); + ops::Switch sw_1 = CreateSwitch(root, "1"); + ops::Switch sw_2 = CreateSwitch(root, "2"); + ops::Switch sw_3 = CreateSwitch(root, "3"); + + ops::Merge m0(root.WithOpName("m0"), {sw_0.output_false, sw_1.output_false}); + ops::Merge m1(root.WithOpName("m1"), {sw_2.output_false, sw_3.output_false}); + + Output add0 = ops::Add(root.WithOpName("add0"), m0.output, m1.output); + Output add1 = ops::Add(root.WithOpName("add1"), m0.output, m1.output); + + Output add2 = ops::Add(root.WithOpName("add2"), add0, add1); + + std::unique_ptr<DeadnessAnalysis> result; + TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result)); + + EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add2.node())); +} + +TEST(DeadnessAnalysisTest, OrOfAnd) { + Scope root = Scope::NewRootScope().ExitOnError(); + + ops::Switch sw_0 = CreateSwitch(root, "0"); + ops::Switch sw_1 = CreateSwitch(root, "1"); + ops::Switch sw_2 = CreateSwitch(root, "2"); + ops::Switch sw_3 = CreateSwitch(root, "3"); + + Output add0 = + ops::Add(root.WithOpName("add0"), sw_0.output_false, sw_1.output_false); + Output add1 = + ops::Add(root.WithOpName("add1"), sw_2.output_false, sw_3.output_false); + + ops::Merge m0(root.WithOpName("m0"), {add0, add1}); + ops::Merge m1(root.WithOpName("m1"), {add0, add1}); + + Output add2 = ops::Add(root.WithOpName("add2"), m0.output, m1.output); + + std::unique_ptr<DeadnessAnalysis> result; + TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result)); + + EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add2.node())); +} + +TEST(DeadnessAnalysisTest, NEGATIVE_AndOrDistributive) { + // This demonstrates one of the weaknesses in the current approach -- since we + // only do some basic simplifications we can't see that "(A|B)&C" == + // "(A&C)|(B&C)". + Scope root = Scope::NewRootScope().ExitOnError(); + + ops::Switch sw_0 = CreateSwitch(root, "0"); + ops::Switch sw_1 = CreateSwitch(root, "1"); + ops::Switch sw_2 = CreateSwitch(root, "2"); + + ops::Merge m0(root.WithOpName("m0"), {sw_0.output_false, sw_1.output_false}); + Output add0 = ops::Add(root.WithOpName("add0"), m0.output, sw_2.output_false); + + Output add1 = + ops::Add(root.WithOpName("add1"), sw_0.output_false, sw_2.output_false); + Output add2 = + ops::Add(root.WithOpName("add2"), sw_1.output_false, sw_2.output_false); + ops::Merge m1(root.WithOpName("m1"), {add1, add2}); + + Output add3 = ops::Add(root.WithOpName("add3"), add0, m1.output); + + std::unique_ptr<DeadnessAnalysis> result; + TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result)); + + EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*add2.node())); +} + +TEST(DeadnessAnalysisTest, Ternary) { + Scope root = Scope::NewRootScope().ExitOnError(); + + Output predicate = ops::Placeholder(root.WithOpName("predicate"), DT_BOOL); + Output true_value = ops::Placeholder(root.WithOpName("true_value"), DT_FLOAT); + Output false_value = + ops::Placeholder(root.WithOpName("false_value"), DT_FLOAT); + + ops::Switch predicated_true(root.WithOpName("predicated_true"), true_value, + predicate); + + ops::Switch predicated_false(root.WithOpName("predicated_false"), true_value, + predicate); + ops::Merge merge(root.WithOpName("ternary"), {predicated_true.output_true, + predicated_false.output_false}); + Output addend = ops::Placeholder(root.WithOpName("addend"), DT_FLOAT); + Output add = ops::Add(root.WithOpName("add"), merge.output, addend); + + std::unique_ptr<DeadnessAnalysis> result; + TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result)); + + EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add.node())); +} + +TEST(DeadnessAnalysisTest, Recv) { + Scope root = Scope::NewRootScope().ExitOnError(); + + Output recv_a = ops::_Recv(root.WithOpName("recv_a"), DT_FLOAT, "tensor_a", + "sender", 0, "receiver"); + Output recv_b = ops::_Recv(root.WithOpName("recv_b"), DT_FLOAT, "tensor_b", + "sender", 0, "receiver"); + Output add = ops::Add(root.WithOpName("add"), recv_a, recv_b); + + std::unique_ptr<DeadnessAnalysis> result; + TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result)); + + EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*add.node())); +} + +TEST(DeadnessAnalysisTest, HostRecv) { + Scope root = Scope::NewRootScope().ExitOnError(); + + Output recv_a = ops::_HostRecv(root.WithOpName("recv_a"), DT_FLOAT, + "tensor_a", "sender", 0, "receiver"); + Output recv_b = ops::_HostRecv(root.WithOpName("recv_b"), DT_FLOAT, + "tensor_b", "sender", 0, "receiver"); + Output add = ops::Add(root.WithOpName("add"), recv_a, recv_b); + + std::unique_ptr<DeadnessAnalysis> result; + TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result)); + + EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*add.node())); +} + +TEST(DeadnessAnalysisTest, Loop) { + Scope root = Scope::NewRootScope().ExitOnError(); + Output iv0 = CreateInductionVariable(root, "iv0", "fr0", 0); + Output iv1 = CreateInductionVariable(root, "iv1", "fr0", 0); + Output iv2 = CreateInductionVariable(root, "iv2", "fr0", 1); + Output add0 = ops::Add(root.WithOpName("add0"), iv0, iv1); + Output add1 = ops::Add(root.WithOpName("add1"), iv1, iv2); + + std::unique_ptr<DeadnessAnalysis> result; + TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result)); + + // NB! iv0 and iv1 are equivalent and a smarter deadness analysis would have + // noticed that. Today we are pessimistic here because we assign an + // uninterpreted symbol to merges with backedges. + + EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*add0.node())); + EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*add1.node())); +} + +TEST(DeadnessAnalysisTest, ControlInputs) { + Scope root = Scope::NewRootScope().ExitOnError(); + ops::Switch sw = CreateSwitch(root, "0"); + + Output id0 = ops::Identity(root.WithOpName("id0"), sw.output_false); + Output id1 = ops::Identity(root.WithOpName("id1"), sw.output_true); + + Output const0 = ops::Const(root.WithOpName("const0"), 1); + Output const1 = ops::Const(root.WithOpName("const1"), 2); + + Output add = ops::Add(root.WithOpName("add"), const0, const1); + + root.graph()->AddControlEdge(id0.node(), const0.node()); + root.graph()->AddControlEdge(id1.node(), const1.node()); + + std::unique_ptr<DeadnessAnalysis> result; + TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result)); + + EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*add.node())); +} + +TEST(DeadnessAnalysisTest, ControlTrigger) { + Scope root = Scope::NewRootScope().ExitOnError(); + ops::Switch sw = CreateSwitch(root, "0"); + + Output id0 = ops::Identity(root.WithOpName("id0"), sw.output_false); + Output id1 = ops::Identity(root.WithOpName("id1"), sw.output_true); + + ops::ControlTrigger ctrl_trigger0(root.WithOpName("ctrl_trigger0")); + ops::ControlTrigger ctrl_trigger1(root.WithOpName("ctrl_trigger1")); + + Output const0 = ops::Const(root.WithOpName("const0"), 1); + Output const1 = ops::Const(root.WithOpName("const1"), 2); + + Output add = ops::Add(root.WithOpName("add"), const0, const1); + + root.graph()->AddControlEdge(id0.node(), ctrl_trigger0.operation.node()); + root.graph()->AddControlEdge(ctrl_trigger0.operation.node(), const0.node()); + + root.graph()->AddControlEdge(id1.node(), ctrl_trigger1.operation.node()); + root.graph()->AddControlEdge(ctrl_trigger1.operation.node(), const1.node()); + + std::unique_ptr<DeadnessAnalysis> result; + TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result)); + + EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add.node())); +} + +TEST(DeadnessAnalysisTest, ControlInputsToMerge) { + Scope root = Scope::NewRootScope().ExitOnError(); + ops::Switch sw = CreateSwitch(root, "0"); + + Output id0 = ops::Identity(root.WithOpName("id0"), sw.output_false); + Output id1 = ops::Identity(root.WithOpName("id1"), sw.output_true); + + Output constant = ops::Const(root.WithOpName("constant"), 5); + ops::Merge m0(root.WithOpName("m0"), {constant}); + ops::Merge m1(root.WithOpName("m0"), {constant}); + Output add = ops::Add(root.WithOpName("add"), m0.output, m1.output); + + root.graph()->AddControlEdge(id0.node(), m0.output.node()); + root.graph()->AddControlEdge(id1.node(), m1.output.node()); + + std::unique_ptr<DeadnessAnalysis> result; + TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result)); + + EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add.node())); +} + +TEST(DeadnessAnalysisTest, RecvVsSwitch) { + // Demonstrates why we need the must_be_true bit on SymbolP. + Scope root = Scope::NewRootScope().ExitOnError(); + + Output recv = ops::_Recv(root.WithOpName("recv"), DT_BOOL, "tensor", "sender", + 0, "receiver"); + Output value = ops::Placeholder(root.WithOpName("value"), DT_BOOL); + ops::Switch sw(root.WithOpName("switch"), value, recv); + Output logical_and = + ops::LogicalAnd(root.WithOpName("and"), recv, sw.output_true); + + std::unique_ptr<DeadnessAnalysis> result; + TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result)); + + EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*logical_and.node())); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc index 9c424b201e..fdd71c6a58 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc @@ -138,7 +138,7 @@ class Encapsulator { // Find subgraphs marked with 'group_attribute', and build a new // subgraph, one for each value of 'group_attribute'. - Status SplitIntoSubgraphs(); + Status SplitIntoSubgraphs(FunctionLibraryDefinition* library); // Build a FunctionDef for each subgraph, and add it 'library'. The values of // the 'group_attribute' annotations become the function names. @@ -1478,7 +1478,7 @@ Status Encapsulator::CopySubgraphEdges( return Status::OK(); } -Status Encapsulator::SplitIntoSubgraphs() { +Status Encapsulator::SplitIntoSubgraphs(FunctionLibraryDefinition* library) { Status s; // Map from input graph nodes to subgraph nodes. @@ -1513,6 +1513,15 @@ Status Encapsulator::SplitIntoSubgraphs() { TF_RETURN_IF_ERROR(BuildControlFlowInfo(subgraph.GetGraph(), &dummy)); } + if (VLOG_IS_ON(1)) { + // Dump subgraphs. + for (auto& entry : subgraphs_) { + dump_graph::DumpGraphToFile( + strings::StrCat("encapsulate_subgraphs_subgraph_", entry.first), + *entry.second.GetGraph(), library); + } + } + return s; } @@ -1936,6 +1945,8 @@ Status Encapsulator::DoStaticShapeInferenceForOutsideCompilationSend( // continue. TensorShapeProto proto; context->ShapeHandleToProto(shape, &proto); + VLOG(2) << "Node " << src_node->name() + << " has known shape: " << proto.DebugString(); if (dummy_node_images.find(src_node) == dummy_node_images.end()) { dummy_node_images[src_node] = AddDummyShapedNode(src_node, src_port, control_flow_info, @@ -1953,6 +1964,8 @@ Status Encapsulator::DoStaticShapeInferenceForOutsideCompilationSend( if (VLOG_IS_ON(2)) { TensorShapeProto proto; context->ShapeHandleToProto(shape, &proto); + VLOG(2) << "Node " << src_node->name() + << " has unknown shape: " << proto.DebugString(); } stack.push_back({src_node, false}); } @@ -2195,6 +2208,23 @@ Status Encapsulator::FindClusterDependencies() { } } } + if (VLOG_IS_ON(2)) { + // Print debug information. + VLOG(2) << "node_ancestors_map:"; + for (const auto& node_iter : node_ancestors_map) { + VLOG(2) << "\t" << node_iter.first->name() << ": subgraph = '" + << node_iter.second.subgraph + << "', outside_compilation_cluster = '" + << node_iter.second.outside_compilation_cluster + << "', ancestor_clusters: " + << (node_iter.second.ancestor_clusters.empty() ? "(empty)" : ""); + for (const auto& cluster_iter : node_iter.second.ancestor_clusters) { + VLOG(2) << "\t\tsubgraph = '" << cluster_iter.subgraph + << "', outside_compilation_cluster = '" + << cluster_iter.outside_compilation_cluster << "'"; + } + } + } return Status::OK(); } @@ -2402,7 +2432,7 @@ Status EncapsulateSubgraphsInFunctions( std::move(outside_compilation_attribute), &graph_in); TF_RETURN_IF_ERROR(encapsulator.FindClusterDependencies()); - TF_RETURN_IF_ERROR(encapsulator.SplitIntoSubgraphs()); + TF_RETURN_IF_ERROR(encapsulator.SplitIntoSubgraphs(library)); TF_RETURN_IF_ERROR(encapsulator.BuildFunctionDefs( rewrite_subgraph_fn, reuse_existing_functions, library)); @@ -2451,7 +2481,7 @@ Status EncapsulateSubgraphsPass::Run( const GraphOptimizationPassOptions& options) { VLOG(1) << "EncapsulateSubgraphsPass::Run"; if (VLOG_IS_ON(1)) { - dump_graph::DumpGraphToFile("before_encapsulate_subgraphs", **options.graph, + dump_graph::DumpGraphToFile("encapsulate_subgraphs_before", **options.graph, options.flib_def); } @@ -2534,7 +2564,7 @@ Status EncapsulateSubgraphsPass::Run( "EncapsulateSubgraphsPass failed"); if (VLOG_IS_ON(1)) { - dump_graph::DumpGraphToFile("after_encapsulate_subgraphs", *graph_out, + dump_graph::DumpGraphToFile("encapsulate_subgraphs_after", *graph_out, options.flib_def); } diff --git a/tensorflow/compiler/jit/kernels/xla_launch_op.cc b/tensorflow/compiler/jit/kernels/xla_launch_op.cc index 338fb5a6f0..c5d0e4f8fb 100644 --- a/tensorflow/compiler/jit/kernels/xla_launch_op.cc +++ b/tensorflow/compiler/jit/kernels/xla_launch_op.cc @@ -51,7 +51,11 @@ XlaLocalLaunchBase::XlaLocalLaunchBase(OpKernelConstruction* ctx, if (device_type_ == DeviceType(DEVICE_CPU)) { platform_id_ = se::host::kHostPlatformId; } else if (device_type_ == DeviceType(DEVICE_GPU)) { - platform_id_ = se::cuda::kCudaPlatformId; + platform_id_ = ctx->device() + ->tensorflow_gpu_device_info() + ->stream->parent() + ->platform() + ->id(); } else { platform_id_ = nullptr; } diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index 8c3882116d..6558f14dd6 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -21,6 +21,7 @@ limitations under the License. #include <unordered_map> #include <unordered_set> +#include "tensorflow/compiler/jit/deadness_analysis.h" #include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/jit/graphcycles/graphcycles.h" #include "tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h" @@ -28,6 +29,7 @@ limitations under the License. #include "tensorflow/compiler/jit/xla_cluster_util.h" #include "tensorflow/compiler/tf2xla/dump_graph.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/framework/graph_def_util.h" #include "tensorflow/core/framework/memory_types.h" @@ -462,17 +464,27 @@ Status MarkForCompilationPass::Run( VLOG(1) << "flags->tf_xla_fusion_only = " << flags->tf_xla_fusion_only; const FunctionLibraryDefinition* fld = options.flib_def; - auto is_compilable = [global_jit_level, cpu_global_jit, fusion_only, fld]( - const Node* node, const DeviceType& device_type) { + std::unique_ptr<DeadnessAnalysis> deadness; + { + XLA_SCOPED_LOGGING_TIMER_LEVEL("DeadnessAnalysis", 0); + TF_RETURN_IF_ERROR(DeadnessAnalysis::Run(**options.graph, &deadness)); + } + + auto is_compilable = [&](const Node* node, const DeviceType& device_type) { const XlaOpRegistry::DeviceRegistration* registration; if (!XlaOpRegistry::GetCompilationDevice(device_type.type(), ®istration)) { return false; } + // TODO(b/111570009): This bailout for ControlTrigger is probably not + // needed. + // // Don't compile control trigger nodes. We won't preserve their deadness // semantics correctly, so it's safest not to compile them. - if (node->IsControlTrigger()) return false; + if (node->IsControlTrigger()) { + return false; + } // If this device requires a JIT, we must say yes. if (registration->requires_compilation) return true; @@ -485,6 +497,14 @@ Status MarkForCompilationPass::Run( status = fld->GetAttr(*node, kXlaCompileAttr, &compile); if (status.ok()) return compile; + // If inputs to `node` can have conflicting deadness (i.e. some are alive + // and some are dead) then don't compile it. XLA cannot represent the + // deadness semantics of these nodes correctly and auto-clustering these + // nodes can cause deadness propagate to nodes that should be live. + if (node->IsMerge() || deadness->HasInputsWithMismatchingDeadness(*node)) { + return false; + } + // Check for fusable ops only if requested. if (global_jit_level > 0 && fusion_only && !IsXlaFusable(node->def())) { return false; diff --git a/tensorflow/compiler/jit/xla_device_ops.h b/tensorflow/compiler/jit/xla_device_ops.h index 134dcc1bb5..6adda327f1 100644 --- a/tensorflow/compiler/jit/xla_device_ops.h +++ b/tensorflow/compiler/jit/xla_device_ops.h @@ -77,9 +77,7 @@ class XlaAssignVariableOp : public AsyncOpKernel { ConstantOp); \ REGISTER_KERNEL_BUILDER( \ Name("Identity").Device(DEVICE).TypeConstraint("T", TYPES), IdentityOp); \ - REGISTER_KERNEL_BUILDER( \ - Name("IdentityN").Device(DEVICE).TypeConstraint("T", TYPES), \ - IdentityNOp); \ + REGISTER_KERNEL_BUILDER(Name("IdentityN").Device(DEVICE), IdentityNOp); \ REGISTER_KERNEL_BUILDER(Name("Placeholder").Device(DEVICE), PlaceholderOp); \ REGISTER_KERNEL_BUILDER(Name("PlaceholderV2").Device(DEVICE), \ PlaceholderOp); \ diff --git a/tensorflow/compiler/jit/xla_fusion_optimizer.cc b/tensorflow/compiler/jit/xla_fusion_optimizer.cc index 74257b09a8..b70e1cf52b 100644 --- a/tensorflow/compiler/jit/xla_fusion_optimizer.cc +++ b/tensorflow/compiler/jit/xla_fusion_optimizer.cc @@ -20,6 +20,7 @@ limitations under the License. #include <unordered_map> #include <unordered_set> +#include "tensorflow/compiler/jit/deadness_analysis.h" #include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/jit/graphcycles/graphcycles.h" #include "tensorflow/compiler/jit/union_find.h" @@ -146,6 +147,9 @@ Status XlaFusionOptimizer::Optimize(grappler::Cluster* cluster, TF_RETURN_IF_ERROR( ImportGraphDef(options, item.graph, &graph, &shape_refiner)); + std::unique_ptr<DeadnessAnalysis> deadness; + TF_RETURN_IF_ERROR(DeadnessAnalysis::Run(graph, &deadness)); + // Collect nodes that can be fused via XLA, while ignoring those that // explicitly ask for XLA: (*) nodes that are marked to be compiled // explicitly. (*) nodes assigned to XLA device. @@ -185,6 +189,14 @@ Status XlaFusionOptimizer::Optimize(grappler::Cluster* cluster, continue; } + // If inputs to `node` can have conflicting deadness (i.e. some are alive + // and some are dead) then don't compile it. XLA cannot represent the + // deadness semantics of these nodes correctly and auto-clustering these + // nodes can cause deadness propagate to nodes that should be live. + if (node->IsMerge() || deadness->HasInputsWithMismatchingDeadness(*node)) { + continue; + } + compilation_candidates.insert(node); } diff --git a/tensorflow/compiler/tests/sort_ops_test.py b/tensorflow/compiler/tests/sort_ops_test.py index 9e2ef964a1..7ff01be3cb 100644 --- a/tensorflow/compiler/tests/sort_ops_test.py +++ b/tensorflow/compiler/tests/sort_ops_test.py @@ -88,6 +88,38 @@ class XlaSortOpTest(xla_test.XLATestCase): topk, [x.astype(dtype)], expected=[x[indices].astype(dtype), indices]) + def testTopK2D(self): + # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU. + if self.device in ["XLA_CPU", "XLA_GPU"]: + return + + supported_types = set( + [dtypes.bfloat16.as_numpy_dtype, np.float32, np.int32, np.uint32]) + for dtype in supported_types.intersection(self.numeric_types): + # Use small input size for bfloat16. Otherwise, we'll get duplicate values + # after conversion to bfloat16, so the possible resulting index array is + # no longer unique. + if dtype == dtypes.bfloat16.as_numpy_dtype: + array_size = 10 + k_options = [0, 1, 2, 10] + else: + array_size = 200 * 1000 + k_options = [0, 1, 2, 10, 20, 100, 1000, 200 * 1000] + batch = 16 + for x in [np.arange(batch * array_size)]: + np.random.shuffle(x) + x = np.reshape(x, [batch, array_size]) + for k in k_options: + indices = x.argsort(axis=1)[::, -1:-k - 1:-1] + expected = np.sort(x, axis=1)[::, -1:-k - 1:-1] + + def topk(v, k=k): + return nn_ops.top_k(v, k=k, sorted=True) + + self._assertOpOutputMatchesExpected( + topk, [x.astype(dtype)], + expected=[expected.astype(dtype), indices]) + def testTopKZeros(self): """Tests that positive and negative zeros sort correctly.""" # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU. diff --git a/tensorflow/compiler/tf2xla/kernels/if_op.cc b/tensorflow/compiler/tf2xla/kernels/if_op.cc index f5fcf3cacd..e2160feba0 100644 --- a/tensorflow/compiler/tf2xla/kernels/if_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/if_op.cc @@ -246,6 +246,7 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) { VLOG(1) << "Done building If"; } +REGISTER_XLA_OP(Name("If").AllowResourceTypes(), XlaIfOp); REGISTER_XLA_OP(Name("XlaIf").AllowResourceTypes(), XlaIfOp); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/topk_op.cc b/tensorflow/compiler/tf2xla/kernels/topk_op.cc index 1ddcb08c8e..82d4a69777 100644 --- a/tensorflow/compiler/tf2xla/kernels/topk_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/topk_op.cc @@ -41,33 +41,35 @@ class TopKOp : public XlaOpKernel { OP_REQUIRES(context, input_shape.dims() >= 1, errors::InvalidArgument("input must be >= 1-D, got shape ", input_shape.DebugString())); + int last_dim = input_shape.dims() - 1; + int last_dim_size = input_shape.dim_size(last_dim); OP_REQUIRES( - context, input_shape.dim_size(input_shape.dims() - 1) >= k, + context, last_dim_size >= k, errors::InvalidArgument("input must have at least k columns. Had ", - input_shape.dim_size(input_shape.dims() - 1), - ", needed ", k)); - - OP_REQUIRES( - context, input_shape.dims() == 1, - errors::Unimplemented("TopK is implemented for 1-D inputs, got shape ", - input_shape.DebugString())); + last_dim_size, ", needed ", k)); xla::XlaBuilder* const b = context->builder(); - if (input_shape.dim_size(0) < k) { - k = input_shape.dim_size(0); + if (last_dim_size < k) { + k = last_dim_size; } const xla::XlaOp input = context->Input(0); - xla::XlaOp iota_s32 = xla::Iota(b, xla::S32, input_shape.dim_size(0)); - xla::XlaOp sort_result = xla::Sort(xla::Neg(input), iota_s32); + + xla::XlaOp iota_s32 = xla::Iota(b, xla::S32, last_dim_size); + auto input_dims = input_shape.dim_sizes(); + std::vector<int64> broadcast_dims(input_dims.begin(), input_dims.end() - 1); + xla::XlaOp broadcast_s32 = xla::Broadcast(iota_s32, broadcast_dims); + xla::XlaOp sort_result = xla::Sort(xla::Neg(input), broadcast_s32); + + std::vector<int64> start_indices(input_shape.dims(), 0); + std::vector<int64> limit_indices(input_dims.begin(), input_dims.end()); + limit_indices[last_dim] = k; + std::vector<int64> strides(input_shape.dims(), 1); + xla::XlaOp values = - xla::Neg(xla::Slice(xla::GetTupleElement(sort_result, 0), - /*start_indices=*/{0}, - /*limit_indices=*/{k}, - /*strides=*/{1})); + xla::Neg(xla::Slice(xla::GetTupleElement(sort_result, 0), start_indices, + limit_indices, strides)); xla::XlaOp indices = xla::Slice(xla::GetTupleElement(sort_result, 1), - /*start_indices=*/{0}, - /*limit_indices=*/{k}, - /*strides=*/{1}); + start_indices, limit_indices, strides); context->SetOutput(0, values); context->SetOutput(1, indices); } diff --git a/tensorflow/compiler/tf2xla/kernels/while_op.cc b/tensorflow/compiler/tf2xla/kernels/while_op.cc index 9413a30a6c..009fdd81b2 100644 --- a/tensorflow/compiler/tf2xla/kernels/while_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/while_op.cc @@ -299,6 +299,7 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { VLOG(1) << "Done building while loop"; } +REGISTER_XLA_OP(Name("While").AllowResourceTypes(), XlaWhileOp); REGISTER_XLA_OP(Name("XlaWhile").AllowResourceTypes(), XlaWhileOp); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index 319cbc74e9..cb47581e36 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -422,16 +422,18 @@ Status BuildComputation( // assignment will be placed on this value, which will cause the resource // update to be returned from the same device that provided the resource. handle = xla::GetTupleElement(xla::Tuple(builder, {handle}), 0); - elems.push_back(handle); } } *num_computation_outputs = elems.size(); - // Builds the XLA computation. - if (always_return_tuple || elems.size() != 1) { - xla::Tuple(builder, elems); + // Builds the XLA computation. We *always* form a tuple here to ensure that + // the output value is the last thing added into the XLA computation, even + // if there is only one output value. + auto tuple = xla::Tuple(builder, elems); + if (!always_return_tuple && elems.size() == 1) { + xla::GetTupleElement(tuple, 0); } builder->ClearOpMetadata(); diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc index 6f76816a86..2fb93be01d 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc @@ -228,6 +228,58 @@ TEST_F(XlaCompilerTest, Simple) { EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal)); } +// Tests compilation of a graph where the _Retval node is not necessarily last +// amongst the graph nodes in construction order, and always_return_tuple is +// false. Regression test for bug where the wrong value was returned. +TEST_F(XlaCompilerTest, OutOfOrderGraph) { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0); + auto b = ops::_Arg(scope.WithOpName("B"), DT_INT32, 1); + // The _Retval node is not last in construction order. + auto d = ops::_Retval(scope.WithOpName("D"), a, 0); + auto c = ops::Add(scope.WithOpName("C"), a, b); + + std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); + TF_ASSERT_OK(scope.ToGraph(graph.get())); + + // Builds a description of the arguments. + std::vector<XlaCompiler::Argument> args(2); + args[0].kind = XlaCompiler::Argument::kParameter; + args[0].type = DT_INT32; + args[0].shape = TensorShape({2}); + args[1].kind = XlaCompiler::Argument::kParameter; + args[1].type = DT_INT32; + args[1].shape = TensorShape({2}); + + // Compiles the graph. + XlaCompiler compiler(DefaultOptions()); + + XlaCompiler::CompileOptions compile_options; + compile_options.always_return_tuple = false; + XlaCompiler::CompilationResult result; + TF_ASSERT_OK(compiler.CompileGraph(compile_options, "add", std::move(graph), + args, &result)); + + // Tests that the generated computation works. + std::unique_ptr<xla::Literal> param0_literal = + xla::LiteralUtil::CreateR1<int32>({7, 42}); + std::unique_ptr<xla::Literal> param1_literal = + xla::LiteralUtil::CreateR1<int32>({-3, 101}); + std::unique_ptr<xla::GlobalData> param0_data = + client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + std::unique_ptr<xla::GlobalData> param1_data = + client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); + + std::unique_ptr<xla::GlobalData> actual = + client_ + ->Execute(*result.computation, {param0_data.get(), param1_data.get()}) + .ConsumeValueOrDie(); + std::unique_ptr<xla::Literal> actual_literal = + client_->Transfer(*actual).ConsumeValueOrDie(); + + EXPECT_TRUE(xla::LiteralTestUtil::Equal(*param0_literal, *actual_literal)); +} + TEST_F(XlaCompilerTest, HasSaneErrorOnNonCompileTimeConstantInputToReshape) { // Builds a graph that adds reshapes a tensor, but with the shape not // statically known. diff --git a/tensorflow/compiler/xla/python/local_computation_builder.cc b/tensorflow/compiler/xla/python/local_computation_builder.cc index be55d50b23..66b1c08a39 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.cc +++ b/tensorflow/compiler/xla/python/local_computation_builder.cc @@ -614,6 +614,9 @@ _FORWARD_BINOP(Min) _FORWARD_BINOP(And) _FORWARD_BINOP(Or) _FORWARD_BINOP(Xor) +_FORWARD_BINOP(ShiftLeft) +_FORWARD_BINOP(ShiftRightArithmetic) +_FORWARD_BINOP(ShiftRightLogical) _FORWARD_UNOP(Not) _FORWARD_UNOP(Abs) _FORWARD_UNOP(Exp) diff --git a/tensorflow/compiler/xla/python/local_computation_builder.h b/tensorflow/compiler/xla/python/local_computation_builder.h index 690ff277e8..17ad044578 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.h +++ b/tensorflow/compiler/xla/python/local_computation_builder.h @@ -333,6 +333,9 @@ class LocalComputationBuilder { _FORWARD_BINOP(And) _FORWARD_BINOP(Or) _FORWARD_BINOP(Xor) + _FORWARD_BINOP(ShiftLeft) + _FORWARD_BINOP(ShiftRightArithmetic) + _FORWARD_BINOP(ShiftRightLogical) _FORWARD_UNOP(Not) _FORWARD_UNOP(Abs) _FORWARD_UNOP(Exp) diff --git a/tensorflow/compiler/xla/python/local_computation_builder.i b/tensorflow/compiler/xla/python/local_computation_builder.i index afdea88cb7..42bf76e5d8 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.i +++ b/tensorflow/compiler/xla/python/local_computation_builder.i @@ -989,6 +989,9 @@ tensorflow::ImportNumpy(); %unignore xla::swig::LocalComputationBuilder::And; %unignore xla::swig::LocalComputationBuilder::Or; %unignore xla::swig::LocalComputationBuilder::Xor; +%unignore xla::swig::LocalComputationBuilder::ShiftLeft; +%unignore xla::swig::LocalComputationBuilder::ShiftRightArithmetic; +%unignore xla::swig::LocalComputationBuilder::ShiftRightLogical; %unignore xla::swig::LocalComputationBuilder::Not; %unignore xla::swig::LocalComputationBuilder::Abs; %unignore xla::swig::LocalComputationBuilder::Exp; diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py index e2b6eaa096..f93d7bda2d 100644 --- a/tensorflow/compiler/xla/python/xla_client.py +++ b/tensorflow/compiler/xla/python/xla_client.py @@ -125,6 +125,9 @@ _BINARY_OPS = [ 'Or', 'Xor', 'Pow', + 'ShiftLeft', + 'ShiftRightArithmetic', + 'ShiftRightLogical', ] diff --git a/tensorflow/compiler/xla/python/xla_client_test.py b/tensorflow/compiler/xla/python/xla_client_test.py index 0564ddcb85..93177aa647 100644 --- a/tensorflow/compiler/xla/python/xla_client_test.py +++ b/tensorflow/compiler/xla/python/xla_client_test.py @@ -171,6 +171,24 @@ class ComputationsWithConstantsTest(LocalComputationTest): c.Constant(NumpyArrayF32([[1, -1, 1], [-1, 1, -1]]))) self._ExecuteAndCompareClose(c, expected=[[2, 1, 4], [3, 6, 5]]) + def testShiftLeft(self): + c = self._NewComputation() + c.ShiftLeft(c.Constant(NumpyArrayS32([3])), + c.Constant(NumpyArrayS32([2]))) + self._ExecuteAndCompareClose(c, expected=[12]) + + def testShiftRightArithmetic(self): + c = self._NewComputation() + c.ShiftRightArithmetic(c.Constant(NumpyArrayS32([-2])), + c.Constant(NumpyArrayS32([1]))) + self._ExecuteAndCompareClose(c, expected=[-1]) + + def testShiftRightLogical(self): + c = self._NewComputation() + c.ShiftRightLogical(c.Constant(NumpyArrayS32([-1])), + c.Constant(NumpyArrayS32([1]))) + self._ExecuteAndCompareClose(c, expected=[2**31 - 1]) + def testGetProto(self): c = self._NewComputation() c.Add( diff --git a/tensorflow/compiler/xla/shape_tree.h b/tensorflow/compiler/xla/shape_tree.h index c74dd648ad..4aacc87b78 100644 --- a/tensorflow/compiler/xla/shape_tree.h +++ b/tensorflow/compiler/xla/shape_tree.h @@ -44,6 +44,10 @@ struct ShapeTreeNode { // Data corresponding to this node. std::pair<ShapeIndex, T> data; + // Children of this node, as indices into the container's nodes_ array. + std::vector<size_t> children; + + // Tells whether this is a leaf node. bool is_leaf = true; explicit ShapeTreeNode(ShapeIndex index) @@ -52,20 +56,6 @@ struct ShapeTreeNode { : data(std::move(index), std::move(data)) {} }; -// Internal representation of an index table entry. -struct IndexTableEntry { - // Index of the node in the ShapeTreeNode vector. - uint32 index; - // Index of the first child in a IndexTableEntry vector. In the index - // table all children entries for a given node will be placed next to each - // other. This allows us to use a single field to index them. - uint32 children_start; -#ifndef NDEBUG - // Number of children, used for bounds checking. - uint32 children_count; -#endif -}; - } // namespace internal template <typename ContainerType, typename IteratorType, typename ValueType> @@ -94,7 +84,6 @@ template <typename T> class ShapeTree { public: using Node = internal::ShapeTreeNode<T>; - using Index = internal::IndexTableEntry; // Default constructor creates a tree with a nil shape (i.e. an empty tuple). ShapeTree() : ShapeTree(ShapeUtil::MakeNil()) {} @@ -278,12 +267,11 @@ class ShapeTree { private: // Initialize node->children based on 'shape'. All children are assigned the // the given 'init_value'. - void InitChildren(const Shape& shape, const T& init_value, Node* node, - Index* index); + void InitChildren(const Shape& shape, const T& init_value, Node* node); // Initialize node->children based on 'shape'. All children have // default-constructed data values. - void InitChildren(const Shape& shape, Node* node, Index* index); + void InitChildren(const Shape& shape, Node* node); // Returns the number of subshapes, including interior nodes, in shape. int64 CountSubshapes(const Shape& shape); @@ -303,9 +291,6 @@ class ShapeTree { // The nodes in this shape tree. std::vector<Node> nodes_; - // Index table for node lookups. - std::vector<Index> index_table_; - // If we own our Shape, this field contains it, and shape_ is a pointer into // here. Otherwise if we don't own our shape, this is nullptr. std::shared_ptr<Shape> shape_storage_; @@ -388,74 +373,36 @@ int64 ShapeTree<T>::CountSubshapes(const Shape& shape) { template <typename T> void ShapeTree<T>::InitChildren(const Shape& shape, const T& init_value, - Node* node, Index* index) { + Node* node) { if (ShapeUtil::IsTuple(shape)) { const int64 size = ShapeUtil::TupleElementCount(shape); -#ifndef NDEBUG - index->children_count = size; -#endif + node->children.reserve(size); node->is_leaf = false; ShapeIndex shape_index = node->data.first; shape_index.push_back(0); - - // At the end of the index_table, reserve a continuous space to hold the - // children of current node. In order to enforce the invariant that all - // children of a given node are placed together, we need to do the - // reservation before we recurse into any of its children. - int64 children_start_position = index_table_.size(); - index_table_.resize(index_table_.size() + size); - for (int i = 0; i < size; ++i) { shape_index[shape_index.size() - 1] = i; - index_table_[children_start_position + i].index = nodes_.size(); - // The first child of the node in the index table is placed at the end of - // the table. - index_table_[children_start_position + i].children_start = - index_table_.size(); + node->children.push_back(nodes_.size()); nodes_.emplace_back(shape_index, init_value); - InitChildren(shape.tuple_shapes(i), init_value, &nodes_.back(), - &index_table_[children_start_position + i]); + InitChildren(shape.tuple_shapes(i), init_value, &nodes_.back()); } - } else { -#ifndef NDEBUG - index->children_count = 0; -#endif } } template <typename T> -void ShapeTree<T>::InitChildren(const Shape& shape, Node* node, Index* index) { +void ShapeTree<T>::InitChildren(const Shape& shape, Node* node) { if (ShapeUtil::IsTuple(shape)) { const int64 size = ShapeUtil::TupleElementCount(shape); -#ifndef NDEBUG - index->children_count = size; -#endif + node->children.reserve(size); node->is_leaf = false; ShapeIndex shape_index = node->data.first; shape_index.push_back(0); - - // At the end of the index_table, reserve a continuous space to hold the - // children of current node. In order to enforce the invariant that all - // children of a given node are placed together, we need to do the - // reservation before we recurse into any of its children. - int64 children_start_position = index_table_.size(); - index_table_.resize(index_table_.size() + size); - for (int i = 0; i < size; ++i) { shape_index[shape_index.size() - 1] = i; - index_table_[children_start_position + i].index = nodes_.size(); - // The first child of the node in the index table is placed at the end of - // the table. - index_table_[children_start_position + i].children_start = - index_table_.size(); + node->children.push_back(nodes_.size()); nodes_.emplace_back(shape_index); - InitChildren(shape.tuple_shapes(i), &nodes_.back(), - &index_table_[children_start_position + i]); + InitChildren(shape.tuple_shapes(i), &nodes_.back()); } - } else { -#ifndef NDEBUG - index->children_count = 0; -#endif } } @@ -466,36 +413,24 @@ ShapeTree<T>::ShapeTree(Shape shape) // The shape_ field is just used to hold the structure of the shape. // It should not be relied upon to store layout information. LayoutUtil::ClearLayout(shape_storage_.get()); - const int64 count = CountSubshapes(*shape_); - nodes_.reserve(count); + nodes_.reserve(CountSubshapes(*shape_)); nodes_.emplace_back(ShapeIndex{}); - - index_table_.reserve(count); - index_table_.emplace_back(Index{0, 1}); - InitChildren(*shape_, &nodes_[0], &index_table_[0]); + InitChildren(*shape_, &nodes_[0]); } template <typename T> ShapeTree<T>::ShapeTree(const Shape* shape) : shape_(shape) { - const int64 count = CountSubshapes(*shape_); - nodes_.reserve(count); + nodes_.reserve(CountSubshapes(*shape_)); nodes_.emplace_back(ShapeIndex{}); - - index_table_.reserve(count); - index_table_.emplace_back(Index{0, 1}); - InitChildren(*shape_, &nodes_[0], &index_table_[0]); + InitChildren(*shape_, &nodes_[0]); } template <typename T> ShapeTree<T>::ShapeTree(const std::shared_ptr<Shape>& shape) : shape_storage_(shape), shape_(shape_storage_.get()) { - const int64 count = CountSubshapes(*shape_); - nodes_.reserve(count); + nodes_.reserve(CountSubshapes(*shape_)); nodes_.emplace_back(ShapeIndex{}); - - index_table_.reserve(count); - index_table_.emplace_back(Index{0, 1}); - InitChildren(*shape_, &nodes_[0], &index_table_[0]); + InitChildren(*shape_, &nodes_[0]); } template <typename T> @@ -505,38 +440,26 @@ ShapeTree<T>::ShapeTree(Shape shape, const T& init_value) // The shape_ field is just used to hold the structure of the shape. // It should not be relied upon to store layout information. LayoutUtil::ClearLayout(shape_storage_.get()); - const int64 count = CountSubshapes(*shape_); - nodes_.reserve(count); + nodes_.reserve(CountSubshapes(*shape_)); nodes_.emplace_back(ShapeIndex{}, init_value); - - index_table_.reserve(count); - index_table_.emplace_back(Index{0, 1}); - InitChildren(*shape_, init_value, &nodes_[0], &index_table_[0]); + InitChildren(*shape_, init_value, &nodes_[0]); } template <typename T> ShapeTree<T>::ShapeTree(const Shape* shape, const T& init_value) : shape_(shape) { - const int64 count = CountSubshapes(*shape_); - nodes_.reserve(count); + nodes_.reserve(CountSubshapes(*shape_)); nodes_.emplace_back(ShapeIndex{}, init_value); - - index_table_.reserve(count); - index_table_.emplace_back(Index{0, 1}); - InitChildren(*shape_, init_value, &nodes_[0], &index_table_[0]); + InitChildren(*shape_, init_value, &nodes_[0]); } template <typename T> ShapeTree<T>::ShapeTree(const std::shared_ptr<Shape>& shape, const T& init_value) : shape_storage_(shape), shape_(shape_storage_.get()) { - const int64 count = CountSubshapes(*shape_); - nodes_.reserve(count); + nodes_.reserve(CountSubshapes(*shape_)); nodes_.emplace_back(ShapeIndex{}, init_value); - - index_table_.reserve(count); - index_table_.emplace_back(Index{0, 1}); - InitChildren(*shape_, init_value, &nodes_[0], &index_table_[0]); + InitChildren(*shape_, init_value, &nodes_[0]); } template <typename T> @@ -551,16 +474,13 @@ T* ShapeTree<T>::mutable_element(ShapeIndexView index) { template <typename T> internal::ShapeTreeNode<T>* ShapeTree<T>::Lookup(ShapeIndexView index) { - Index* iter = &index_table_[0]; + Node* node = &nodes_[0]; for (const int64 i : index) { CHECK_GE(i, 0); -#ifndef NDEBUG - CHECK_LT(i, iter->children_count); -#endif - iter = &index_table_[iter->children_start + i]; + CHECK_LT(i, node->children.size()); + node = &nodes_[node->children[i]]; } - - return &nodes_[iter->index]; + return node; } template <typename T> diff --git a/tensorflow/compiler/xla/shape_tree_test.cc b/tensorflow/compiler/xla/shape_tree_test.cc index 4391078b64..51de82e957 100644 --- a/tensorflow/compiler/xla/shape_tree_test.cc +++ b/tensorflow/compiler/xla/shape_tree_test.cc @@ -227,16 +227,14 @@ TEST_F(ShapeTreeTest, NestedTupleShape) { TEST_F(ShapeTreeTest, InvalidIndexingTuple) { ShapeTree<int> shape_tree{tuple_shape_}; -#ifndef NDEBUG + EXPECT_DEATH(shape_tree.element({4}), ""); -#endif } TEST_F(ShapeTreeTest, InvalidIndexingNestedTuple) { ShapeTree<int> shape_tree{nested_tuple_shape_}; -#ifndef NDEBUG + EXPECT_DEATH(shape_tree.element({0, 0}), ""); -#endif } TEST_F(ShapeTreeTest, ShapeTreeOfNonCopyableType) { @@ -604,15 +602,12 @@ void BM_Iterate(int iters, int depth, int fan_out) { } } -#define BENCHMARK_WITH_ARGS(name) \ - BENCHMARK(name)->ArgPair(2, 8)->ArgPair(1, 1000) - -BENCHMARK_WITH_ARGS(BM_Construct); -BENCHMARK_WITH_ARGS(BM_ConstructUnowned); -BENCHMARK_WITH_ARGS(BM_Copy); -BENCHMARK_WITH_ARGS(BM_Move); -BENCHMARK_WITH_ARGS(BM_ForEach); -BENCHMARK_WITH_ARGS(BM_Iterate); +BENCHMARK(BM_Construct)->ArgPair(2, 8); +BENCHMARK(BM_ConstructUnowned)->ArgPair(2, 8); +BENCHMARK(BM_Copy)->ArgPair(2, 8); +BENCHMARK(BM_Move)->ArgPair(2, 8); +BENCHMARK(BM_ForEach)->ArgPair(2, 8); +BENCHMARK(BM_Iterate)->ArgPair(2, 8); } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h index 83d15e8fe3..17c1d7b10a 100644 --- a/tensorflow/compiler/xla/shape_util.h +++ b/tensorflow/compiler/xla/shape_util.h @@ -31,7 +31,6 @@ limitations under the License. #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/threadpool.h" #include "tensorflow/core/lib/gtl/array_slice.h" -#include "tensorflow/core/lib/gtl/inlined_vector.h" #include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/platform/cpu_info.h" #include "tensorflow/core/platform/env.h" @@ -74,12 +73,10 @@ class ShapeIndex { // push_front is O(n^2), but shapes don't usually have a ton of dimensions. void push_front(int64 value) { indices_.insert(indices_.begin(), value); } - using container_type = gtl::InlinedVector<int64, 2>; - - container_type::const_iterator begin() const { return indices_.begin(); } - container_type::const_iterator end() const { return indices_.end(); } - container_type::iterator begin() { return indices_.begin(); } - container_type::iterator end() { return indices_.end(); } + std::vector<int64>::const_iterator begin() const { return indices_.begin(); } + std::vector<int64>::const_iterator end() const { return indices_.end(); } + std::vector<int64>::iterator begin() { return indices_.begin(); } + std::vector<int64>::iterator end() { return indices_.end(); } const int64* data() const { return indices_.data(); } @@ -100,7 +97,7 @@ class ShapeIndex { string ToString() const; private: - container_type indices_; + std::vector<int64> indices_; }; // A view into a ShapeIndex as above, with the cheap/easy ability to consume the diff --git a/tensorflow/contrib/autograph/converters/control_flow.py b/tensorflow/contrib/autograph/converters/control_flow.py index a25232f713..5a5a6ad63a 100644 --- a/tensorflow/contrib/autograph/converters/control_flow.py +++ b/tensorflow/contrib/autograph/converters/control_flow.py @@ -171,8 +171,8 @@ class ControlFlowTransformer(converter.Base): # actually has some return value as well. cond_results = None # TODO(mdan): This doesn't belong here; it's specific to the operator. - returned_from_body = templates.replace_as_expression('1') - returned_from_orelse = templates.replace_as_expression('1') + returned_from_body = templates.replace_as_expression('tf.constant(1)') + returned_from_orelse = templates.replace_as_expression('tf.constant(1)') body_name = self.ctx.namer.new_symbol('if_true', body_scope.referenced) orelse_name = self.ctx.namer.new_symbol('if_false', orelse_scope.referenced) diff --git a/tensorflow/contrib/autograph/converters/control_flow_test.py b/tensorflow/contrib/autograph/converters/control_flow_test.py index 6670b8a66f..ade3501426 100644 --- a/tensorflow/contrib/autograph/converters/control_flow_test.py +++ b/tensorflow/contrib/autograph/converters/control_flow_test.py @@ -31,7 +31,8 @@ class ControlFlowTest(converter_testing.TestCase): def assertTransformedResult(self, test_fn, inputs, expected): if not isinstance(inputs, tuple): inputs = (inputs,) - with self.converted(test_fn, control_flow, {}) as result: + with self.converted(test_fn, control_flow, {}, + constant_op.constant) as result: with self.test_session() as sess: self.assertEqual(sess.run(result.test_fn(*inputs)), expected) diff --git a/tensorflow/contrib/checkpoint/python/containers.py b/tensorflow/contrib/checkpoint/python/containers.py index 4d3d531299..242c1e8ba4 100644 --- a/tensorflow/contrib/checkpoint/python/containers.py +++ b/tensorflow/contrib/checkpoint/python/containers.py @@ -35,9 +35,9 @@ class UniqueNameTracker(data_structures.CheckpointableDataStructure): self.slotdeps = tf.contrib.checkpoint.UniqueNameTracker() slotdeps = self.slotdeps slots = [] - slots.append(slotdeps.track(tfe.Variable(3.), "x")) # Named "x" - slots.append(slotdeps.track(tfe.Variable(4.), "y")) - slots.append(slotdeps.track(tfe.Variable(5.), "x")) # Named "x_1" + slots.append(slotdeps.track(tf.Variable(3.), "x")) # Named "x" + slots.append(slotdeps.track(tf.Variable(4.), "y")) + slots.append(slotdeps.track(tf.Variable(5.), "x")) # Named "x_1" ``` """ diff --git a/tensorflow/contrib/cmake/tf_core_kernels.cmake b/tensorflow/contrib/cmake/tf_core_kernels.cmake index 844f62649d..7b892ba248 100644 --- a/tensorflow/contrib/cmake/tf_core_kernels.cmake +++ b/tensorflow/contrib/cmake/tf_core_kernels.cmake @@ -68,6 +68,7 @@ if(tensorflow_BUILD_CONTRIB_KERNELS) "${tensorflow_source_dir}/tensorflow/contrib/coder/kernels/range_coder_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/coder/kernels/range_coder_ops_util.cc" "${tensorflow_source_dir}/tensorflow/contrib/coder/ops/coder_ops.cc" + "${tensorflow_source_dir}/tensorflow/contrib/data/kernels/assert_next_dataset_op.cc" "${tensorflow_source_dir}/tensorflow/contrib/data/kernels/csv_dataset_op.cc" "${tensorflow_source_dir}/tensorflow/contrib/data/kernels/directed_interleave_dataset_op.cc" "${tensorflow_source_dir}/tensorflow/contrib/data/kernels/ignore_errors_dataset_op.cc" diff --git a/tensorflow/contrib/data/kernels/BUILD b/tensorflow/contrib/data/kernels/BUILD index 7b69e10441..566cbb246a 100644 --- a/tensorflow/contrib/data/kernels/BUILD +++ b/tensorflow/contrib/data/kernels/BUILD @@ -71,8 +71,19 @@ cc_library( ) cc_library( + name = "assert_next_dataset_op", + srcs = ["assert_next_dataset_op.cc"], + deps = [ + "//tensorflow/core:framework_headers_lib", + "//third_party/eigen3", + "@protobuf_archive//:protobuf_headers", + ], +) + +cc_library( name = "dataset_kernels", deps = [ + ":assert_next_dataset_op", ":csv_dataset_op", ":directed_interleave_dataset_op", ":ignore_errors_dataset_op", diff --git a/tensorflow/contrib/data/kernels/assert_next_dataset_op.cc b/tensorflow/contrib/data/kernels/assert_next_dataset_op.cc new file mode 100644 index 0000000000..95b8e1f7fd --- /dev/null +++ b/tensorflow/contrib/data/kernels/assert_next_dataset_op.cc @@ -0,0 +1,152 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include <map> + +#include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/framework/partial_tensor_shape.h" +#include "tensorflow/core/framework/tensor.h" + +namespace tensorflow { +namespace { + +// See documentation in ../ops/dataset_ops.cc for a high-level +// description of the following op. +class AssertNextDatasetOp : public UnaryDatasetOpKernel { + public: + explicit AssertNextDatasetOp(OpKernelConstruction* ctx) + : UnaryDatasetOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); + } + + protected: + void MakeDataset(OpKernelContext* ctx, DatasetBase* input, + DatasetBase** output) override { + std::vector<string> transformations; + OP_REQUIRES_OK(ctx, ParseVectorArgument<string>(ctx, "transformations", + &transformations)); + *output = + new Dataset(ctx, input, transformations, output_types_, output_shapes_); + } + + private: + class Dataset : public GraphDatasetBase { + public: + Dataset(OpKernelContext* ctx, const DatasetBase* input, + const std::vector<string>& transformations, + const DataTypeVector& output_types, + const std::vector<PartialTensorShape>& output_shapes) + : GraphDatasetBase(ctx), + input_(input), + transformations_(transformations), + output_types_(output_types), + output_shapes_(output_shapes) { + input_->Ref(); + } + + ~Dataset() override { input_->Unref(); } + + std::unique_ptr<IteratorBase> MakeIteratorInternal( + const string& prefix) const override { + return std::unique_ptr<IteratorBase>( + new Iterator({this, strings::StrCat(prefix, "::Assert")})); + } + + const DataTypeVector& output_dtypes() const override { + return output_types_; + } + const std::vector<PartialTensorShape>& output_shapes() const override { + return output_shapes_; + } + + string DebugString() const override { + return "AssertNextDatasetOp::Dataset"; + } + + protected: + Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b, + Node** output) const override { + Node* input_graph_node = nullptr; + TF_RETURN_IF_ERROR(b->AddParentDataset(ctx, input_, &input_graph_node)); + Node* transformations_node = nullptr; + TF_RETURN_IF_ERROR(b->AddVector(transformations_, &transformations_node)); + TF_RETURN_IF_ERROR(b->AddDataset( + this, {input_graph_node, transformations_node}, output)); + return Status::OK(); + } + + private: + class Iterator : public DatasetIterator<Dataset> { + public: + explicit Iterator(const Params& params) + : DatasetIterator<Dataset>(params) {} + + Status Initialize(IteratorContext* ctx) override { + std::vector<string> tokens = + str_util::Split(prefix(), ':', str_util::SkipEmpty()); + if (dataset()->transformations_.size() > tokens.size() - 2) { + return errors::InvalidArgument( + "Asserted next ", dataset()->transformations_.size(), + " transformations but encountered only ", tokens.size() - 2, "."); + } + int n = tokens.size(); + for (size_t i = 0; i < dataset()->transformations_.size(); ++i) { + if (dataset()->transformations_[i] != tokens[n - 2 - i]) { + return errors::InvalidArgument( + "Asserted ", dataset()->transformations_[i], + " transformation at offset ", i, " but encountered ", + tokens[n - 2 - i], " transformation instead."); + } + } + return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); + } + + Status GetNextInternal(IteratorContext* ctx, + std::vector<Tensor>* out_tensors, + bool* end_of_sequence) override { + return input_impl_->GetNext(ctx, out_tensors, end_of_sequence); + } + + protected: + Status SaveInternal(IteratorStateWriter* writer) override { + TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_)); + return Status::OK(); + } + + Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { + TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_)); + return Status::OK(); + } + + private: + std::unique_ptr<IteratorBase> input_impl_; + }; + + const DatasetBase* input_; + const std::vector<string> transformations_; + const DataTypeVector output_types_; + const std::vector<PartialTensorShape> output_shapes_; + }; + + DataTypeVector output_types_; + std::vector<PartialTensorShape> output_shapes_; +}; + +REGISTER_KERNEL_BUILDER(Name("AssertNextDataset").Device(DEVICE_CPU), + AssertNextDatasetOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/contrib/data/kernels/csv_dataset_op.cc b/tensorflow/contrib/data/kernels/csv_dataset_op.cc index dadde705e1..f7e3ed886c 100644 --- a/tensorflow/contrib/data/kernels/csv_dataset_op.cc +++ b/tensorflow/contrib/data/kernels/csv_dataset_op.cc @@ -150,6 +150,7 @@ class CSVDatasetOp : public DatasetOpKernel { delim_(delim), na_value_(std::move(na_value)), use_compression_(!compression_type.empty()), + compression_type_(std::move(compression_type)), options_(options) {} std::unique_ptr<IteratorBase> MakeIteratorInternal( @@ -169,10 +170,45 @@ class CSVDatasetOp : public DatasetOpKernel { protected: Status AsGraphDefInternal(DatasetGraphDefBuilder* b, Node** output) const override { - // TODO(rachelim): Implement this - std::vector<Node*> input_tensors; - TF_RETURN_IF_ERROR(b->AddDataset(this, input_tensors, output)); - return errors::Unimplemented("CSVDataset: AsGraphDefInternal"); + Node* filenames = nullptr; + Node* compression_type = nullptr; + Node* buffer_size = nullptr; + Node* header = nullptr; + Node* delim = nullptr; + Node* use_quote_delim = nullptr; + Node* na_value = nullptr; + Node* select_cols = nullptr; + + std::vector<Node*> record_defaults; + record_defaults.reserve(record_defaults_.size()); + for (const Tensor& t : record_defaults_) { + Node* node; + TF_RETURN_IF_ERROR(b->AddTensor(t, &node)); + record_defaults.emplace_back(node); + } + + TF_RETURN_IF_ERROR(b->AddVector(filenames_, &filenames)); + TF_RETURN_IF_ERROR(b->AddScalar(compression_type_, &compression_type)); + TF_RETURN_IF_ERROR( + b->AddScalar(options_.input_buffer_size, &buffer_size)); + TF_RETURN_IF_ERROR(b->AddScalar(header_, &header)); + + string delim_string(1, delim_); + TF_RETURN_IF_ERROR(b->AddScalar(delim_string, &delim)); + TF_RETURN_IF_ERROR(b->AddScalar(use_quote_delim_, &use_quote_delim)); + TF_RETURN_IF_ERROR(b->AddScalar(na_value_, &na_value)); + TF_RETURN_IF_ERROR(b->AddVector(select_cols_, &select_cols)); + + TF_RETURN_IF_ERROR(b->AddDataset( + this, + {std::make_pair(0, filenames), std::make_pair(1, compression_type), + std::make_pair(2, buffer_size), std::make_pair(3, header), + std::make_pair(4, delim), std::make_pair(5, use_quote_delim), + std::make_pair(6, na_value), + std::make_pair(7, select_cols)}, // Single tensor inputs + {std::make_pair(8, record_defaults)}, // Tensor list inputs + {}, output)); + return Status::OK(); } private: @@ -224,14 +260,58 @@ class CSVDatasetOp : public DatasetOpKernel { protected: Status SaveInternal(IteratorStateWriter* writer) override { mutex_lock l(mu_); - // TODO(rachelim): Implement save - return errors::Unimplemented("CSVDataset: SaveInternal"); + TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("current_file_index"), + current_file_index_)); + // `input_stream_` is empty if + // 1. GetNext has not been called even once. + // 2. All files have been read and the iterator has been exhausted. + if (input_stream_ && num_buffer_reads_ > 0) { + TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("pos"), pos_)); + // If num_buffer_reads_ == 0, the buffer hasn't been filled even once. + TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("num_buffer_reads"), + num_buffer_reads_)); + } + return Status::OK(); } + Status RestoreInternal(IteratorContext* ctx, IteratorStateReader* reader) override { mutex_lock l(mu_); - // TODO(rachelim): Implement restore - return errors::Unimplemented("CSVDataset: RestoreInternal"); + ResetStreamsLocked(); + int64 current_file_index; + TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("current_file_index"), + ¤t_file_index)); + current_file_index_ = size_t(current_file_index); + // The keys "pos" and "num_buffer_reads" are written only if + // the iterator was saved with an open, partially read file. + if (reader->Contains(full_name("pos"))) { + int64 pos, num_buffer_reads; + TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("pos"), &pos)); + TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("num_buffer_reads"), + &num_buffer_reads)); + + TF_RETURN_IF_ERROR(SetupStreamsLocked(ctx->env())); + + num_buffer_reads_ = size_t(num_buffer_reads - 1); + + // Restores the most recently held buffer + Status s = input_stream_->SkipNBytes( + num_buffer_reads_ * dataset()->options_.input_buffer_size); + if (!s.ok() && !errors::IsOutOfRange(s)) { + // We might get out of range error here if the size of the file + // is not an exact multiple of the buffer size, and the last buffer + // read is < buffer_size. This is valid and we do not surface the + // error. + return s; + } + + Status s2 = FillBuffer(&buffer_); + if (!s2.ok() && !errors::IsOutOfRange(s2)) { + return s2; + } + pos_ = size_t(pos); + } + return Status::OK(); } private: @@ -533,6 +613,7 @@ class CSVDatasetOp : public DatasetOpKernel { Status FillBuffer(string* result) EXCLUSIVE_LOCKS_REQUIRED(mu_) { result->clear(); + ++num_buffer_reads_; Status s = input_stream_->ReadNBytes( dataset()->options_.input_buffer_size, result); @@ -712,6 +793,7 @@ class CSVDatasetOp : public DatasetOpKernel { } buffer_.clear(); pos_ = 0; + num_buffer_reads_ = 0; if (dataset()->header_) { // Read one line, but don't include it. Pass nullptrs as dummy // pointers to objects that shouldn't be invoked anyway @@ -737,6 +819,7 @@ class CSVDatasetOp : public DatasetOpKernel { string buffer_ GUARDED_BY(mu_); // Maintain our own buffer size_t pos_ GUARDED_BY( mu_); // Index into the buffer must be maintained between iters + size_t num_buffer_reads_ GUARDED_BY(mu_); std::shared_ptr<io::RandomAccessInputStream> random_access_input_stream_ GUARDED_BY(mu_); std::shared_ptr<io::InputStreamInterface> input_stream_ GUARDED_BY(mu_); @@ -755,6 +838,7 @@ class CSVDatasetOp : public DatasetOpKernel { const char delim_; const string na_value_; const bool use_compression_; + const string compression_type_; const io::ZlibCompressionOptions options_; }; // class Dataset diff --git a/tensorflow/contrib/data/ops/dataset_ops.cc b/tensorflow/contrib/data/ops/dataset_ops.cc index a623c27ff8..b5c6f2e241 100644 --- a/tensorflow/contrib/data/ops/dataset_ops.cc +++ b/tensorflow/contrib/data/ops/dataset_ops.cc @@ -177,4 +177,17 @@ display_name: A human-readable name for the threads that may be visible in some visualizations. )doc"); +REGISTER_OP("AssertNextDataset") + .Input("input_dataset: variant") + .Input("transformations: string") + .Output("handle: variant") + .Attr("output_types: list(type) >= 1") + .Attr("output_shapes: list(shape) >= 1") + .SetShapeFn([](shape_inference::InferenceContext* c) { + shape_inference::ShapeHandle unused; + // transformations should be a vector. + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused)); + return shape_inference::ScalarShape(c); + }); + } // namespace tensorflow diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD index 18457320b9..d372bed479 100644 --- a/tensorflow/contrib/data/python/kernel_tests/BUILD +++ b/tensorflow/contrib/data/python/kernel_tests/BUILD @@ -208,7 +208,6 @@ py_test( srcs_version = "PY2AND3", deps = [ "//tensorflow/contrib/data/python/ops:optimization", - "//tensorflow/core:protos_all_py", "//tensorflow/python:client_testlib", "//tensorflow/python:errors", "//tensorflow/python/data/ops:dataset_ops", diff --git a/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py index 21eebccd11..cfef40e192 100644 --- a/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py @@ -18,7 +18,6 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.data.python.ops import optimization -from tensorflow.core.framework import graph_pb2 from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import errors from tensorflow.python.platform import test @@ -26,41 +25,76 @@ from tensorflow.python.platform import test class OptimizeDatasetTest(test.TestCase): + def testAssertSuffix(self): + dataset = dataset_ops.Dataset.from_tensors(0).apply( + optimization.assert_next(["Map"])).map(lambda x: x) + iterator = dataset.make_one_shot_iterator() + get_next = iterator.get_next() + + with self.test_session() as sess: + self.assertEqual(0, sess.run(get_next)) + + def testAssertSuffixInvalid(self): + dataset = dataset_ops.Dataset.from_tensors(0).apply( + optimization.assert_next(["Whoops"])).map(lambda x: x) + iterator = dataset.make_one_shot_iterator() + get_next = iterator.get_next() + + with self.test_session() as sess: + with self.assertRaisesRegexp( + errors.InvalidArgumentError, + "Asserted Whoops transformation at offset 0 but encountered " + "Map transformation instead." + ): + sess.run(get_next) + + def testAssertSuffixShort(self): + dataset = dataset_ops.Dataset.from_tensors(0).apply( + optimization.assert_next(["Map", "Whoops"])).map(lambda x: x) + iterator = dataset.make_one_shot_iterator() + get_next = iterator.get_next() + + with self.test_session() as sess: + with self.assertRaisesRegexp( + errors.InvalidArgumentError, + "Asserted next 2 transformations but encountered only 1."): + sess.run(get_next) + def testDefaultOptimizations(self): - dataset = dataset_ops.Dataset.range(10).map(lambda x: x * x).batch( - 10).apply(optimization.optimize()) + dataset = dataset_ops.Dataset.range(10).apply( + optimization.assert_next( + ["Map", "Batch"])).map(lambda x: x * x).batch(10).apply( + optimization.optimize()) iterator = dataset.make_one_shot_iterator() get_next = iterator.get_next() with self.test_session() as sess: - graph = graph_pb2.GraphDef().FromString( - sess.run(dataset._as_serialized_graph())) self.assertAllEqual([x * x for x in range(10)], sess.run(get_next)) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) def testEmptyOptimizations(self): - dataset = dataset_ops.Dataset.range(10).map(lambda x: x * x).batch( - 10).apply(optimization.optimize([])) + dataset = dataset_ops.Dataset.range(10).apply( + optimization.assert_next( + ["Map", "Batch"])).map(lambda x: x * x).batch(10).apply( + optimization.optimize([])) iterator = dataset.make_one_shot_iterator() get_next = iterator.get_next() with self.test_session() as sess: - graph = graph_pb2.GraphDef().FromString( - sess.run(dataset._as_serialized_graph())) self.assertAllEqual([x * x for x in range(10)], sess.run(get_next)) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) def testOptimization(self): - dataset = dataset_ops.Dataset.range(10).map(lambda x: x * x).batch( - 10).apply(optimization.optimize(["map_and_batch_fusion"])) + dataset = dataset_ops.Dataset.range(10).apply( + optimization.assert_next( + ["MapAndBatch"])).map(lambda x: x * x).batch(10).apply( + optimization.optimize(["map_and_batch_fusion"])) iterator = dataset.make_one_shot_iterator() get_next = iterator.get_next() with self.test_session() as sess: - graph = graph_pb2.GraphDef().FromString( - sess.run(dataset._as_serialized_graph())) self.assertAllEqual([x * x for x in range(10)], sess.run(get_next)) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/BUILD b/tensorflow/contrib/data/python/kernel_tests/serialization/BUILD index 686788522a..3c3f23f9a9 100644 --- a/tensorflow/contrib/data/python/kernel_tests/serialization/BUILD +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/BUILD @@ -73,6 +73,20 @@ py_test( ) py_test( + name = "csv_dataset_serialization_test", + size = "small", + srcs = ["csv_dataset_serialization_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/contrib/data/python/ops:readers", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_ops", + ], +) + +py_test( name = "dataset_constructor_serialization_test", size = "medium", srcs = ["dataset_constructor_serialization_test.py"], diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/csv_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/csv_dataset_serialization_test.py new file mode 100644 index 0000000000..247f2046ea --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/csv_dataset_serialization_test.py @@ -0,0 +1,73 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the CsvDataset serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import gzip +import os + +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.contrib.data.python.ops import readers +from tensorflow.python.platform import test + + +class CsvDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def setUp(self): + self._num_cols = 7 + self._num_rows = 10 + self._num_epochs = 14 + self._num_outputs = self._num_rows * self._num_epochs + + inputs = [ + ",".join(str(self._num_cols * j + i) + for i in range(self._num_cols)) + for j in range(self._num_rows) + ] + contents = "\n".join(inputs).encode("utf-8") + + self._filename = os.path.join(self.get_temp_dir(), "file.csv") + self._compressed = os.path.join(self.get_temp_dir(), + "comp.csv") # GZip compressed + + with open(self._filename, "wb") as f: + f.write(contents) + with gzip.GzipFile(self._compressed, "wb") as f: + f.write(contents) + + def ds_func(self, **kwargs): + compression_type = kwargs.get("compression_type", None) + if compression_type == "GZIP": + filename = self._compressed + elif compression_type is None: + filename = self._filename + else: + raise ValueError("Invalid compression type:", compression_type) + + return readers.CsvDataset(filename, **kwargs).repeat(self._num_epochs) + + def testSerializationCore(self): + defs = [[0]] * self._num_cols + self.run_core_tests( + lambda: self.ds_func(record_defaults=defs, buffer_size=2), + lambda: self.ds_func(record_defaults=defs, buffer_size=12), + self._num_outputs) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/ops/optimization.py b/tensorflow/contrib/data/python/ops/optimization.py index cf89657226..018c5115e1 100644 --- a/tensorflow/contrib/data/python/ops/optimization.py +++ b/tensorflow/contrib/data/python/ops/optimization.py @@ -18,12 +18,34 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.data.python.ops import contrib_op_loader # pylint: disable=unused-import +from tensorflow.contrib.data.python.ops import gen_dataset_ops as contrib_gen_dataset_ops from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import gen_dataset_ops +# TODO(jsimsa): Support RE matching for both individual transformation (e.g. to +# account for indexing) and transformation sequence. +def assert_next(transformations): + """A transformation that asserts which transformations happen next. + + Args: + transformations: A `tf.string` vector `tf.Tensor` identifying the + transformations that are expected to happen next. + + Returns: + A `Dataset` transformation function, which can be passed to + @{tf.data.Dataset.apply}. + """ + + def _apply_fn(dataset): + """Function from `Dataset` to `Dataset` that applies the transformation.""" + return _AssertNextDataset(dataset, transformations) + + return _apply_fn + + def optimize(optimizations=None): """A transformation that applies optimizations. @@ -44,6 +66,37 @@ def optimize(optimizations=None): return _apply_fn +class _AssertNextDataset(dataset_ops.Dataset): + """A `Dataset` that asserts which transformations happen next.""" + + def __init__(self, input_dataset, transformations): + """See `assert_next()` for details.""" + super(_AssertNextDataset, self).__init__() + self._input_dataset = input_dataset + if transformations is None: + raise ValueError("At least one transformation should be specified") + self._transformations = ops.convert_to_tensor( + transformations, dtype=dtypes.string, name="transformations") + + def _as_variant_tensor(self): + return contrib_gen_dataset_ops.assert_next_dataset( + self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access + self._transformations, + **dataset_ops.flat_structure(self)) + + @property + def output_classes(self): + return self._input_dataset.output_classes + + @property + def output_shapes(self): + return self._input_dataset.output_shapes + + @property + def output_types(self): + return self._input_dataset.output_types + + class _OptimizeDataset(dataset_ops.Dataset): """A `Dataset` that acts as an identity, and applies optimizations.""" diff --git a/tensorflow/contrib/eager/python/examples/generative_examples/dcgan.ipynb b/tensorflow/contrib/eager/python/examples/generative_examples/dcgan.ipynb new file mode 100644 index 0000000000..43c8c355dc --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/generative_examples/dcgan.ipynb @@ -0,0 +1,711 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "0TD5ZrvEMbhZ" + }, + "source": [ + "##### Copyright 2018 The TensorFlow Authors.\n", + "\n", + "Licensed under the Apache License, Version 2.0 (the \"License\").\n", + "\n", + "# DCGAN: An example with tf.keras and eager\n", + "\n", + "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\u003ctd\u003e\n", + "\u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/generative_examples/dcgan.ipynb\"\u003e\n", + " \u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e \n", + "\u003c/td\u003e\u003ctd\u003e\n", + "\u003ca target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/eager/python/examples/generative_examples/dcgan.ipynb\"\u003e\u003cimg width=32px src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\u003c/td\u003e\u003c/table\u003e" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "ITZuApL56Mny" + }, + "source": [ + "This notebook demonstrates how to generate images of handwritten digits using [tf.keras](https://www.tensorflow.org/programmers_guide/keras) and [eager execution](https://www.tensorflow.org/programmers_guide/eager). To do this, we use Deep Convolutional Generative Adverserial Networks ([DCGAN](https://arxiv.org/pdf/1511.06434.pdf)).\n", + "\n", + "On a colab GPU(Tesla K80), the model takes around 40 seconds per epoch to train.\n", + "\n", + "Below is the output generated after training the generator and discriminator models for 100 epochs.\n", + "\n", + "![sample output](https://tensorflow.org/images/gan/dcgan.gif)" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "u_2z-B3piVsw" + }, + "outputs": [], + "source": [ + "# to generate gifs\n", + "!pip install imageio" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "e1_Y75QXJS6h" + }, + "source": [ + "## Import TensorFlow and enable eager execution" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "YfIk2es3hJEd" + }, + "outputs": [], + "source": [ + "# Import TensorFlow \u003e= 1.9 and enable eager execution\n", + "import tensorflow as tf\n", + "tf.enable_eager_execution()\n", + "\n", + "import os\n", + "import time\n", + "import numpy as np\n", + "import glob\n", + "import matplotlib.pyplot as plt\n", + "import PIL\n", + "import imageio\n", + "from IPython import display" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "iYn4MdZnKCey" + }, + "source": [ + "## Load the dataset\n", + "\n", + "We are going to use the MNIST dataset to train the generator and the discriminator. The generator will then generate handwritten digits." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "a4fYMGxGhrna" + }, + "outputs": [], + "source": [ + "(train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data()" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "NFC2ghIdiZYE" + }, + "outputs": [], + "source": [ + "train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')\n", + "# We are normalizing the images to the range of [-1, 1]\n", + "train_images = (train_images - 127.5) / 127.5" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "S4PIDhoDLbsZ" + }, + "outputs": [], + "source": [ + "BUFFER_SIZE = 60000\n", + "BATCH_SIZE = 256" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "PIGN6ouoQxt3" + }, + "source": [ + "## Use tf.data to create batches and shuffle the dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "-yKCCQOoJ7cn" + }, + "outputs": [], + "source": [ + "train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "THY-sZMiQ4UV" + }, + "source": [ + "## Write the generator and discriminator models\n", + "\n", + "* **Generator** \n", + " * It is responsible for **creating the convincing images good enough to fool the discriminator**.\n", + " * It consists of Conv2DTranspose(Upsampling) layers. We start with a fully connected layer and upsample the image 2 times so as to reach the desired image size(mnist image size) which is (28, 28, 1). \n", + " * We use **leaky relu** activation except for the **last layer** which uses **tanh** activation.\n", + " \n", + "* **Discriminator**\n", + " * **The discriminator is responsible for classifying the fake images from the real images.**\n", + " * In other words, the discriminator is given generated images(from the generator) and the real MNIST images. The job of the discriminator is to classify these images into fake(generated) and real(MNIST images).\n", + " * **Basically the generator should be good enough to fool the discriminator that the generated images are real**." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "VGLbvBEmjK0a" + }, + "outputs": [], + "source": [ + "class Generator(tf.keras.Model):\n", + " def __init__(self):\n", + " super(Generator, self).__init__()\n", + " self.fc1 = tf.keras.layers.Dense(7*7*64, use_bias=False)\n", + " self.batchnorm1 = tf.keras.layers.BatchNormalization()\n", + " \n", + " self.conv1 = tf.keras.layers.Conv2DTranspose(64, (5, 5), strides=(1, 1), padding='same', use_bias=False)\n", + " self.batchnorm2 = tf.keras.layers.BatchNormalization()\n", + " \n", + " self.conv2 = tf.keras.layers.Conv2DTranspose(32, (5, 5), strides=(2, 2), padding='same', use_bias=False)\n", + " self.batchnorm3 = tf.keras.layers.BatchNormalization()\n", + " \n", + " self.conv3 = tf.keras.layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False)\n", + "\n", + " def call(self, x, training=True):\n", + " x = self.fc1(x)\n", + " x = self.batchnorm1(x, training=training)\n", + " x = tf.nn.relu(x)\n", + "\n", + " x = tf.reshape(x, shape=(-1, 7, 7, 64))\n", + "\n", + " x = self.conv1(x)\n", + " x = self.batchnorm2(x, training=training)\n", + " x = tf.nn.relu(x)\n", + "\n", + " x = self.conv2(x)\n", + " x = self.batchnorm3(x, training=training)\n", + " x = tf.nn.relu(x)\n", + "\n", + " x = tf.nn.tanh(self.conv3(x)) \n", + " return x" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "bkOfJxk5j5Hi" + }, + "outputs": [], + "source": [ + "class Discriminator(tf.keras.Model):\n", + " def __init__(self):\n", + " super(Discriminator, self).__init__()\n", + " self.conv1 = tf.keras.layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same')\n", + " self.conv2 = tf.keras.layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same')\n", + " self.dropout = tf.keras.layers.Dropout(0.3)\n", + " self.flatten = tf.keras.layers.Flatten()\n", + " self.fc1 = tf.keras.layers.Dense(1)\n", + "\n", + " def call(self, x, training=True):\n", + " x = tf.nn.leaky_relu(self.conv1(x))\n", + " x = self.dropout(x, training=training)\n", + " x = tf.nn.leaky_relu(self.conv2(x))\n", + " x = self.dropout(x, training=training)\n", + " x = self.flatten(x)\n", + " x = self.fc1(x)\n", + " return x" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "gDkA05NE6QMs" + }, + "outputs": [], + "source": [ + "generator = Generator()\n", + "discriminator = Discriminator()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "0FMYgY_mPfTi" + }, + "source": [ + "## Define the loss functions and the optimizer\n", + "\n", + "* **Discriminator loss**\n", + " * The discriminator loss function takes 2 inputs; **real images, generated images**\n", + " * real_loss is a sigmoid cross entropy loss of the **real images** and an **array of ones(since these are the real images)**\n", + " * generated_loss is a sigmoid cross entropy loss of the **generated images** and an **array of zeros(since these are the fake images)**\n", + " * Then the total_loss is the sum of real_loss and the generated_loss\n", + " \n", + "* **Generator loss**\n", + " * It is a sigmoid cross entropy loss of the generated images and an **array of ones**\n", + " \n", + "\n", + "* The discriminator and the generator optimizers are different since we will train them separately." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "wkMNfBWlT-PV" + }, + "outputs": [], + "source": [ + "def discriminator_loss(real_output, generated_output):\n", + " # [1,1,...,1] with real output since it is true and we want\n", + " # our generated examples to look like it\n", + " real_loss = tf.losses.sigmoid_cross_entropy(multi_class_labels=tf.ones_like(real_output), logits=real_output)\n", + "\n", + " # [0,0,...,0] with generated images since they are fake\n", + " generated_loss = tf.losses.sigmoid_cross_entropy(multi_class_labels=tf.zeros_like(generated_output), logits=generated_output)\n", + "\n", + " total_loss = real_loss + generated_loss\n", + "\n", + " return total_loss" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "90BIcCKcDMxz" + }, + "outputs": [], + "source": [ + "def generator_loss(generated_output):\n", + " return tf.losses.sigmoid_cross_entropy(tf.ones_like(generated_output), generated_output)" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "iWCn_PVdEJZ7" + }, + "outputs": [], + "source": [ + "discriminator_optimizer = tf.train.AdamOptimizer(1e-4)\n", + "generator_optimizer = tf.train.AdamOptimizer(1e-4)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "Rw1fkAczTQYh" + }, + "source": [ + "## Training\n", + "\n", + "* We start by iterating over the dataset\n", + "* The generator is given **noise as an input** which when passed through the generator model will output a image looking like a handwritten digit\n", + "* The discriminator is given the **real MNIST images as well as the generated images(from the generator)**.\n", + "* Next, we calculate the generator and the discriminator loss.\n", + "* Then, we calculate the gradients of loss with respect to both the generator and the discriminator variables(inputs) and apply those to the optimizer.\n", + "\n", + "## Generate Images\n", + "\n", + "* After training, its time to generate some images!\n", + "* We start by creating noise array as an input to the generator\n", + "* The generator will then convert the noise into handwritten images.\n", + "* Last step is to plot the predictions and **voila!**" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "NS2GWywBbAWo" + }, + "outputs": [], + "source": [ + "EPOCHS = 150\n", + "noise_dim = 100\n", + "num_examples_to_generate = 100\n", + "\n", + "# keeping the random vector constant for generation(prediction) so\n", + "# it will be easier to see the improvement of the gan.\n", + "random_vector_for_generation = tf.random_normal([num_examples_to_generate,\n", + " noise_dim])" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "RmdVsmvhPxyy" + }, + "outputs": [], + "source": [ + "def generate_and_save_images(model, epoch, test_input):\n", + " # make sure the training parameter is set to False because we\n", + " # don't want to train the batchnorm layer when doing inference.\n", + " predictions = model(test_input, training=False)\n", + "\n", + " fig = plt.figure(figsize=(10,10))\n", + " \n", + " for i in range(predictions.shape[0]):\n", + " plt.subplot(10, 10, i+1)\n", + " plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')\n", + " plt.axis('off')\n", + " \n", + " # tight_layout minimizes the overlap between 2 sub-plots\n", + " plt.tight_layout()\n", + " plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))\n", + " plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "2M7LmLtGEMQJ" + }, + "outputs": [], + "source": [ + "def train(dataset, epochs, noise_dim): \n", + " for epoch in range(epochs):\n", + " start = time.time()\n", + " \n", + " for images in dataset:\n", + " # generating noise from a uniform distribution\n", + " noise = tf.random_normal([BATCH_SIZE, noise_dim])\n", + " \n", + " with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:\n", + " generated_images = generator(noise, training=True)\n", + " \n", + " real_output = discriminator(images, training=True)\n", + " generated_output = discriminator(generated_images, training=True)\n", + " \n", + " gen_loss = generator_loss(generated_output)\n", + " disc_loss = discriminator_loss(real_output, generated_output)\n", + " \n", + " gradients_of_generator = gen_tape.gradient(gen_loss, generator.variables)\n", + " gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.variables)\n", + " \n", + " generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.variables))\n", + " discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.variables))\n", + "\n", + " \n", + " if epoch % 10 == 0:\n", + " display.clear_output(wait=True)\n", + " generate_and_save_images(generator,\n", + " epoch + 1,\n", + " random_vector_for_generation)\n", + "\n", + " print ('Time taken for epoch {} is {} sec'.format(epoch + 1,\n", + " time.time()-start))\n", + " # generating after the final epoch\n", + " generate_and_save_images(generator,\n", + " epochs,\n", + " random_vector_for_generation)" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "Ly3UN0SLLY2l" + }, + "outputs": [], + "source": [ + "train(train_dataset, EPOCHS, noise_dim)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "P4M_vIbUi7c0" + }, + "source": [ + "# Display an image using the epoch number" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "WfO5wCdclHGL" + }, + "outputs": [], + "source": [ + "def display_image(epoch_no):\n", + " plt.figure(figsize=(15,15))\n", + " plt.imshow(np.array(PIL.Image.open('image_at_epoch_{:04d}.png'.format(epoch_no))))\n", + " plt.axis('off')" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "5x3q9_Oe5q0A" + }, + "outputs": [], + "source": [ + "display_image(EPOCHS)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "NywiH3nL8guF" + }, + "source": [ + "## Generate a GIF of all the saved images." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "xmO0Dmu2WICn" + }, + "source": [ + "\u003c!-- TODO(markdaoust): Remove the hack when Ipython version is updated --\u003e\n" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "IGKQgENQ8lEI" + }, + "outputs": [], + "source": [ + "with imageio.get_writer('dcgan.gif', mode='I') as writer:\n", + " filenames = glob.glob('image*.png')\n", + " filenames = sorted(filenames)\n", + " for filename in filenames:\n", + " image = imageio.imread(filename)\n", + " writer.append_data(image)\n", + " # this is a hack to display the gif inside the notebook\n", + " os.system('mv dcgan.gif dcgan.gif.png')" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "uV0yiKpzNP1b" + }, + "outputs": [], + "source": [ + "display.Image(filename=\"dcgan.gif.png\")" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "4UJjSnIMOzOJ" + }, + "outputs": [], + "source": [ + "" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "collapsed_sections": [], + "default_view": {}, + "name": "dcgan.ipynb", + "private_outputs": true, + "provenance": [ + { + "file_id": "1eb0NOTQapkYs3X0v-zL1x5_LFKgDISnp", + "timestamp": 1527173385672 + } + ], + "toc_visible": true, + "version": "0.3.2", + "views": {} + }, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc.py b/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc.py index 729d8525fa..275aee5130 100644 --- a/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc.py +++ b/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc.py @@ -54,7 +54,7 @@ class Dynamics(tf.keras.Model): self.position_fn = neural_nets.GenericNet(x_dim, factor=2.) self.momentum_fn = neural_nets.GenericNet(x_dim, factor=1.) - self.eps = tfe.Variable( + self.eps = tf.Variable( initial_value=eps, name="eps", dtype=tf.float32, trainable=True) def apply_transition(self, position): diff --git a/tensorflow/contrib/eager/python/examples/l2hmc/neural_nets.py b/tensorflow/contrib/eager/python/examples/l2hmc/neural_nets.py index e230ad5e25..68e0bc3123 100644 --- a/tensorflow/contrib/eager/python/examples/l2hmc/neural_nets.py +++ b/tensorflow/contrib/eager/python/examples/l2hmc/neural_nets.py @@ -25,7 +25,6 @@ from __future__ import division from __future__ import print_function import tensorflow as tf -import tensorflow.contrib.eager as tfe class GenericNet(tf.keras.Model): @@ -47,13 +46,13 @@ class GenericNet(tf.keras.Model): # Scale self.scale_layer = _custom_dense(x_dim, .001) - self.coeff_scale = tfe.Variable( + self.coeff_scale = tf.Variable( initial_value=tf.zeros([1, x_dim]), name='coeff_scale', trainable=True) # Translation self.translation_layer = _custom_dense(x_dim, factor=.001) # Transformation self.transformation_layer = _custom_dense(x_dim, .001) - self.coeff_transformation = tfe.Variable( + self.coeff_transformation = tf.Variable( initial_value=tf.zeros([1, x_dim]), name='coeff_transformation', trainable=True) diff --git a/tensorflow/contrib/eager/python/examples/notebooks/custom_training.ipynb b/tensorflow/contrib/eager/python/examples/notebooks/custom_training.ipynb index 591e2d0c85..5f1b48fa0d 100644 --- a/tensorflow/contrib/eager/python/examples/notebooks/custom_training.ipynb +++ b/tensorflow/contrib/eager/python/examples/notebooks/custom_training.ipynb @@ -118,7 +118,6 @@ "cell_type": "code", "source": [ "import tensorflow as tf\n", - "tfe = tf.contrib.eager # Shorthand for some symbols\n", "\n", "tf.enable_eager_execution()" ], @@ -184,7 +183,7 @@ }, "cell_type": "code", "source": [ - "v = tfe.Variable(1.0)\n", + "v = tf.Variable(1.0)\n", "assert v.numpy() == 1.0\n", "\n", "# Re-assign the value\n", @@ -258,8 +257,8 @@ " def __init__(self):\n", " # Initialize variable to (5.0, 0.0)\n", " # In practice, these should be initialized to random values.\n", - " self.W = tfe.Variable(5.0)\n", - " self.b = tfe.Variable(0.0)\n", + " self.W = tf.Variable(5.0)\n", + " self.b = tf.Variable(0.0)\n", " \n", " def __call__(self, x):\n", " return self.W * x + self.b\n", diff --git a/tensorflow/contrib/eager/python/examples/revnet/revnet_test.py b/tensorflow/contrib/eager/python/examples/revnet/revnet_test.py index b2ac4b67c9..b0d0a5486d 100644 --- a/tensorflow/contrib/eager/python/examples/revnet/revnet_test.py +++ b/tensorflow/contrib/eager/python/examples/revnet/revnet_test.py @@ -138,7 +138,7 @@ class RevNetTest(tf.test.TestCase): minval=0, maxval=self.config.n_classes, dtype=tf.int32) - global_step = tfe.Variable(0., trainable=False) + global_step = tf.Variable(0., trainable=False) model = revnet.RevNet(config=config) model(x) updates = model.get_updates_for(x) diff --git a/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py b/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py index c2340a293a..d64bf5354e 100644 --- a/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py +++ b/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py @@ -310,7 +310,7 @@ def main(_): with tf.device("/device:GPU:0" if have_gpu else None): # Make learning_rate a Variable so it can be included in the checkpoint # and we can resume training with the last saved learning_rate. - learning_rate = tfe.Variable(20.0, name="learning_rate") + learning_rate = tf.Variable(20.0, name="learning_rate") model = PTBModel(corpus.vocab_size(), FLAGS.embedding_dim, FLAGS.hidden_dim, FLAGS.num_layers, FLAGS.dropout, use_cudnn_rnn) diff --git a/tensorflow/contrib/eager/python/examples/sagan/sagan.py b/tensorflow/contrib/eager/python/examples/sagan/sagan.py index 561be36c91..8130414985 100644 --- a/tensorflow/contrib/eager/python/examples/sagan/sagan.py +++ b/tensorflow/contrib/eager/python/examples/sagan/sagan.py @@ -62,7 +62,7 @@ class SelfAttentionModule(tf.keras.Model): kernel_size=1, strides=(1, 1), data_format=data_format) - self.scale = tfe.Variable(0., trainable=True) + self.scale = tf.Variable(0., trainable=True) def call(self, x): f = self.f(x) diff --git a/tensorflow/contrib/eager/python/examples/workshop/2_models.ipynb b/tensorflow/contrib/eager/python/examples/workshop/2_models.ipynb index 4f1410e00b..f3a65f5aab 100644 --- a/tensorflow/contrib/eager/python/examples/workshop/2_models.ipynb +++ b/tensorflow/contrib/eager/python/examples/workshop/2_models.ipynb @@ -69,7 +69,7 @@ "cell_type": "code", "source": [ "# Creating variables\n", - "v = tfe.Variable(1.0)\n", + "v = tf.Variable(1.0)\n", "v" ], "execution_count": 2, diff --git a/tensorflow/contrib/eager/python/tfe_test.py b/tensorflow/contrib/eager/python/tfe_test.py index db50b33af2..4454abfb96 100644 --- a/tensorflow/contrib/eager/python/tfe_test.py +++ b/tensorflow/contrib/eager/python/tfe_test.py @@ -27,7 +27,6 @@ from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import numerics -from tensorflow.python.ops import variables from tensorflow.python.platform import test from tensorflow.python.summary import summary from tensorflow.python.summary.writer import writer @@ -45,12 +44,6 @@ class TFETest(test_util.TensorFlowTestCase): r'indices = 7 is not in \[0, 3\)'): array_ops.gather([0, 1, 2], 7) - def testVariableError(self): - with self.assertRaisesRegexp( - RuntimeError, - r'Variable not supported when eager execution is enabled'): - variables.Variable(initial_value=1.0) - def testGradients(self): def square(x): diff --git a/tensorflow/contrib/lite/delegates/eager/BUILD b/tensorflow/contrib/lite/delegates/eager/BUILD index 9d8c20e96f..9f31ffdf67 100644 --- a/tensorflow/contrib/lite/delegates/eager/BUILD +++ b/tensorflow/contrib/lite/delegates/eager/BUILD @@ -39,6 +39,38 @@ cc_test( ) cc_library( + name = "delegate_data", + srcs = ["delegate_data.cc"], + hdrs = ["delegate_data.h"], + tags = [ + "no_oss", + "tflite_not_portable", + ], + deps = [ + ":buffer_map", + "//tensorflow/core:core_cpu", + "//tensorflow/core:lib", + "//tensorflow/core/common_runtime/eager:context", + ], +) + +cc_test( + name = "delegate_data_test", + size = "small", + srcs = ["delegate_data_test.cc"], + tags = [ + "tflite_not_portable", + ], + deps = [ + ":delegate_data", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite:util", + "//tensorflow/contrib/lite/testing:util", + "@com_google_googletest//:gtest", + ], +) + +cc_library( name = "util", srcs = ["util.cc"], hdrs = ["util.h"], diff --git a/tensorflow/contrib/lite/delegates/eager/buffer_map.cc b/tensorflow/contrib/lite/delegates/eager/buffer_map.cc index e4a780b735..1d6453f498 100644 --- a/tensorflow/contrib/lite/delegates/eager/buffer_map.cc +++ b/tensorflow/contrib/lite/delegates/eager/buffer_map.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/core/framework/log_memory.h" namespace tflite { +namespace eager { namespace { // A tensor buffer that is allocated, deallocated and populated by TF Lite. class TfLiteTensorBuffer : public tensorflow::TensorBuffer { @@ -102,4 +103,5 @@ void BufferMap::SetFromTensorFlow(int tensor_index, tensorflow::Tensor tensor) { id_to_tensor_[tensor_index] = std::move(tensor); } +} // namespace eager } // namespace tflite diff --git a/tensorflow/contrib/lite/delegates/eager/buffer_map.h b/tensorflow/contrib/lite/delegates/eager/buffer_map.h index 922f67f574..a28329ae7d 100644 --- a/tensorflow/contrib/lite/delegates/eager/buffer_map.h +++ b/tensorflow/contrib/lite/delegates/eager/buffer_map.h @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" namespace tflite { +namespace eager { // Maps a TF Lite tensor index into a TensorFlow tensor. // @@ -54,6 +55,7 @@ class BufferMap { std::map<int, tensorflow::Tensor> id_to_tensor_; }; +} // namespace eager } // namespace tflite #endif // TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_BUFFER_MAP_H_ diff --git a/tensorflow/contrib/lite/delegates/eager/buffer_map_test.cc b/tensorflow/contrib/lite/delegates/eager/buffer_map_test.cc index c447eeaa05..dcb3f6c941 100644 --- a/tensorflow/contrib/lite/delegates/eager/buffer_map_test.cc +++ b/tensorflow/contrib/lite/delegates/eager/buffer_map_test.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/contrib/lite/util.h" namespace tflite { +namespace eager { namespace { using ::testing::ElementsAre; @@ -163,6 +164,7 @@ TEST(BufferMapTest, TensorFlowOverwritesTfLite) { } } // namespace +} // namespace eager } // namespace tflite int main(int argc, char** argv) { diff --git a/tensorflow/contrib/lite/delegates/eager/delegate_data.cc b/tensorflow/contrib/lite/delegates/eager/delegate_data.cc new file mode 100644 index 0000000000..29687694bd --- /dev/null +++ b/tensorflow/contrib/lite/delegates/eager/delegate_data.cc @@ -0,0 +1,46 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/delegates/eager/delegate_data.h" + +#include "tensorflow/core/common_runtime/device_factory.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tflite { +namespace eager { +tensorflow::Status DelegateData::Create(std::unique_ptr<DelegateData>* data) { + std::vector<tensorflow::Device*> devices; + + TF_RETURN_IF_ERROR(tensorflow::DeviceFactory::AddDevices( + tensorflow::SessionOptions(), "/device:cpu:*", &devices)); + + std::unique_ptr<tensorflow::DeviceMgr> device_mgr( + new tensorflow::DeviceMgr(devices)); + // Note that Rendezvous is ref-counted so it will be automatically deleted. + tensorflow::Rendezvous* rendezvous = + new tensorflow::IntraProcessRendezvous(device_mgr.get()); + data->reset(new DelegateData(new tensorflow::EagerContext( + tensorflow::SessionOptions(), + tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT, + /*async=*/false, std::move(device_mgr), rendezvous))); + return tensorflow::Status(); +} + +DelegateData::DelegateData(tensorflow::EagerContext* eager_context) + : eager_context_(eager_context) {} + +DelegateData::~DelegateData() {} + +} // namespace eager +} // namespace tflite diff --git a/tensorflow/contrib/lite/delegates/eager/delegate_data.h b/tensorflow/contrib/lite/delegates/eager/delegate_data.h new file mode 100644 index 0000000000..8a0e8ba8bf --- /dev/null +++ b/tensorflow/contrib/lite/delegates/eager/delegate_data.h @@ -0,0 +1,48 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_DELEGATE_DATA_H_ +#define TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_DELEGATE_DATA_H_ + +#include "tensorflow/contrib/lite/delegates/eager/buffer_map.h" +#include "tensorflow/core/common_runtime/eager/context.h" + +namespace tflite { +namespace eager { + +// Data kept by the Eager delegate for the lifetime of an Interpreter. +class DelegateData { + public: + // Create a new DelegateData, initialized with a newly-created EagerContext. + static tensorflow::Status Create(std::unique_ptr<DelegateData>* data); + + ~DelegateData(); + + // The EagerContext that is required for execution of Eager Ops. + tensorflow::EagerContext* GetEagerContext() { return eager_context_.get(); } + + // Map from TF Lite tensor index to TensorFlow tensor. + BufferMap* GetBufferMap() { return &buffer_map_; } + + private: + explicit DelegateData(tensorflow::EagerContext* eager_context); + + std::unique_ptr<tensorflow::EagerContext> eager_context_; + BufferMap buffer_map_; +}; + +} // namespace eager +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_DELEGATE_DATA_H_ diff --git a/tensorflow/contrib/lite/delegates/eager/delegate_data_test.cc b/tensorflow/contrib/lite/delegates/eager/delegate_data_test.cc new file mode 100644 index 0000000000..30251b8f82 --- /dev/null +++ b/tensorflow/contrib/lite/delegates/eager/delegate_data_test.cc @@ -0,0 +1,44 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/delegates/eager/delegate_data.h" + +#include <gmock/gmock.h> +#include <gtest/gtest.h> +#include "tensorflow/contrib/lite/testing/util.h" + +namespace tflite { +namespace eager { +namespace { + +TEST(DelegateDataTest, Basic) { + std::unique_ptr<DelegateData> data; + // We only check for success because it is hard to make initialization fail. + // It only happens if we manage to not link the CPU device factory into the + // binary. + EXPECT_TRUE(DelegateData::Create(&data).ok()); + + EXPECT_NE(data->GetEagerContext(), nullptr); + EXPECT_NE(data->GetBufferMap(), nullptr); +} + +} // namespace +} // namespace eager +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/delegates/eager/util.cc b/tensorflow/contrib/lite/delegates/eager/util.cc index e1879bdaff..4426c653e6 100644 --- a/tensorflow/contrib/lite/delegates/eager/util.cc +++ b/tensorflow/contrib/lite/delegates/eager/util.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/contrib/lite/delegates/eager/util.h" namespace tflite { +namespace eager { TfLiteStatus ConvertStatus(TfLiteContext* context, const tensorflow::Status& status) { @@ -67,4 +68,5 @@ TF_DataType GetTensorFlowDataType(TfLiteType type) { } } +} // namespace eager } // namespace tflite diff --git a/tensorflow/contrib/lite/delegates/eager/util.h b/tensorflow/contrib/lite/delegates/eager/util.h index 12b33b9b49..a9407be071 100644 --- a/tensorflow/contrib/lite/delegates/eager/util.h +++ b/tensorflow/contrib/lite/delegates/eager/util.h @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/core/lib/core/status.h" namespace tflite { +namespace eager { // Converts a tensorflow:Status into a TfLiteStatus. If the original status // represented an error, reports it using the given 'context'. @@ -35,6 +36,7 @@ TfLiteStatus CopyShape(TfLiteContext* context, const tensorflow::Tensor& src, // Returns the TF C API Data type that corresponds to the given TfLiteType. TF_DataType GetTensorFlowDataType(TfLiteType type); +} // namespace eager } // namespace tflite #endif // TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_UTIL_H_ diff --git a/tensorflow/contrib/lite/delegates/eager/util_test.cc b/tensorflow/contrib/lite/delegates/eager/util_test.cc index 53ed4db972..c4fbf54127 100644 --- a/tensorflow/contrib/lite/delegates/eager/util_test.cc +++ b/tensorflow/contrib/lite/delegates/eager/util_test.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/contrib/lite/testing/util.h" namespace tflite { +namespace eager { namespace { using tensorflow::DT_FLOAT; @@ -102,6 +103,7 @@ TEST(UtilTest, TypeConversions) { } } // namespace +} // namespace eager } // namespace tflite int main(int argc, char** argv) { diff --git a/tensorflow/contrib/lite/kernels/pow_test.cc b/tensorflow/contrib/lite/kernels/pow_test.cc index 474d323bc3..74b3aef5bd 100644 --- a/tensorflow/contrib/lite/kernels/pow_test.cc +++ b/tensorflow/contrib/lite/kernels/pow_test.cc @@ -50,22 +50,22 @@ class PowOpModel : public SingleOpModel { }; TEST(PowOpModel, Simple) { - PowOpModel<int32> model({TensorType_INT32, {1, 2, 2, 1}}, - {TensorType_INT32, {1, 2, 2, 1}}, - {TensorType_INT32, {}}); - model.PopulateTensor<int32>(model.input1(), {12, 2, 7, 8}); - model.PopulateTensor<int32>(model.input2(), {1, 2, 3, 1}); + PowOpModel<int32_t> model({TensorType_INT32, {1, 2, 2, 1}}, + {TensorType_INT32, {1, 2, 2, 1}}, + {TensorType_INT32, {}}); + model.PopulateTensor<int32_t>(model.input1(), {12, 2, 7, 8}); + model.PopulateTensor<int32_t>(model.input2(), {1, 2, 3, 1}); model.Invoke(); EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 2, 2, 1)); EXPECT_THAT(model.GetOutput(), ElementsAre(12, 4, 343, 8)); } TEST(PowOpModel, NegativeAndZeroValue) { - PowOpModel<int32> model({TensorType_INT32, {1, 2, 2, 1}}, - {TensorType_INT32, {1, 2, 2, 1}}, - {TensorType_INT32, {}}); - model.PopulateTensor<int32>(model.input1(), {0, 2, -7, 8}); - model.PopulateTensor<int32>(model.input2(), {1, 2, 3, 0}); + PowOpModel<int32_t> model({TensorType_INT32, {1, 2, 2, 1}}, + {TensorType_INT32, {1, 2, 2, 1}}, + {TensorType_INT32, {}}); + model.PopulateTensor<int32_t>(model.input1(), {0, 2, -7, 8}); + model.PopulateTensor<int32_t>(model.input2(), {1, 2, 3, 0}); model.Invoke(); EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 2, 2, 1)); EXPECT_THAT(model.GetOutput(), ElementsAre(0, 4, -343, 1)); @@ -98,10 +98,10 @@ TEST(PowOpModel, NegativeFloatTest) { } TEST(PowOpModel, BroadcastTest) { - PowOpModel<int32> model({TensorType_INT32, {1, 2, 2, 1}}, - {TensorType_INT32, {1}}, {TensorType_INT32, {}}); - model.PopulateTensor<int32>(model.input1(), {12, 2, 7, 8}); - model.PopulateTensor<int32>(model.input2(), {4}); + PowOpModel<int32_t> model({TensorType_INT32, {1, 2, 2, 1}}, + {TensorType_INT32, {1}}, {TensorType_INT32, {}}); + model.PopulateTensor<int32_t>(model.input1(), {12, 2, 7, 8}); + model.PopulateTensor<int32_t>(model.input2(), {4}); model.Invoke(); EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 2, 2, 1)); EXPECT_THAT(model.GetOutput(), ElementsAre(20736, 16, 2401, 4096)); diff --git a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc index c38b692dcd..f97919363b 100644 --- a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc +++ b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc @@ -340,6 +340,8 @@ PyObject* InterpreterWrapper::SetTensor(int i, PyObject* value) { namespace { +// Checks to see if a tensor access can succeed (returns nullptr on error). +// Otherwise returns Py_None. PyObject* CheckGetTensorArgs(Interpreter* interpreter_, int tensor_index, TfLiteTensor** tensor, int* type_num) { TFLITE_PY_ENSURE_VALID_INTERPRETER(); @@ -362,7 +364,7 @@ PyObject* CheckGetTensorArgs(Interpreter* interpreter_, int tensor_index, return nullptr; } - return nullptr; + Py_RETURN_NONE; } } // namespace @@ -371,10 +373,12 @@ PyObject* InterpreterWrapper::GetTensor(int i) const { // Sanity check accessor TfLiteTensor* tensor = nullptr; int type_num = 0; - if (PyObject* pynone_or_nullptr = - CheckGetTensorArgs(interpreter_.get(), i, &tensor, &type_num)) { - return pynone_or_nullptr; - } + + PyObject* check_result = + CheckGetTensorArgs(interpreter_.get(), i, &tensor, &type_num); + if (check_result == nullptr) return check_result; + Py_XDECREF(check_result); + std::vector<npy_intp> dims(tensor->dims->data, tensor->dims->data + tensor->dims->size); // Make a buffer copy but we must tell Numpy It owns that data or else @@ -396,10 +400,11 @@ PyObject* InterpreterWrapper::tensor(PyObject* base_object, int i) { // Sanity check accessor TfLiteTensor* tensor = nullptr; int type_num = 0; - if (PyObject* pynone_or_nullptr = - CheckGetTensorArgs(interpreter_.get(), i, &tensor, &type_num)) { - return pynone_or_nullptr; - } + + PyObject* check_result = + CheckGetTensorArgs(interpreter_.get(), i, &tensor, &type_num); + if (check_result == nullptr) return check_result; + Py_XDECREF(check_result); std::vector<npy_intp> dims(tensor->dims->data, tensor->dims->data + tensor->dims->size); diff --git a/tensorflow/contrib/lite/toco/BUILD b/tensorflow/contrib/lite/toco/BUILD index 5e197e584c..c88079717d 100644 --- a/tensorflow/contrib/lite/toco/BUILD +++ b/tensorflow/contrib/lite/toco/BUILD @@ -93,6 +93,7 @@ cc_library( ":runtime", ":toco_port", "//tensorflow/core:lib", + "@com_google_absl//absl/types:optional", ], ) @@ -246,6 +247,7 @@ cc_library( "graph_transformations/resolve_constant_transpose.cc", "graph_transformations/resolve_constant_unary.cc", "graph_transformations/resolve_fake_quant_args_from_vars.cc", + "graph_transformations/resolve_gather_attributes.cc", "graph_transformations/resolve_multiply_by_zero.cc", "graph_transformations/resolve_pad_attributes.cc", "graph_transformations/resolve_padv2_attributes.cc", diff --git a/tensorflow/contrib/lite/toco/export_tensorflow.cc b/tensorflow/contrib/lite/toco/export_tensorflow.cc index 4508aa6632..f9a6d31d60 100644 --- a/tensorflow/contrib/lite/toco/export_tensorflow.cc +++ b/tensorflow/contrib/lite/toco/export_tensorflow.cc @@ -215,6 +215,30 @@ void ConvertFloatTensorConst(const Model& model, const string& name, LegacyScalarPolicy::kAvoidLegacyScalars); } +void ConvertBoolTensorConst(const Model& model, const string& name, + GraphDef* tensorflow_graph) { + if (HasAlreadyExportedConst(name, *tensorflow_graph)) { + return; + } + CHECK(model.HasArray(name)); + const auto& array = model.GetArray(name); + tensorflow::NodeDef* const_op = tensorflow_graph->add_node(); + const_op->set_op("Const"); + const_op->set_name(name); + (*const_op->mutable_attr())["dtype"].set_type(DT_BOOL); + auto* tensor = (*const_op->mutable_attr())["value"].mutable_tensor(); + tensor->set_dtype(DT_BOOL); + const auto& data = array.GetBuffer<ArrayDataType::kBool>().data; + for (auto index : data) { + tensor->add_bool_val(index); + } + const auto& array_shape = array.shape(); + auto* shape = tensor->mutable_tensor_shape(); + for (int i = 0; i < array_shape.dimensions_count(); i++) { + shape->add_dim()->set_size(array_shape.dims(i)); + } +} + void ConvertIntTensorConst(const Model& model, const string& name, GraphDef* tensorflow_graph) { if (HasAlreadyExportedConst(name, *tensorflow_graph)) { @@ -621,7 +645,8 @@ void ConvertAddOperator(const Model& model, const AddOperator& src_op, CHECK_EQ(src_op.inputs.size(), 2); *add_op->add_input() = src_op.inputs[0]; *add_op->add_input() = src_op.inputs[1]; - (*add_op->mutable_attr())["T"].set_type(DT_FLOAT); + (*add_op->mutable_attr())["T"].set_type( + GetTensorFlowDataType(model, src_op.outputs[0])); } void ConvertAddNOperator(const Model& model, const AddNOperator& src_op, @@ -633,7 +658,8 @@ void ConvertAddNOperator(const Model& model, const AddNOperator& src_op, *add_op->add_input() = input; } (*add_op->mutable_attr())["N"].set_i(src_op.inputs.size()); - (*add_op->mutable_attr())["T"].set_type(DT_FLOAT); + (*add_op->mutable_attr())["T"].set_type( + GetTensorFlowDataType(model, src_op.outputs[0])); } void ConvertMulOperator(const Model& model, const MulOperator& src_op, @@ -644,16 +670,18 @@ void ConvertMulOperator(const Model& model, const MulOperator& src_op, CHECK_EQ(src_op.inputs.size(), 2); *add_op->add_input() = src_op.inputs[0]; *add_op->add_input() = src_op.inputs[1]; - (*add_op->mutable_attr())["T"].set_type(DT_FLOAT); + (*add_op->mutable_attr())["T"].set_type( + GetTensorFlowDataType(model, src_op.outputs[0])); } -void ConvertReluOperator(const ReluOperator& src_op, +void ConvertReluOperator(const Model& model, const ReluOperator& src_op, GraphDef* tensorflow_graph) { tensorflow::NodeDef* relu_op = tensorflow_graph->add_node(); relu_op->set_op("Relu"); relu_op->set_name(src_op.outputs[0]); *relu_op->add_input() = src_op.inputs[0]; - (*relu_op->mutable_attr())["T"].set_type(DT_FLOAT); + (*relu_op->mutable_attr())["T"].set_type( + GetTensorFlowDataType(model, src_op.outputs[0])); } void ConvertRelu1Operator(const Relu1Operator& src_op, @@ -1110,13 +1138,27 @@ void ConvertFloorOperator(const Model& model, const FloorOperator& src_op, void ConvertGatherOperator(const Model& model, const GatherOperator& src_op, GraphDef* tensorflow_graph) { tensorflow::NodeDef* gather_op = tensorflow_graph->add_node(); - gather_op->set_op("Gather"); + gather_op->set_op("GatherV2"); gather_op->set_name(src_op.outputs[0]); - CHECK_EQ(src_op.inputs.size(), 2); *gather_op->add_input() = src_op.inputs[0]; *gather_op->add_input() = src_op.inputs[1]; + if (!src_op.axis) { + // Dynamic axis. + CHECK_EQ(src_op.inputs.size(), 3); + *gather_op->add_input() = src_op.inputs[2]; + } else { + // Constant axis. + CHECK_EQ(src_op.inputs.size(), 2); + const string gather_axis = + AvailableArrayName(model, gather_op->name() + "/axis"); + CreateIntTensorConst(gather_axis, {src_op.axis.value()}, {}, + tensorflow_graph); + *gather_op->add_input() = gather_axis; + } + (*gather_op->mutable_attr())["Tindices"].set_type(DT_INT32); + (*gather_op->mutable_attr())["Taxis"].set_type(DT_INT32); const tensorflow::DataType params_type = GetTensorFlowDataType(model, src_op.inputs[0]); (*gather_op->mutable_attr())["Tparams"].set_type(params_type); @@ -1638,6 +1680,9 @@ void ConvertReduceOperator(const Model& model, const T& src_op, const tensorflow::DataType params_type = GetTensorFlowDataType(model, src_op.inputs[0]); (*new_op->mutable_attr())["T"].set_type(params_type); + const tensorflow::DataType indices_type = + GetTensorFlowDataType(model, src_op.inputs[1]); + (*new_op->mutable_attr())["Tidx"].set_type(indices_type); if (src_op.keep_dims) { (*new_op->mutable_attr())["keep_dims"].set_b(true); @@ -1694,43 +1739,43 @@ void ConvertSubOperator(const Model& model, const SubOperator& src_op, void ConvertTensorFlowMinimumOperator(const Model& model, const TensorFlowMinimumOperator& src_op, GraphDef* tensorflow_graph) { - tensorflow::NodeDef* sub_op = tensorflow_graph->add_node(); - sub_op->set_op("Minimum"); - sub_op->set_name(src_op.outputs[0]); + tensorflow::NodeDef* min_op = tensorflow_graph->add_node(); + min_op->set_op("Minimum"); + min_op->set_name(src_op.outputs[0]); CHECK_EQ(src_op.inputs.size(), 2); - *sub_op->add_input() = src_op.inputs[0]; - *sub_op->add_input() = src_op.inputs[1]; + *min_op->add_input() = src_op.inputs[0]; + *min_op->add_input() = src_op.inputs[1]; const tensorflow::DataType data_type = GetTensorFlowDataType(model, src_op.inputs[0]); - (*sub_op->mutable_attr())["T"].set_type(data_type); + (*min_op->mutable_attr())["T"].set_type(data_type); } void ConvertTensorFlowMaximumOperator(const Model& model, const TensorFlowMaximumOperator& src_op, GraphDef* tensorflow_graph) { - tensorflow::NodeDef* sub_op = tensorflow_graph->add_node(); - sub_op->set_op("Maximum"); - sub_op->set_name(src_op.outputs[0]); + tensorflow::NodeDef* max_op = tensorflow_graph->add_node(); + max_op->set_op("Maximum"); + max_op->set_name(src_op.outputs[0]); CHECK_EQ(src_op.inputs.size(), 2); - *sub_op->add_input() = src_op.inputs[0]; - *sub_op->add_input() = src_op.inputs[1]; + *max_op->add_input() = src_op.inputs[0]; + *max_op->add_input() = src_op.inputs[1]; const tensorflow::DataType data_type = GetTensorFlowDataType(model, src_op.inputs[0]); - (*sub_op->mutable_attr())["T"].set_type(data_type); + (*max_op->mutable_attr())["T"].set_type(data_type); } void ConvertSelectOperator(const Model& model, const SelectOperator& src_op, GraphDef* tensorflow_graph) { - tensorflow::NodeDef* sub_op = tensorflow_graph->add_node(); - sub_op->set_op("Select"); - sub_op->set_name(src_op.outputs[0]); + tensorflow::NodeDef* select_op = tensorflow_graph->add_node(); + select_op->set_op("Select"); + select_op->set_name(src_op.outputs[0]); CHECK_EQ(src_op.inputs.size(), 3); - *sub_op->add_input() = src_op.inputs[0]; - *sub_op->add_input() = src_op.inputs[1]; - *sub_op->add_input() = src_op.inputs[2]; + *select_op->add_input() = src_op.inputs[0]; + *select_op->add_input() = src_op.inputs[1]; + *select_op->add_input() = src_op.inputs[2]; const tensorflow::DataType data_type = GetTensorFlowDataType(model, src_op.inputs[1]); - (*sub_op->mutable_attr())["T"].set_type(data_type); + (*select_op->mutable_attr())["T"].set_type(data_type); } void ConvertTileOperator(const Model& model, @@ -1753,11 +1798,14 @@ void ConvertTileOperator(const Model& model, void ConvertTopKV2Operator(const Model& model, const TopKV2Operator& src_op, GraphDef* tensorflow_graph) { tensorflow::NodeDef* topk_op = tensorflow_graph->add_node(); - topk_op->set_op("TOPKV2"); + topk_op->set_op("TopKV2"); topk_op->set_name(src_op.outputs[0]); CHECK_EQ(src_op.inputs.size(), 2); *topk_op->add_input() = src_op.inputs[0]; *topk_op->add_input() = src_op.inputs[1]; + const tensorflow::DataType data_type = + GetTensorFlowDataType(model, src_op.inputs[0]); + (*topk_op->mutable_attr())["T"].set_type(data_type); (*topk_op->mutable_attr())["sorted"].set_b(true); } @@ -1828,6 +1876,43 @@ void ConvertPowOperator(const Model& model, const PowOperator& src_op, (*pow_op->mutable_attr())["T"].set_type(data_type); } +void ConvertAnyOperator(const Model& model, const AnyOperator& src_op, + GraphDef* tensorflow_graph) { + tensorflow::NodeDef* any_op = tensorflow_graph->add_node(); + any_op->set_op("Any"); + any_op->set_name(src_op.outputs[0]); + CHECK_EQ(src_op.inputs.size(), 2); + for (int i = 0; i < 2; ++i) { + *any_op->add_input() = src_op.inputs[i]; + } + const tensorflow::DataType data_type = + GetTensorFlowDataType(model, src_op.inputs[1]); + (*any_op->mutable_attr())["Tidx"].set_type(data_type); + (*any_op->mutable_attr())["keep_dims"].set_b(src_op.keep_dims); +} + +void ConvertLogicalAndOperator(const Model& model, + const LogicalAndOperator& src_op, + GraphDef* tensorflow_graph) { + tensorflow::NodeDef* logical_op = tensorflow_graph->add_node(); + logical_op->set_op("LogicalAnd"); + logical_op->set_name(src_op.outputs[0]); + CHECK_EQ(src_op.inputs.size(), 2); + for (int i = 0; i < 2; ++i) { + *logical_op->add_input() = src_op.inputs[i]; + } +} + +void ConvertLogicalNotOperator(const Model& model, + const LogicalNotOperator& src_op, + GraphDef* tensorflow_graph) { + tensorflow::NodeDef* logical_op = tensorflow_graph->add_node(); + logical_op->set_op("LogicalNot"); + logical_op->set_name(src_op.outputs[0]); + CHECK_EQ(src_op.inputs.size(), 1); + *logical_op->add_input() = src_op.inputs[0]; +} + void ConvertOperator(const Model& model, const Operator& src_op, GraphDef* tensorflow_graph) { if (src_op.fused_activation_function != FusedActivationFunctionType::kNone) { @@ -1864,7 +1949,7 @@ void ConvertOperator(const Model& model, const Operator& src_op, ConvertMulOperator(model, static_cast<const MulOperator&>(src_op), tensorflow_graph); } else if (src_op.type == OperatorType::kRelu) { - ConvertReluOperator(static_cast<const ReluOperator&>(src_op), + ConvertReluOperator(model, static_cast<const ReluOperator&>(src_op), tensorflow_graph); } else if (src_op.type == OperatorType::kRelu1) { ConvertRelu1Operator(static_cast<const Relu1Operator&>(src_op), @@ -1974,6 +2059,10 @@ void ConvertOperator(const Model& model, const Operator& src_op, ConvertReduceOperator(model, static_cast<const TensorFlowProdOperator&>(src_op), tensorflow_graph, "Prod"); + } else if (src_op.type == OperatorType::kReduceMin) { + ConvertReduceOperator(model, + static_cast<const TensorFlowMaxOperator&>(src_op), + tensorflow_graph, "Min"); } else if (src_op.type == OperatorType::kReduceMax) { ConvertReduceOperator(model, static_cast<const TensorFlowMaxOperator&>(src_op), @@ -2060,6 +2149,17 @@ void ConvertOperator(const Model& model, const Operator& src_op, } else if (src_op.type == OperatorType::kPow) { ConvertPowOperator(model, static_cast<const PowOperator&>(src_op), "Pow", tensorflow_graph); + } else if (src_op.type == OperatorType::kAny) { + ConvertAnyOperator(model, static_cast<const AnyOperator&>(src_op), + tensorflow_graph); + } else if (src_op.type == OperatorType::kLogicalAnd) { + ConvertLogicalAndOperator(model, + static_cast<const LogicalAndOperator&>(src_op), + tensorflow_graph); + } else if (src_op.type == OperatorType::kLogicalNot) { + ConvertLogicalNotOperator(model, + static_cast<const LogicalNotOperator&>(src_op), + tensorflow_graph); } else { LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(src_op.type); } @@ -2138,6 +2238,9 @@ void ExportTensorFlowGraphDefImplementation(const Model& model, const auto& array = *array_pair.second; if (array.buffer) { switch (array.data_type) { + case ArrayDataType::kBool: + ConvertBoolTensorConst(model, array_name, tensorflow_graph); + break; case ArrayDataType::kFloat: ConvertFloatTensorConst(model, array_name, tensorflow_graph); break; diff --git a/tensorflow/contrib/lite/toco/graph_transformations/convert_expanddims_to_reshape.cc b/tensorflow/contrib/lite/toco/graph_transformations/convert_expanddims_to_reshape.cc index 56f48d47de..310a88484c 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/convert_expanddims_to_reshape.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/convert_expanddims_to_reshape.cc @@ -40,11 +40,6 @@ bool ConvertExpandDimsToReshape::Run(Model* model, std::size_t op_index) { // Yield until input dims have been resolved. return false; } - if (input_array.shape().dimensions_count() == 0) { - // Input array cannot be 0-D. - // (Unsure if this is TF behavior, but was required to get a test to pass.) - return false; - } const auto& axis_array = model->GetArray(expand_op->inputs[1]); if (!axis_array.has_shape()) { diff --git a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h index 5cee08fd4c..b7634e28c6 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h +++ b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h @@ -195,6 +195,7 @@ DECLARE_GRAPH_TRANSFORMATION(Dequantize) DECLARE_GRAPH_TRANSFORMATION(UnpartitionEmbeddingLookup) DECLARE_GRAPH_TRANSFORMATION(ShuffleFCWeights) DECLARE_GRAPH_TRANSFORMATION(ResolveFakeQuantArgsFromVars) +DECLARE_GRAPH_TRANSFORMATION(ResolveGatherAttributes) class PropagateDefaultMinMax : public GraphTransformation { public: diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc index 670bcf64e7..3dda536ef7 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc @@ -62,6 +62,9 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) { case OperatorType::kGreaterEqual: case OperatorType::kEqual: case OperatorType::kNotEqual: + case OperatorType::kAny: + case OperatorType::kLogicalAnd: + case OperatorType::kLogicalNot: // These operators unconditionally produce bool outputs SetDataTypeForAllOutputs(model, op, ArrayDataType::kBool); break; diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc index 5e2ba0eca7..62ed5c46e9 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc @@ -437,6 +437,7 @@ void ProcessTensorFlowReshapeOperator(Model* model, product_non_wildcard_dims *= shape_data[i]; } } + const int input_flat_size = RequiredBufferSizeForShape(input_shape); if (has_wildcard) { CHECK_GE(input_flat_size, product_non_wildcard_dims) @@ -445,6 +446,12 @@ void ProcessTensorFlowReshapeOperator(Model* model, << op->outputs[0] << "\". Are your input shapes correct?"; shape_data[wildcard_index] = input_flat_size / product_non_wildcard_dims; } + + if (shape_data.size() == 1 && shape_data[0] == 0) { + // We have reshaped a scalar, so preserve as a scalar. + shape_data.clear(); + } + auto& output_shape = *output_array.mutable_shape(); *output_shape.mutable_dims() = shape_data; CHECK_EQ(input_flat_size, RequiredBufferSizeForShape(output_shape)) @@ -522,7 +529,7 @@ void ProcessAddNOperator(Model* model, Operator* op) { bool KeepDims(const Operator& op) { switch (op.type) { - case OperatorType::kMin: // Reduction Min + case OperatorType::kReduceMin: // Reduction Min return static_cast<const TensorFlowMinOperator&>(op).keep_dims; case OperatorType::kReduceMax: // Reduction Max return static_cast<const TensorFlowMaxOperator&>(op).keep_dims; @@ -1036,17 +1043,28 @@ void ProcessGatherOperator(Model* model, GatherOperator* op) { return; } + // Yield until the axis has been resolved. + if (!op->axis) { + return; + } + int axis = op->axis.value(); + const auto& input_shape = input_array.shape(); const auto& indices_shape = indices_array.shape(); QCHECK_GE(input_shape.dimensions_count(), 1); op->input_rank = input_shape.dimensions_count(); + QCHECK_LT(axis, op->input_rank); - // Copy the input dimensions to the output except for dimension 0, + // Copy the input dimensions to the output except for the axis dimensions // where the dimension of indices_shape is used. - // TODO(mgubin): if axis != 0 this is not true, change when it's supported. auto output_dims = output_array.mutable_shape()->mutable_dims(); - output_dims->push_back(indices_shape.dims(0)); - for (int dim = 1; dim < input_shape.dimensions_count(); dim++) { + for (int dim = 0; dim < axis; ++dim) { + output_dims->push_back(input_shape.dims(dim)); + } + for (int dim = 0; dim < indices_shape.dimensions_count(); ++dim) { + output_dims->push_back(indices_shape.dims(dim)); + } + for (int dim = axis + 1; dim < input_shape.dimensions_count(); ++dim) { output_dims->push_back(input_shape.dims(dim)); } } @@ -1501,6 +1519,65 @@ void ProcessTileOperator(Model* model, TensorFlowTileOperator* op) { } } +void ProcessAnyOperator(Model* model, AnyOperator* op) { + CHECK_EQ(op->inputs.size(), 2); + CHECK_EQ(op->outputs.size(), 1); + + auto& output_array = model->GetArray(op->outputs[0]); + if (output_array.has_shape()) { + // We have already run. + return; + } + + const auto& input_array = model->GetArray(op->inputs[0]); + if (!input_array.has_shape()) { + // Yield until input dims have been resolved. + return; + } + const auto& input_shape = input_array.shape(); + + auto& reduction_indices_array = model->GetArray(op->inputs[1]); + if (!reduction_indices_array.has_shape()) { + // Yield until reduction indices shape been resolved. + return; + } + if (!reduction_indices_array.buffer) { + // Yield until the reduction indices are constant. + return; + } + CHECK(reduction_indices_array.data_type == ArrayDataType::kInt32) + << "Any reduction input must be int32"; + + int input_rank = input_shape.dimensions_count(); + std::set<int32> true_indices; + const auto& reduction_indices = + reduction_indices_array.GetBuffer<ArrayDataType::kInt32>().data; + for (int i = 0; i < reduction_indices.size(); ++i) { + const int32 reduction_index = reduction_indices[i]; + if (reduction_index < -input_rank || reduction_index >= input_rank) { + CHECK(false) << "Invalid reduction dimension " << reduction_index + << " for input with " << input_rank << " dimensions"; + } + int32 wrapped_index = reduction_index; + if (wrapped_index < 0) { + wrapped_index += input_rank; + } + true_indices.insert(wrapped_index); + } + + auto* mutable_dims = output_array.mutable_shape()->mutable_dims(); + mutable_dims->clear(); + for (int i = 0; i < input_rank; ++i) { + if (true_indices.count(i) > 0) { + if (op->keep_dims) { + mutable_dims->emplace_back(1); + } + } else { + mutable_dims->emplace_back(input_shape.dims(i)); + } + } +} + } // namespace bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) { @@ -1539,6 +1616,8 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) { case OperatorType::kFloor: case OperatorType::kExp: case OperatorType::kSin: + case OperatorType::kLogicalAnd: + case OperatorType::kLogicalNot: ProcessSimpleOperator(model, op, 0); break; case OperatorType::kGather: @@ -1607,7 +1686,7 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) { case OperatorType::kL2Pool: ProcessL2PoolOperator(model, static_cast<L2PoolOperator*>(op)); break; - case OperatorType::kMin: // Reduction Min + case OperatorType::kReduceMin: // Reduction Min case OperatorType::kReduceMax: // Reduction Max case OperatorType::kSum: case OperatorType::kReduceProd: @@ -1732,6 +1811,9 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) { case OperatorType::kTile: ProcessTileOperator(model, static_cast<TensorFlowTileOperator*>(op)); break; + case OperatorType::kAny: + ProcessAnyOperator(model, static_cast<AnyOperator*>(op)); + break; default: // Unimplemented, another graph transformation should drop it. LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(op->type); diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_reshape.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_reshape.cc index 404f27e067..5295eeccec 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_reshape.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_reshape.cc @@ -59,6 +59,15 @@ bool IsReshapeTrivial(const Model& model, const Operator& op, if (CountOpsWithInput(model, op.outputs[0]) == 1) { const auto* next_op = GetOpWithInput(model, op.outputs[0]); if (next_op->type == OperatorType::kReshape) { + if (!IsDiscardableArray(model, next_op->outputs[0])) { + // If the |next_op| output is used as a model output we need to preserve + // its shape. + transformation->AddMessageF( + "%s cannot be merged into following reshape %s as it is " + "non-discardable and must keep the specified shape", + LogName(op), LogName(*next_op)); + return false; + } transformation->AddMessageF( "%s is trivial because its output is only consumed by another " "Reshape op %s", diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_gather.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_gather.cc index debe298a5a..36d7dad0ce 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_gather.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_gather.cc @@ -69,7 +69,7 @@ bool ResolveConstantGather::Run(Model* model, std::size_t op_index) { } const auto* op = static_cast<const GatherOperator*>(base_op); - CHECK_EQ(op->inputs.size(), 2); + CHECK_GE(op->inputs.size(), 2); CHECK_EQ(op->outputs.size(), 1); auto& output_array = model->GetArray(op->outputs[0]); if (output_array.data_type == ArrayDataType::kNone) { @@ -81,10 +81,14 @@ bool ResolveConstantGather::Run(Model* model, std::size_t op_index) { return false; } - // Only handling axis=0 for now. - if (op->axis != 0) { + if (!op->axis) { + // Yield until axis has been set by ResolveGatherAttributes. + return false; + } + if (op->axis.value() != 0) { + // Only handling axis=0 for now. AddMessageF("%s has axis %d; only axis=0 is supported", LogName(*op), - op->axis); + op->axis.value()); return false; } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc index 51099cf74a..fe3882c28d 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc @@ -57,7 +57,7 @@ bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) { case OperatorType::kSqrt: case OperatorType::kSquare: case OperatorType::kSum: - case OperatorType::kMin: // Reduction Min + case OperatorType::kReduceMin: // Reduction Min case OperatorType::kReduceMax: // Reduction Max case OperatorType::kReshape: case OperatorType::kRelu6: @@ -196,7 +196,7 @@ bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) { } output_float_data[i] = sum; } - } else if (unary_op->type == OperatorType::kMin) { + } else if (unary_op->type == OperatorType::kReduceMin) { // At the moment only full reduction across all dimensions is supported. // TODO(starka): Output should not be padded. for (int i = 0; i < output_dims_count; i++) { diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_gather_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_gather_attributes.cc new file mode 100644 index 0000000000..ce825c91af --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_gather_attributes.cc @@ -0,0 +1,53 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include <memory> +#include <string> +#include <unordered_map> +#include <vector> + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +bool ResolveGatherAttributes::Run(Model* model, std::size_t op_index) { + auto* gather_op = model->operators[op_index].get(); + if (gather_op->type != OperatorType::kGather) return false; + auto* op = static_cast<GatherOperator*>(gather_op); + + if (op->axis) { + // Attributes already resolved + return false; + } + if (op->inputs.size() != 3) return false; + if (!IsConstantParameterArray(*model, op->inputs[2])) return false; + + const auto& indices_array = model->GetArray(op->inputs[2]); + if (!indices_array.has_shape()) return false; + const auto& axis_data = indices_array.GetBuffer<ArrayDataType::kInt32>().data; + CHECK_EQ(axis_data.size(), 1) + << "Multidimensional gather not supported on " << LogName(*op); + op->axis = {axis_data[0]}; + + // Drop the axis array as we no longer need it. + DeleteArrayIfUsedOnce(op->inputs[2], model); + op->inputs.resize(2); + + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_reduce_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_reduce_attributes.cc index 5f8a06ba92..7d456af2fb 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_reduce_attributes.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_reduce_attributes.cc @@ -48,6 +48,8 @@ bool ResolveReduceAttributes::Run(Model* model, std::size_t op_index) { return ResolveAttributes(model, static_cast<TensorFlowSumOperator*>(op)); case OperatorType::kReduceProd: return ResolveAttributes(model, static_cast<TensorFlowProdOperator*>(op)); + case OperatorType::kReduceMin: + return ResolveAttributes(model, static_cast<TensorFlowMinOperator*>(op)); case OperatorType::kReduceMax: return ResolveAttributes(model, static_cast<TensorFlowMaxOperator*>(op)); default: diff --git a/tensorflow/contrib/lite/toco/graph_transformations/unfuse_activation_functions.cc b/tensorflow/contrib/lite/toco/graph_transformations/unfuse_activation_functions.cc index 2c7046c8c7..69bad2fa89 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/unfuse_activation_functions.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/unfuse_activation_functions.cc @@ -64,7 +64,14 @@ bool UnfuseActivationFunctions::Run(Model* model, std::size_t op_index) { const string& tmp_array_name = AvailableArrayName(*model, op->outputs[0] + "_unfused"); CHECK(!model->HasArray(tmp_array_name)); - model->GetOrCreateArray(tmp_array_name); + + const auto& output_array = model->GetArray(op->outputs[0]); + auto& tmp_array = model->GetOrCreateArray(tmp_array_name); + if (output_array.quantization_params) { + tmp_array.GetOrCreateQuantizationParams() = + output_array.GetQuantizationParams(); + } + ac_op->inputs = {tmp_array_name}; op->outputs = {tmp_array_name}; return true; diff --git a/tensorflow/contrib/lite/toco/graph_transformations/unpartition_embedding_lookup.cc b/tensorflow/contrib/lite/toco/graph_transformations/unpartition_embedding_lookup.cc index cbea39bcc0..dd9e26e68b 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/unpartition_embedding_lookup.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/unpartition_embedding_lookup.cc @@ -187,6 +187,7 @@ bool UnpartitionEmbeddingLookup::Run(Model* model, std::size_t op_index) { AvailableArrayName(*model, gather_ops[0]->inputs[0] + "_permuted/perm")); gather_params_permute_op->outputs.push_back( AvailableArrayName(*model, gather_ops[0]->inputs[0] + "_permuted")); + gather_params_permute_op->axis = {0}; op_it = model->operators.emplace(op_it, gather_params_permute_op) + 1; model->GetOrCreateArray(gather_params_permute_op->outputs[0]); const auto& partition_array = model->GetArray(gather_ops[0]->inputs[0]); @@ -212,6 +213,7 @@ bool UnpartitionEmbeddingLookup::Run(Model* model, std::size_t op_index) { mod_op->inputs[0]}; merged_gather_op->outputs = {stitch_op->outputs[0]}; merged_gather_op->input_rank = partition_array.shape().dimensions_count(); + merged_gather_op->axis = {0}; model->operators.emplace(op_it, merged_gather_op); AddMessageF( diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc index 576eb71534..8bb797fe0f 100644 --- a/tensorflow/contrib/lite/toco/import_tensorflow.cc +++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc @@ -1042,22 +1042,6 @@ tensorflow::Status ConvertSimpleOperator( return ConvertSimpleOperator<Op>(node, tf_import_flags, model); } -tensorflow::Status ConvertMinOperator( - const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Min"); - TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); - auto* op = new TensorFlowMinOperator; - op->inputs.push_back(node.input(0)); - op->inputs.push_back(node.input(1)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); - if (HasAttr(node, "keep_dims")) { - op->keep_dims = GetBoolAttr(node, "keep_dims"); - } - return tensorflow::Status::OK(); -} - tensorflow::Status ConvertUnsupportedOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, Model* model) { @@ -1197,8 +1181,17 @@ tensorflow::Status ConvertGatherOperator( auto* op = new GatherOperator; op->inputs.push_back(node.input(0)); op->inputs.push_back(node.input(1)); - // TODO(ahentz): we currently ignore the third tensor in GatherV2 but we - // should read it an pass it on to the TF Lite Interpreter. + if (node.input_size() >= 3) { + // GatherV2 form where we are provided an axis. It may be either a constant + // or runtime defined value, so we just wire up the array and let + // ResolveGatherAttributes take care of it later on. + const auto axis_data_type = GetDataTypeAttr(node, "Taxis"); + CHECK(axis_data_type == DT_INT32 || axis_data_type == DT_INT64); + op->inputs.push_back(node.input(2)); + } else { + // Gather form that assumes axis=0. + op->axis = {0}; + } op->outputs.push_back(node.name()); model->operators.emplace_back(op); return tensorflow::Status::OK(); @@ -1585,6 +1578,24 @@ tensorflow::Status ConvertShapeOperator( return tensorflow::Status::OK(); } +tensorflow::Status ConvertAnyOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { + CHECK_EQ(node.op(), "Any"); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); + const auto idx_type = + HasAttr(node, "Tidx") ? GetDataTypeAttr(node, "Tidx") : DT_INT32; + CHECK(idx_type == DT_INT32); + auto op = absl::make_unique<AnyOperator>(); + op->inputs.push_back(node.input(0)); + op->inputs.push_back(node.input(1)); + op->outputs.push_back(node.name()); + op->keep_dims = + HasAttr(node, "keep_dims") ? GetBoolAttr(node, "keep_dims") : false; + model->operators.push_back(std::move(op)); + return tensorflow::Status::OK(); +} + void StripCaretFromArrayNames(Model* model) { for (auto& op : model->operators) { for (auto& input : op->inputs) { @@ -1820,6 +1831,7 @@ ConverterMapType GetTensorFlowNodeConverterMap() { {"Add", ConvertSimpleOperator<AddOperator, 2>}, {"AddN", ConvertSimpleOperator<AddNOperator>}, {"All", ConvertSimpleOperator<TensorFlowAllOperator>}, + {"Any", ConvertAnyOperator}, {"ArgMax", ConvertArgMinMaxOperator<ArgMaxOperator, kArgMax>}, {"ArgMin", ConvertArgMinMaxOperator<ArgMinOperator, kArgMin>}, {"Assert", ConvertSimpleOperator<TensorFlowAssertOperator>}, @@ -1862,15 +1874,16 @@ ConverterMapType GetTensorFlowNodeConverterMap() { {"Less", ConvertSimpleOperator<TensorFlowLessOperator, 2>}, {"LessEqual", ConvertSimpleOperator<TensorFlowLessEqualOperator, 2>}, {"Log", ConvertSimpleOperator<LogOperator, 1>}, - {"Log", ConvertSimpleOperator<LogOperator, 1>}, {"LogSoftmax", ConvertSimpleOperator<LogSoftmaxOperator, 1>}, + {"LogicalAnd", ConvertSimpleOperator<LogicalAndOperator, 2>}, + {"LogicalNot", ConvertSimpleOperator<LogicalNotOperator, 1>}, {"MatMul", ConvertMatMulOperator}, {"Max", ConvertReduceOperator<TensorFlowMaxOperator>}, {"MaxPool", ConvertMaxPoolOperator}, {"Maximum", ConvertSimpleOperator<TensorFlowMaximumOperator, 2>}, {"Mean", ConvertReduceOperator<MeanOperator>}, {"Merge", ConvertSimpleOperator<TensorFlowMergeOperator, 2>}, - {"Min", ConvertMinOperator}, + {"Min", ConvertReduceOperator<TensorFlowMinOperator>}, {"Minimum", ConvertSimpleOperator<TensorFlowMinimumOperator, 2>}, {"Mul", ConvertSimpleOperator<MulOperator, 2>}, {"Neg", ConvertSimpleOperator<NegOperator, 1>}, diff --git a/tensorflow/contrib/lite/toco/model.h b/tensorflow/contrib/lite/toco/model.h index 8fff68cf47..6fe194516d 100644 --- a/tensorflow/contrib/lite/toco/model.h +++ b/tensorflow/contrib/lite/toco/model.h @@ -23,6 +23,7 @@ limitations under the License. #include <unordered_map> #include <vector> +#include "absl/types/optional.h" #include "tensorflow/contrib/lite/toco/model_flags.pb.h" #include "tensorflow/contrib/lite/toco/runtime/types.h" #include "tensorflow/contrib/lite/toco/toco_port.h" @@ -109,7 +110,7 @@ enum class OperatorType : uint8 { kLessEqual, kReduceMax, // Reduction Max kMaximum, // Element-wise Maximum - kMin, // Reduction Min + kReduceMin, // Reduction Min kMinimum, // Element-wise Minimum kMatMul, kMerge, @@ -142,6 +143,9 @@ enum class OperatorType : uint8 { kNotEqual, kPow, kArgMin, + kAny, + kLogicalAnd, + kLogicalNot, }; // Helper to deal with TensorFlow arrays using a different ordering of @@ -1415,16 +1419,15 @@ struct TensorFlowMaxOperator : Operator { bool keep_dims = false; }; -// Global min reduction: computes the min of all of entries in the input array. -// Thus the output is "0-dimensional": it consists of a single scalar value. +// Min reduction: computes the min of all of entries across the axes. // // Inputs: // inputs[0]: required: the input array // -// TensorFlow equivalent: Min --- except that we only support the special case -// of global reduction across all dimensions. +// TensorFlow equivalent: Min struct TensorFlowMinOperator : Operator { - TensorFlowMinOperator() : Operator(OperatorType::kMin) {} + TensorFlowMinOperator() : Operator(OperatorType::kReduceMin) {} + std::vector<int> axis; bool keep_dims = false; }; @@ -1525,11 +1528,15 @@ struct FloorOperator : Operator { // Inputs: // inputs[0]: required: the params array // inputs[1]: required: the indices to gather +// inputs[2]: optional: axis // // TensorFlow equivalent: Gather struct GatherOperator : Operator { GatherOperator() : Operator(OperatorType::kGather) {} - int axis = 0; + // Axis is populated explicitly or implicitly from the axis input by + // ResolveGatherAttributes. An empty axis indicates that the axis has not yet + // be resolved. + absl::optional<int> axis; int input_rank = 0; }; @@ -1685,6 +1692,39 @@ struct PowOperator : Operator { PowOperator() : Operator(OperatorType::kPow) {} }; +// Any operator: +// +// Inputs: +// Inputs[0]: required: A boolean input tensor. +// Inputs[1]: required: reduction_indices. +// +// TensorFlow equivalent: tf.reduce_any. +struct AnyOperator : Operator { + AnyOperator() : Operator(OperatorType::kAny) {} + bool keep_dims = false; +}; + +// LogicalAnd operator: +// +// Inputs: +// Inputs[0]: required: A boolean tensor. +// Inputs[1]: required: A boolean tensor. +// +// TensorFlow equivalent: tf.logical_and. +struct LogicalAndOperator : Operator { + LogicalAndOperator() : Operator(OperatorType::kLogicalAnd) {} +}; + +// LogicalNot operator: +// +// Inputs: +// Inputs[0]: required: A boolean tensor. +// +// TensorFlow equivalent: tf.logical_not. +struct LogicalNotOperator : Operator { + LogicalNotOperator() : Operator(OperatorType::kLogicalNot) {} +}; + // Alloc's are used for transient arrays only. An Alloc specifies which interval // of the "transient_data" workspace buffer passed to inference functions, is to // be used for the transient array at hand. The 'start' and 'end' values are diff --git a/tensorflow/contrib/lite/toco/tflite/operator.cc b/tensorflow/contrib/lite/toco/tflite/operator.cc index 68d13586f1..1a1c4b8944 100644 --- a/tensorflow/contrib/lite/toco/tflite/operator.cc +++ b/tensorflow/contrib/lite/toco/tflite/operator.cc @@ -370,12 +370,13 @@ class Gather : public BuiltinOperator<GatherOperator, ::tflite::GatherOptions, flatbuffers::Offset<TfLiteOptions> WriteOptions( const TocoOperator& op, flatbuffers::FlatBufferBuilder* builder) const override { - return ::tflite::CreateGatherOptions(*builder, op.axis); + int axis = op.axis ? op.axis.value() : 0; + return ::tflite::CreateGatherOptions(*builder, axis); } void ReadOptions(const TfLiteOptions& options, TocoOperator* op) const override { - op->axis = options.axis(); + op->axis = {options.axis()}; } int GetVersion(const Operator& op) const override { return 1; } diff --git a/tensorflow/contrib/lite/toco/toco_tooling.cc b/tensorflow/contrib/lite/toco/toco_tooling.cc index d8964ebc13..aa7f6996eb 100644 --- a/tensorflow/contrib/lite/toco/toco_tooling.cc +++ b/tensorflow/contrib/lite/toco/toco_tooling.cc @@ -117,6 +117,7 @@ void MakeGeneralGraphTransformationsSet( transformations->Add(new ResolveConstantShapeOrRank); transformations->Add(new MakeInitialDequantizeOperator); transformations->Add(new UnpartitionEmbeddingLookup); + transformations->Add(new ResolveGatherAttributes); } bool SupportsQuantization(FileFormat format) { diff --git a/tensorflow/contrib/lite/toco/tooling_util.cc b/tensorflow/contrib/lite/toco/tooling_util.cc index 4305727c8c..52f8df45a2 100644 --- a/tensorflow/contrib/lite/toco/tooling_util.cc +++ b/tensorflow/contrib/lite/toco/tooling_util.cc @@ -351,10 +351,10 @@ const char* OperatorTypeName(OperatorType type) { HANDLE_OPERATORTYPENAME_CASE(LessEqual) HANDLE_OPERATORTYPENAME_CASE(MatMul) HANDLE_OPERATORTYPENAME_CASE(ReduceMax) // Reduction Max - HANDLE_OPERATORTYPENAME_CASE(Maximum) // Element-wise Maximum + HANDLE_OPERATORTYPENAME_CASE(Maximum) // Element-wise Maximum HANDLE_OPERATORTYPENAME_CASE(Merge) - HANDLE_OPERATORTYPENAME_CASE(Min) // Reduction Min - HANDLE_OPERATORTYPENAME_CASE(Minimum) // Element-wise Minimum + HANDLE_OPERATORTYPENAME_CASE(ReduceMin) // Reduction Min + HANDLE_OPERATORTYPENAME_CASE(Minimum) // Element-wise Minimum HANDLE_OPERATORTYPENAME_CASE(Neg) HANDLE_OPERATORTYPENAME_CASE(Pack) HANDLE_OPERATORTYPENAME_CASE(Pad) @@ -399,6 +399,9 @@ const char* OperatorTypeName(OperatorType type) { HANDLE_OPERATORTYPENAME_CASE(Equal) HANDLE_OPERATORTYPENAME_CASE(NotEqual) HANDLE_OPERATORTYPENAME_CASE(Pow) + HANDLE_OPERATORTYPENAME_CASE(Any) + HANDLE_OPERATORTYPENAME_CASE(LogicalAnd) + HANDLE_OPERATORTYPENAME_CASE(LogicalNot) default: LOG(FATAL) << "Unhandled op type"; #undef HANDLE_OPERATORTYPENAME_CASE @@ -940,8 +943,12 @@ void CheckEachArray(const Model& model) { // shape. CHECK(array->has_shape()); // Constant buffer should has a valid shape. - for (int d : array->shape().dims()) { - CHECK_GE(d, 1); + bool is_scalar = + array->shape().dimensions_count() == 1 && array->shape().dims(0) == 0; + if (!is_scalar) { + for (int d : array->shape().dims()) { + CHECK_GE(d, 1); + } } // The shape flat-size should agree with the buffer length. CHECK_EQ(array->buffer->Length(), diff --git a/tensorflow/contrib/makefile/tf_op_files.txt b/tensorflow/contrib/makefile/tf_op_files.txt index 6e7423f85e..ecf2e120df 100644 --- a/tensorflow/contrib/makefile/tf_op_files.txt +++ b/tensorflow/contrib/makefile/tf_op_files.txt @@ -229,6 +229,8 @@ tensorflow/core/kernels/cast_op_impl_int32.cc tensorflow/core/kernels/cast_op_impl_int64.cc tensorflow/core/kernels/cast_op_impl_int8.cc tensorflow/core/kernels/cast_op_impl_uint16.cc +tensorflow/core/kernels/cast_op_impl_uint32.cc +tensorflow/core/kernels/cast_op_impl_uint64.cc tensorflow/core/kernels/cast_op_impl_uint8.cc tensorflow/core/kernels/boosted_trees/prediction_ops.cc tensorflow/core/kernels/boosted_trees/resource_ops.cc diff --git a/tensorflow/contrib/nccl/kernels/nccl_manager.cc b/tensorflow/contrib/nccl/kernels/nccl_manager.cc index b1cb89391c..99fecf9651 100644 --- a/tensorflow/contrib/nccl/kernels/nccl_manager.cc +++ b/tensorflow/contrib/nccl/kernels/nccl_manager.cc @@ -445,7 +445,7 @@ void NcclManager::LoopKernelLaunches(NcclStream* nccl_stream) { se::Stream* comm_stream = nccl_stream->stream.get(); ScopedActivateExecutorContext scoped_context(nccl_stream->executor); const cudaStream_t* cu_stream = reinterpret_cast<const cudaStream_t*>( - comm_stream->implementation()->CudaStreamMemberHack()); + comm_stream->implementation()->GpuStreamMemberHack()); while (true) { // Find collective to run. diff --git a/tensorflow/contrib/tensor_forest/BUILD b/tensorflow/contrib/tensor_forest/BUILD index 136856c015..164f3e58e6 100644 --- a/tensorflow/contrib/tensor_forest/BUILD +++ b/tensorflow/contrib/tensor_forest/BUILD @@ -223,7 +223,6 @@ tf_kernel_library( ":model_ops_lib", "//tensorflow/core:framework", "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", ], alwayslink = 1, ) @@ -319,7 +318,6 @@ tf_kernel_library( ":stats_ops_lib", "//tensorflow/core:framework", "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", ], alwayslink = 1, ) diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph.cc b/tensorflow/contrib/tensorrt/convert/convert_graph.cc index 089b03dcb5..68c78e8301 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_graph.cc +++ b/tensorflow/contrib/tensorrt/convert/convert_graph.cc @@ -831,9 +831,7 @@ tensorflow::Status ConvertAfterShapes(ConversionParams& params) { // The allocator is used to build the engine. The build and the built engine // will be destroyed after we get the serialized engine string, so it's fine // to use unique_ptr here. - // TODO(aaroey): nvinfer1::IGpuAllocator doesn't have a virtual destructor - // and destructing the unique_ptr will result in segfault, fix it. - std::unique_ptr<TRTDeviceAllocator> alloc; + std::unique_ptr<TRTBaseAllocator> alloc; auto device_alloc = GetDeviceAndAllocator(params, engine); int cuda_device_id = 0; if (device_alloc.first >= 0) { diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.cu.cc b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.cu.cc index 988b35f74f..2de7973750 100644 --- a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.cu.cc +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.cu.cc @@ -65,7 +65,7 @@ class IncPluginTRT : public OpKernel { reinterpret_cast<const cudaStream_t*>(context->op_device_context() ->stream() ->implementation() - ->CudaStreamMemberHack())); + ->GpuStreamMemberHack())); IncrementKernel(input_tensor.flat<float>().data(), inc_, output_tensor->flat<float>().data(), input_shape.num_elements(), *stream); diff --git a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc index 04d072f5d9..54009179a8 100644 --- a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc +++ b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc @@ -230,7 +230,7 @@ void TRTEngineOp::ExecuteCalibration(tensorflow::OpKernelContext* ctx, reinterpret_cast<const cudaStream_t*>(ctx->op_device_context() ->stream() ->implementation() - ->CudaStreamMemberHack())); + ->GpuStreamMemberHack())); calib_res->calibrator_->setBatch(input_data, *stream); VLOG(2) << "Passed calibration data"; ExecuteNativeSegment(ctx, helper); @@ -391,7 +391,7 @@ void TRTEngineOp::ComputeAsync(tensorflow::OpKernelContext* ctx, reinterpret_cast<const cudaStream_t*>(ctx->op_device_context() ->stream() ->implementation() - ->CudaStreamMemberHack())); + ->GpuStreamMemberHack())); // TODO(jie): trt enqueue does not return error auto& trt_execution_context_ptr = engine_ctx_pair.second; diff --git a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.h b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.h index 6fe318be6a..9265250605 100644 --- a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.h +++ b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.h @@ -81,7 +81,7 @@ class TRTEngineOp : public AsyncOpKernel { std::vector<string> output_nodes_; // keep device allocator for TRT. - std::unique_ptr<TRTDeviceAllocator> allocator_; + std::unique_ptr<TRTBaseAllocator> allocator_; // serialized protobuf segment or trt engine depending on static_engine_ flag. string serialized_segment_; diff --git a/tensorflow/contrib/tensorrt/resources/trt_allocator.cc b/tensorflow/contrib/tensorrt/resources/trt_allocator.cc index 9f115990c3..81d7330b49 100644 --- a/tensorflow/contrib/tensorrt/resources/trt_allocator.cc +++ b/tensorflow/contrib/tensorrt/resources/trt_allocator.cc @@ -37,8 +37,22 @@ void TRTCudaAllocator::free(void* memory) { cudaFree(memory); } void* TRTDeviceAllocator::allocate(uint64_t size, uint64_t alignment, uint32_t flags) { + // WAR for allocator alignment requirement. Certain cuda API calls require GPU + // memory with alignemtn to cudaDeviceProp::textureAlignment. + // See issue #20856 + alignment = 512; assert((alignment & (alignment - 1)) == 0); // zero or a power of 2. - void* mem = allocator_->AllocateRaw(alignment, size); + size_t total_size = size + alignment; + void* mem = allocator_->AllocateRaw(alignment, total_size); + if (!mem) { + return nullptr; + } + + void* alloc_mem = mem; + CHECK(std::align(alignment, size, mem, total_size)); + if (mem != alloc_mem) { + CHECK(mem_map_.insert({mem, alloc_mem}).second); + } VLOG(2) << "Allocated " << size << " bytes with alignment " << alignment << " @ " << mem; return mem; @@ -51,7 +65,15 @@ TRTDeviceAllocator::TRTDeviceAllocator(tensorflow::Allocator* allocator) void TRTDeviceAllocator::free(void* memory) { VLOG(2) << "Deallocating @ " << memory; - allocator_->DeallocateRaw(memory); + // allocated memory adjusted for alignment, restore the original pointer + if (memory) { + auto alloc_mem = mem_map_.find(memory); + if (alloc_mem != mem_map_.end()) { + memory = alloc_mem->second; + mem_map_.erase(alloc_mem->first); + } + allocator_->DeallocateRaw(memory); + } } } // namespace tensorrt diff --git a/tensorflow/contrib/tensorrt/resources/trt_allocator.h b/tensorflow/contrib/tensorrt/resources/trt_allocator.h index 97ac82ca5d..b8825b108d 100644 --- a/tensorflow/contrib/tensorrt/resources/trt_allocator.h +++ b/tensorflow/contrib/tensorrt/resources/trt_allocator.h @@ -37,7 +37,14 @@ class IGpuAllocator { namespace tensorflow { namespace tensorrt { -class TRTCudaAllocator : public nvinfer1::IGpuAllocator { +class TRTBaseAllocator : public nvinfer1::IGpuAllocator { + // Base allocator class so we can have a virtual destructor; + public: + // python wrapper seems to be not happy with an pure virtual destructor; + virtual ~TRTBaseAllocator() = default; +}; + +class TRTCudaAllocator : public TRTBaseAllocator { // Allocator implementation that is using cuda allocator instead of device // allocator in case we can't get device allocator from TF. public: @@ -47,7 +54,7 @@ class TRTCudaAllocator : public nvinfer1::IGpuAllocator { void free(void* memory) override; }; -class TRTDeviceAllocator : public nvinfer1::IGpuAllocator { +class TRTDeviceAllocator : public TRTBaseAllocator { // Allocator implementation wrapping TF device allocators. public: TRTDeviceAllocator(tensorflow::Allocator* allocator); @@ -62,6 +69,9 @@ class TRTDeviceAllocator : public nvinfer1::IGpuAllocator { private: tensorflow::Allocator* allocator_; + + // supporting alignment from allocation request requires a map to free; + std::unordered_map<void*, void*> mem_map_; }; } // namespace tensorrt diff --git a/tensorflow/contrib/tensorrt/resources/trt_resources.h b/tensorflow/contrib/tensorrt/resources/trt_resources.h index b7d5ffd674..d7d56cb95e 100644 --- a/tensorflow/contrib/tensorrt/resources/trt_resources.h +++ b/tensorflow/contrib/tensorrt/resources/trt_resources.h @@ -64,7 +64,7 @@ class TRTCalibrationResource : public tensorflow::ResourceBase { std::unique_ptr<TRTInt8Calibrator> calibrator_; TrtUniquePtrType<nvinfer1::IBuilder> builder_; TrtUniquePtrType<nvinfer1::ICudaEngine> engine_; - std::unique_ptr<nvinfer1::IGpuAllocator> allocator_; + std::unique_ptr<TRTBaseAllocator> allocator_; tensorflow::tensorrt::Logger logger_; // TODO(sami): Use threadpool threads! std::unique_ptr<std::thread> thr_; diff --git a/tensorflow/contrib/tensorrt/test/tf_trt_integration_test.py b/tensorflow/contrib/tensorrt/test/tf_trt_integration_test.py index 7c3ef498c9..035b112254 100644 --- a/tensorflow/contrib/tensorrt/test/tf_trt_integration_test.py +++ b/tensorflow/contrib/tensorrt/test/tf_trt_integration_test.py @@ -186,8 +186,8 @@ class TfTrtIntegrationTest(test_util.TensorFlowTestCase): # Defaults to 2 runs to verify result across multiple runs is same. for _ in range(num_runs): new_val = sess.run(out, {inp: input_data}) - self.assertEquals(TEST_GRAPHS[graph_key].expected_output_dims, - new_val.shape) + self.assertEqual(TEST_GRAPHS[graph_key].expected_output_dims, + new_val.shape) if val is not None: self.assertAllEqual(new_val, val) val = new_val @@ -220,19 +220,19 @@ class TfTrtIntegrationTest(test_util.TensorFlowTestCase): for n in gdef.node: if n.op == "TRTEngineOp": num_engines += 1 - self.assertNotEqual("", n.attr["serialized_segment"].s) - self.assertNotEqual("", n.attr["segment_funcdef_name"].s) - self.assertEquals(n.attr["precision_mode"].s, precision_mode) - self.assertEquals(n.attr["static_engine"].b, not dynamic_engine) + self.assertNotEqual(to_bytes(""), n.attr["serialized_segment"].s) + self.assertNotEqual(to_bytes(""), n.attr["segment_funcdef_name"].s) + self.assertEqual(n.attr["precision_mode"].s, to_bytes(precision_mode)) + self.assertEqual(n.attr["static_engine"].b, not dynamic_engine) if precision_mode == MODE_INT8 and is_calibrated: - self.assertNotEqual("", n.attr["calibration_data"].s) + self.assertNotEqual(to_bytes(""), n.attr["calibration_data"].s) else: - self.assertEquals("", n.attr["calibration_data"].s) + self.assertEqual(to_bytes(""), n.attr["calibration_data"].s) if precision_mode is None: - self.assertEquals(num_engines, 0) + self.assertEqual(num_engines, 0) else: - self.assertEquals(num_engines, - TEST_GRAPHS[graph_key].num_expected_engines) + self.assertEqual(num_engines, + TEST_GRAPHS[graph_key].num_expected_engines) def _RunTest(self, graph_key, use_optimizer, precision_mode, dynamic_infer_engine, dynamic_calib_engine): diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py index 718ea630a8..78b79b111e 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py @@ -701,8 +701,6 @@ def generate_per_core_enqueue_ops_fn_for_host( infeed_queue = tpu_feed.InfeedQueue( number_of_tuple_elements=len(per_host_sharded_inputs[0])) captured_infeed_queue.capture(infeed_queue) - infeed_queue.set_configuration_from_sharded_input_tensors( - per_host_sharded_inputs) per_host_enqueue_ops = infeed_queue.generate_enqueue_ops( per_host_sharded_inputs, tpu_ordinal_function=tpu_ordinal_function_impl) @@ -837,8 +835,6 @@ def generate_per_host_v2_enqueue_ops_fn_for_host( infeed_queue = tpu_feed.InfeedQueue( number_of_tuple_elements=len(per_host_sharded_inputs[0])) captured_infeed_queue.capture(infeed_queue) - infeed_queue.set_configuration_from_sharded_input_tensors( - per_host_sharded_inputs) per_host_enqueue_ops = infeed_queue.generate_enqueue_ops( per_host_sharded_inputs, tpu_ordinal_function=tpu_ordinal_function_impl) @@ -867,7 +863,7 @@ def generate_broadcast_enqueue_ops_fn(ctx, input_fn, inputs_structure_recorder, def tpu_ordinal_function_impl(replica_id): if ctx.device_assignment: - return ctx.device_assignment.tpu_ordinal(replica_id=replica_id) + return ctx.device_assignment.tpu_ordinal(replica=replica_id) else: return replica_id % num_replicas_per_host diff --git a/tensorflow/core/api_def/base_api/api_def_ResourceScatterNdUpdate.pbtxt b/tensorflow/core/api_def/base_api/api_def_ResourceScatterNdUpdate.pbtxt index b07ee9fda9..17b79ee30c 100644 --- a/tensorflow/core/api_def/base_api/api_def_ResourceScatterNdUpdate.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_ResourceScatterNdUpdate.pbtxt @@ -51,7 +51,7 @@ For example, say we want to update 4 scattered elements to a rank-1 tensor to 8 elements. In Python, that update would look like this: ```python - ref = tfe.Variable([1, 2, 3, 4, 5, 6, 7, 8]) + ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8]) indices = tf.constant([[4], [3], [1] ,[7]]) updates = tf.constant([9, 10, 11, 12]) update = tf.scatter_nd_update(ref, indices, updates) diff --git a/tensorflow/core/common_runtime/gpu/gpu_device.cc b/tensorflow/core/common_runtime/gpu/gpu_device.cc index 3cb51b0dbc..7110ffd40c 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_device.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_device.cc @@ -41,6 +41,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/gpu/gpu_util.h" #include "tensorflow/core/common_runtime/gpu_device_context.h" #include "tensorflow/core/common_runtime/local_device.h" +#include "tensorflow/core/common_runtime/visitable_allocator.h" #include "tensorflow/core/framework/allocator.h" #include "tensorflow/core/framework/device_base.h" #include "tensorflow/core/framework/op_kernel.h" @@ -856,7 +857,7 @@ void BaseGPUDevice::ReinitializeDevice(OpKernelContext* context, static_cast<ConcretePerOpGpuDevice*>(device); DCHECK(concrete_device); const cudaStream_t* cuda_stream = reinterpret_cast<const cudaStream_t*>( - streams_[stream_id]->compute->implementation()->CudaStreamMemberHack()); + streams_[stream_id]->compute->implementation()->GpuStreamMemberHack()); concrete_device->Reinitialize(context, cuda_stream, tf_gpu_id_, allocator, scratch_[stream_id]); } diff --git a/tensorflow/core/common_runtime/process_state.cc b/tensorflow/core/common_runtime/process_state.cc index 4d83b25ce6..447338e7bd 100644 --- a/tensorflow/core/common_runtime/process_state.cc +++ b/tensorflow/core/common_runtime/process_state.cc @@ -71,7 +71,7 @@ ProcessState::MemDesc ProcessState::PtrType(const void* ptr) { return MemDesc(); } -Allocator* ProcessState::GetCPUAllocator(int numa_node) { +VisitableAllocator* ProcessState::GetCPUAllocator(int numa_node) { CHECK_GE(numa_node, 0); if (!numa_enabled_) numa_node = 0; mutex_lock lock(mu_); diff --git a/tensorflow/core/common_runtime/process_state.h b/tensorflow/core/common_runtime/process_state.h index 0f4ae230bb..2892677333 100644 --- a/tensorflow/core/common_runtime/process_state.h +++ b/tensorflow/core/common_runtime/process_state.h @@ -65,7 +65,7 @@ class ProcessState { // Returns the one CPUAllocator used for the given numa_node. // TEMPORARY: ignores numa_node. - Allocator* GetCPUAllocator(int numa_node); + VisitableAllocator* GetCPUAllocator(int numa_node); typedef std::unordered_map<const void*, MemDesc> MDMap; @@ -87,7 +87,7 @@ class ProcessState { mutex mu_; - std::vector<Allocator*> cpu_allocators_ GUARDED_BY(mu_); + std::vector<VisitableAllocator*> cpu_allocators_ GUARDED_BY(mu_); virtual ~ProcessState(); diff --git a/tensorflow/core/graph/algorithm.cc b/tensorflow/core/graph/algorithm.cc index 4652fbe406..9b4200e0b4 100644 --- a/tensorflow/core/graph/algorithm.cc +++ b/tensorflow/core/graph/algorithm.cc @@ -25,7 +25,8 @@ namespace tensorflow { void DFS(const Graph& g, const std::function<void(Node*)>& enter, const std::function<void(Node*)>& leave, - const NodeComparator& stable_comparator) { + const NodeComparator& stable_comparator, + const EdgeFilter& edge_filter) { // Stack of work to do. struct Work { Node* node; @@ -52,7 +53,6 @@ void DFS(const Graph& g, const std::function<void(Node*)>& enter, // Arrange to call leave(n) when all done with descendants. if (leave) stack.push_back(Work{n, true}); - gtl::iterator_range<NeighborIter> nodes = n->out_nodes(); auto add_work = [&visited, &stack](Node* out) { if (!visited[out->id()]) { // Note; we must not mark as visited until we actually process it. @@ -62,16 +62,20 @@ void DFS(const Graph& g, const std::function<void(Node*)>& enter, if (stable_comparator) { std::vector<Node*> nodes_sorted; - for (Node* out : nodes) { - nodes_sorted.emplace_back(out); + for (const Edge* out_edge : n->out_edges()) { + if (!edge_filter || edge_filter(*out_edge)) { + nodes_sorted.emplace_back(out_edge->dst()); + } } std::sort(nodes_sorted.begin(), nodes_sorted.end(), stable_comparator); for (Node* out : nodes_sorted) { add_work(out); } } else { - for (Node* out : nodes) { - add_work(out); + for (const Edge* out_edge : n->out_edges()) { + if (!edge_filter || edge_filter(*out_edge)) { + add_work(out_edge->dst()); + } } } } @@ -118,8 +122,6 @@ void ReverseDFSFromHelper(const Graph& g, gtl::ArraySlice<T> start, // Arrange to call leave(n) when all done with descendants. if (leave) stack.push_back(Work{n, true}); - gtl::iterator_range<NeighborIter> nodes = n->in_nodes(); - auto add_work = [&visited, &stack](T out) { if (!visited[out->id()]) { // Note; we must not mark as visited until we actually process it. @@ -129,16 +131,16 @@ void ReverseDFSFromHelper(const Graph& g, gtl::ArraySlice<T> start, if (stable_comparator) { std::vector<T> nodes_sorted; - for (T in : nodes) { - nodes_sorted.emplace_back(in); + for (const Edge* in_edge : n->in_edges()) { + nodes_sorted.emplace_back(in_edge->src()); } std::sort(nodes_sorted.begin(), nodes_sorted.end(), stable_comparator); for (T in : nodes_sorted) { add_work(in); } } else { - for (T in : nodes) { - add_work(in); + for (const Edge* in_edge : n->in_edges()) { + add_work(in_edge->src()); } } } @@ -161,14 +163,17 @@ void ReverseDFSFrom(const Graph& g, gtl::ArraySlice<Node*> start, } void GetPostOrder(const Graph& g, std::vector<Node*>* order, - const NodeComparator& stable_comparator) { + const NodeComparator& stable_comparator, + const EdgeFilter& edge_filter) { order->clear(); - DFS(g, nullptr, [order](Node* n) { order->push_back(n); }, stable_comparator); + DFS(g, nullptr, [order](Node* n) { order->push_back(n); }, stable_comparator, + edge_filter); } void GetReversePostOrder(const Graph& g, std::vector<Node*>* order, - const NodeComparator& stable_comparator) { - GetPostOrder(g, order, stable_comparator); + const NodeComparator& stable_comparator, + const EdgeFilter& edge_filter) { + GetPostOrder(g, order, stable_comparator, edge_filter); std::reverse(order->begin(), order->end()); } diff --git a/tensorflow/core/graph/algorithm.h b/tensorflow/core/graph/algorithm.h index ac4a099013..5bbbc6f6dc 100644 --- a/tensorflow/core/graph/algorithm.h +++ b/tensorflow/core/graph/algorithm.h @@ -28,6 +28,8 @@ namespace tensorflow { // Comparator for two nodes. This is used in order to get a stable ording. using NodeComparator = std::function<bool(const Node*, const Node*)>; +using EdgeFilter = std::function<bool(const Edge&)>; + // Compares two node based on their ids. struct NodeComparatorID { bool operator()(const Node* n1, const Node* n2) const { @@ -47,9 +49,11 @@ struct NodeComparatorName { // If leave is not empty, calls leave(n) after visiting all children of n. // If stable_comparator is set, a stable ordering of visit is achieved by // sorting a node's neighbors first before visiting them. +// If edge_filter is set then ignores edges for which edge_filter returns false. extern void DFS(const Graph& g, const std::function<void(Node*)>& enter, const std::function<void(Node*)>& leave, - const NodeComparator& stable_comparator = {}); + const NodeComparator& stable_comparator = {}, + const EdgeFilter& edge_filter = {}); // Perform a reverse depth-first-search on g starting at the sink node. // If enter is not empty, calls enter(n) before visiting any parents of n. @@ -83,15 +87,21 @@ extern void ReverseDFSFrom(const Graph& g, gtl::ArraySlice<const Node*> start, // If stable_comparator is set, a stable ordering of visit is achieved by // sorting a node's neighbors first before visiting them. // +// If edge_filter is set then ignores edges for which edge_filter returns false. +// // REQUIRES: order is not NULL. void GetPostOrder(const Graph& g, std::vector<Node*>* order, - const NodeComparator& stable_comparator = {}); + const NodeComparator& stable_comparator = {}, + const EdgeFilter& edge_filter = {}); // Stores in *order the reverse post-order numbering of all nodes // If stable_comparator is set, a stable ordering of visit is achieved by // sorting a node's neighbors first before visiting them. +// +// If edge_filter is set then ignores edges for which edge_filter returns false. void GetReversePostOrder(const Graph& g, std::vector<Node*>* order, - const NodeComparator& stable_comparator = {}); + const NodeComparator& stable_comparator = {}, + const EdgeFilter& edge_filter = {}); // Prune nodes in "g" that are not in some path from the source node // to any node in 'nodes'. Returns true if changes were made to the graph. diff --git a/tensorflow/core/graph/algorithm_test.cc b/tensorflow/core/graph/algorithm_test.cc index f67d5a2fd2..60a3e66aa1 100644 --- a/tensorflow/core/graph/algorithm_test.cc +++ b/tensorflow/core/graph/algorithm_test.cc @@ -36,6 +36,11 @@ namespace { REGISTER_OP("TestParams").Output("o: float"); REGISTER_OP("TestInput").Output("a: float").Output("b: float"); REGISTER_OP("TestMul").Input("a: float").Input("b: float").Output("o: float"); +REGISTER_OP("TestUnary").Input("a: float").Output("o: float"); +REGISTER_OP("TestBinary") + .Input("a: float") + .Input("b: float") + .Output("o: float"); // Compares that the order of nodes in 'inputs' respects the // pair orders described in 'ordered_pairs'. @@ -148,5 +153,52 @@ TEST(AlgorithmTest, ReversePostOrderStable) { EXPECT_TRUE(ExpectBefore({{"t2", "t3"}}, order, &error)); } } + +TEST(AlgorithmTest, PostOrderWithEdgeFilter) { + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + string error; + Node* n0 = ops::SourceOp("TestParams", b.opts().WithName("n0")); + Node* n1 = ops::UnaryOp("TestUnary", n0, b.opts().WithName("n1")); + Node* n2 = ops::UnaryOp("TestUnary", n1, b.opts().WithName("n2")); + Node* n3 = ops::BinaryOp("TestBinary", n2, n0, b.opts().WithName("n3")); + + Graph g(OpRegistry::Global()); + TF_ASSERT_OK(GraphDefBuilderToGraph(b, &g)); + + g.AddEdge(g.FindNodeId(n3->id()), 0, g.FindNodeId(n1->id()), 1); + + std::vector<Node*> post_order; + auto edge_filter = [&](const Edge& e) { + return !(e.src()->id() == n3->id() && e.dst()->id() == n1->id()); + }; + + std::vector<Node*> expected_post_order = { + g.sink_node(), g.FindNodeId(n3->id()), g.FindNodeId(n2->id()), + g.FindNodeId(n1->id()), g.FindNodeId(n0->id()), g.source_node()}; + + std::vector<Node*> expected_reverse_post_order = expected_post_order; + std::reverse(expected_reverse_post_order.begin(), + expected_reverse_post_order.end()); + + GetPostOrder(g, &post_order, /*stable_comparator=*/{}, + /*edge_filter=*/edge_filter); + + ASSERT_EQ(expected_post_order.size(), post_order.size()); + for (int i = 0; i < post_order.size(); i++) { + CHECK_EQ(post_order[i], expected_post_order[i]) + << post_order[i]->name() << " vs. " << expected_post_order[i]->name(); + } + + std::vector<Node*> reverse_post_order; + GetReversePostOrder(g, &reverse_post_order, /*stable_comparator=*/{}, + /*edge_filter=*/edge_filter); + + ASSERT_EQ(expected_reverse_post_order.size(), reverse_post_order.size()); + for (int i = 0; i < reverse_post_order.size(); i++) { + CHECK_EQ(reverse_post_order[i], expected_reverse_post_order[i]) + << reverse_post_order[i]->name() << " vs. " + << expected_reverse_post_order[i]->name(); + } +} } // namespace } // namespace tensorflow diff --git a/tensorflow/core/graph/graph_constructor.cc b/tensorflow/core/graph/graph_constructor.cc index add26f3b71..8c73f8f712 100644 --- a/tensorflow/core/graph/graph_constructor.cc +++ b/tensorflow/core/graph/graph_constructor.cc @@ -1042,6 +1042,14 @@ Status GraphConstructor::Convert() { } if (processed < node_defs_.size()) { + LOG(WARNING) << "IN " << __func__ << (node_defs_.size() - processed) + << " NODES IN A CYCLE"; + for (int64 i = 0; i < node_defs_.size(); i++) { + if (pending_count_[i] != 0) { + LOG(WARNING) << "PENDING: " << SummarizeNodeDef(*node_defs_[i]) + << "WITH PENDING COUNT = " << pending_count_[i]; + } + } return errors::InvalidArgument(node_defs_.size() - processed, " nodes in a cycle"); } diff --git a/tensorflow/core/grappler/clusters/cluster.cc b/tensorflow/core/grappler/clusters/cluster.cc index 8d8c6084ec..6d84283e68 100644 --- a/tensorflow/core/grappler/clusters/cluster.cc +++ b/tensorflow/core/grappler/clusters/cluster.cc @@ -29,6 +29,14 @@ void Cluster::AllowSoftPlacement(bool soft_placement_state) { options_.config.set_allow_soft_placement(soft_placement_state); } +void Cluster::SetNumInterOpThreads(int num_threads) { + for (int i = 0; i < options_.config.session_inter_op_thread_pool_size(); + ++i) { + options_.config.mutable_session_inter_op_thread_pool(i)->set_num_threads( + num_threads); + } +} + void Cluster::SetNumWarmupSteps(int num_steps) { options_.config.mutable_graph_options()->set_build_cost_model_after( num_steps); diff --git a/tensorflow/core/grappler/clusters/cluster.h b/tensorflow/core/grappler/clusters/cluster.h index 06db36b3aa..e94fb900c0 100644 --- a/tensorflow/core/grappler/clusters/cluster.h +++ b/tensorflow/core/grappler/clusters/cluster.h @@ -65,6 +65,9 @@ class Cluster { // with reftype input(s) which are from CPU. void AllowSoftPlacement(bool soft_placement_state); + // Update the number of inter-op threads for each per-session threadpool + void SetNumInterOpThreads(int num_threads); + // Set the number of steps required to warmup TensorFlow. Must be called // before Provision(). void SetNumWarmupSteps(int num_steps); diff --git a/tensorflow/core/grappler/optimizers/data/BUILD b/tensorflow/core/grappler/optimizers/data/BUILD index 3cb9d4d61c..c8946c499c 100644 --- a/tensorflow/core/grappler/optimizers/data/BUILD +++ b/tensorflow/core/grappler/optimizers/data/BUILD @@ -48,10 +48,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core/grappler:graph_view", "//tensorflow/core/grappler:grappler_item", - "//tensorflow/core/grappler:grappler_item_builder", "//tensorflow/core/grappler:utils", - "//tensorflow/core/grappler/clusters:virtual_cluster", - "//tensorflow/core/grappler/optimizers:meta_optimizer", ] + tf_protos_all(), ) diff --git a/tensorflow/core/grappler/optimizers/data/graph_utils.cc b/tensorflow/core/grappler/optimizers/data/graph_utils.cc index b5b46ccafe..ea5f450009 100644 --- a/tensorflow/core/grappler/optimizers/data/graph_utils.cc +++ b/tensorflow/core/grappler/optimizers/data/graph_utils.cc @@ -16,11 +16,7 @@ limitations under the License. #include "tensorflow/core/grappler/optimizers/data/graph_utils.h" #include "tensorflow/core/framework/device_base.h" -#include "tensorflow/core/grappler/clusters/virtual_cluster.h" #include "tensorflow/core/grappler/graph_view.h" -#include "tensorflow/core/grappler/grappler_item.h" -#include "tensorflow/core/grappler/grappler_item_builder.h" -#include "tensorflow/core/grappler/optimizers/meta_optimizer.h" #include "tensorflow/core/util/ptr_util.h" namespace tensorflow { diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index 1c842150fd..99e5e3cfca 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -4855,6 +4855,8 @@ filegroup( "cast_op_impl_int64.cc", "cast_op_impl_int8.cc", "cast_op_impl_uint16.cc", + "cast_op_impl_uint32.cc", + "cast_op_impl_uint64.cc", "cast_op_impl_uint8.cc", "concat_lib.h", "concat_lib_cpu.cc", diff --git a/tensorflow/core/kernels/cast_op.cc b/tensorflow/core/kernels/cast_op.cc index 626db9131a..e6e388b3d1 100644 --- a/tensorflow/core/kernels/cast_op.cc +++ b/tensorflow/core/kernels/cast_op.cc @@ -41,8 +41,10 @@ typedef Eigen::SyclDevice SYCLDevice; #define CURRY_TYPES2(FN, arg0) \ FN(arg0, bool); \ FN(arg0, uint8); \ - FN(arg0, int8); \ FN(arg0, uint16); \ + FN(arg0, uint32); \ + FN(arg0, uint64); \ + FN(arg0, int8); \ FN(arg0, int16); \ FN(arg0, int32); \ FN(arg0, int64); \ @@ -86,10 +88,14 @@ Status CpuCastOp::Prepare() { work_ = GetCpuCastFromBool(dst_dtype_); } else if (src_dtype_ == DT_UINT8) { work_ = GetCpuCastFromUint8(dst_dtype_); - } else if (src_dtype_ == DT_INT8) { - work_ = GetCpuCastFromInt8(dst_dtype_); } else if (src_dtype_ == DT_UINT16) { work_ = GetCpuCastFromUint16(dst_dtype_); + } else if (src_dtype_ == DT_UINT32) { + work_ = GetCpuCastFromUint32(dst_dtype_); + } else if (src_dtype_ == DT_UINT64) { + work_ = GetCpuCastFromUint64(dst_dtype_); + } else if (src_dtype_ == DT_INT8) { + work_ = GetCpuCastFromInt8(dst_dtype_); } else if (src_dtype_ == DT_INT16) { work_ = GetCpuCastFromInt16(dst_dtype_); } else if (src_dtype_ == DT_INT32) { @@ -135,10 +141,14 @@ class GpuCastOp : public CastOpBase { work_ = GetGpuCastFromBool(dst_dtype_); } else if (src_dtype_ == DT_UINT8) { work_ = GetGpuCastFromUint8(dst_dtype_); - } else if (src_dtype_ == DT_INT8) { - work_ = GetGpuCastFromInt8(dst_dtype_); } else if (src_dtype_ == DT_UINT16) { work_ = GetGpuCastFromUint16(dst_dtype_); + } else if (src_dtype_ == DT_UINT32) { + work_ = GetGpuCastFromUint32(dst_dtype_); + } else if (src_dtype_ == DT_UINT64) { + work_ = GetGpuCastFromUint64(dst_dtype_); + } else if (src_dtype_ == DT_INT8) { + work_ = GetGpuCastFromInt8(dst_dtype_); } else if (src_dtype_ == DT_INT16) { work_ = GetGpuCastFromInt16(dst_dtype_); } else if (src_dtype_ == DT_INT32) { @@ -178,8 +188,10 @@ REGISTER_KERNEL_BUILDER(Name("Cast").Device(DEVICE_CPU), CpuCastOp); CURRY_TYPES2(REGISTER_CAST_GPU, bool); CURRY_TYPES2(REGISTER_CAST_GPU, uint8); -CURRY_TYPES2(REGISTER_CAST_GPU, int8); CURRY_TYPES2(REGISTER_CAST_GPU, uint16); +CURRY_TYPES2(REGISTER_CAST_GPU, uint32); +CURRY_TYPES2(REGISTER_CAST_GPU, uint64); +CURRY_TYPES2(REGISTER_CAST_GPU, int8); CURRY_TYPES2(REGISTER_CAST_GPU, int16); CURRY_TYPES2(REGISTER_CAST_GPU, int32); CURRY_TYPES2(REGISTER_CAST_GPU, int64); diff --git a/tensorflow/core/kernels/cast_op_gpu.cu.cc b/tensorflow/core/kernels/cast_op_gpu.cu.cc index 9c9e9e7658..607e7f5efd 100644 --- a/tensorflow/core/kernels/cast_op_gpu.cu.cc +++ b/tensorflow/core/kernels/cast_op_gpu.cu.cc @@ -37,8 +37,10 @@ struct CastFunctor<GPUDevice, O, I> { #define DEFINE_ALL_FROM(in_type) \ DEFINE(in_type, bool); \ DEFINE(in_type, uint8); \ - DEFINE(in_type, int8); \ DEFINE(in_type, uint16); \ + DEFINE(in_type, uint32); \ + DEFINE(in_type, uint64); \ + DEFINE(in_type, int8); \ DEFINE(in_type, int16); \ DEFINE(in_type, int32); \ DEFINE(in_type, int64); \ @@ -50,8 +52,10 @@ struct CastFunctor<GPUDevice, O, I> { DEFINE_ALL_FROM(bool); DEFINE_ALL_FROM(uint8); -DEFINE_ALL_FROM(int8); DEFINE_ALL_FROM(uint16); +DEFINE_ALL_FROM(uint32); +DEFINE_ALL_FROM(uint64); +DEFINE_ALL_FROM(int8); DEFINE_ALL_FROM(int16); DEFINE_ALL_FROM(int32); DEFINE_ALL_FROM(int64); diff --git a/tensorflow/core/kernels/cast_op_impl.h b/tensorflow/core/kernels/cast_op_impl.h index 382e5440e1..fe821b25df 100644 --- a/tensorflow/core/kernels/cast_op_impl.h +++ b/tensorflow/core/kernels/cast_op_impl.h @@ -48,8 +48,10 @@ struct CastFunctor<Eigen::SyclDevice, O, I> { #define CURRY_TYPES3_NO_HALF(FN, arg0, arg1) \ FN(arg0, arg1, bool); \ FN(arg0, arg1, uint8); \ - FN(arg0, arg1, int8); \ FN(arg0, arg1, uint16); \ + FN(arg0, arg1, uint32); \ + FN(arg0, arg1, uint64); \ + FN(arg0, arg1, int8); \ FN(arg0, arg1, int16); \ FN(arg0, arg1, int32); \ FN(arg0, arg1, int64); \ @@ -82,10 +84,16 @@ std::function<void(OpKernelContext*, const Tensor&, Tensor*)> GetCpuCastFromUint8(DataType dst_dtype); std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetCpuCastFromInt8(DataType dst_dtype); +GetCpuCastFromUint16(DataType dst_dtype); std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetCpuCastFromUint16(DataType dst_dtype); +GetCpuCastFromUint32(DataType dst_dtype); + +std::function<void(OpKernelContext*, const Tensor&, Tensor*)> +GetCpuCastFromUint64(DataType dst_dtype); + +std::function<void(OpKernelContext*, const Tensor&, Tensor*)> +GetCpuCastFromInt8(DataType dst_dtype); std::function<void(OpKernelContext*, const Tensor&, Tensor*)> GetCpuCastFromInt16(DataType dst_dtype); @@ -123,10 +131,16 @@ std::function<void(OpKernelContext*, const Tensor&, Tensor*)> GetGpuCastFromUint8(DataType dst_dtype); std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetGpuCastFromInt8(DataType dst_dtype); +GetGpuCastFromUint16(DataType dst_dtype); std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetGpuCastFromUint16(DataType dst_dtype); +GetGpuCastFromUint32(DataType dst_dtype); + +std::function<void(OpKernelContext*, const Tensor&, Tensor*)> +GetGpuCastFromUint64(DataType dst_dtype); + +std::function<void(OpKernelContext*, const Tensor&, Tensor*)> +GetGpuCastFromInt8(DataType dst_dtype); std::function<void(OpKernelContext*, const Tensor&, Tensor*)> GetGpuCastFromInt16(DataType dst_dtype); @@ -168,6 +182,12 @@ std::function<void(OpKernelContext*, const Tensor&, Tensor*)> GetSyclCastFromUint16(DataType dst_dtype); std::function<void(OpKernelContext*, const Tensor&, Tensor*)> +GetSyclCastFromUint32(DataType dst_dtype); + +std::function<void(OpKernelContext*, const Tensor&, Tensor*)> +GetSyclCastFromUint64(DataType dst_dtype); + +std::function<void(OpKernelContext*, const Tensor&, Tensor*)> GetSyclCastFromInt16(DataType dst_dtype); std::function<void(OpKernelContext*, const Tensor&, Tensor*)> diff --git a/tensorflow/core/kernels/cast_op_impl_uint32.cc b/tensorflow/core/kernels/cast_op_impl_uint32.cc new file mode 100644 index 0000000000..d1a854d98b --- /dev/null +++ b/tensorflow/core/kernels/cast_op_impl_uint32.cc @@ -0,0 +1,46 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/kernels/cast_op_impl.h" + +namespace tensorflow { + +typedef Eigen::ThreadPoolDevice CPUDevice; +typedef Eigen::GpuDevice GPUDevice; + +std::function<void(OpKernelContext*, const Tensor&, Tensor*)> +GetCpuCastFromUint32(DataType dst_dtype) { + CURRY_TYPES3(CAST_CASE, CPUDevice, uint32); + return nullptr; +} + +#if GOOGLE_CUDA +std::function<void(OpKernelContext*, const Tensor&, Tensor*)> +GetGpuCastFromUint32(DataType dst_dtype) { + CURRY_TYPES3_NO_BF16(CAST_CASE, GPUDevice, uint32); + return nullptr; +} +#endif // GOOGLE_CUDA + +#ifdef TENSORFLOW_USE_SYCL +typedef Eigen::SyclDevice SYCLDevice; +std::function<void(OpKernelContext*, const Tensor&, Tensor*)> +GetSyclCastFromUint32(DataType dst_dtype) { + CURRY_TYPES3_NO_HALF(CAST_CASE, SYCLDevice, uint32); + return nullptr; +} +#endif // TENSORFLOW_USE_SYCL + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/cast_op_impl_uint64.cc b/tensorflow/core/kernels/cast_op_impl_uint64.cc new file mode 100644 index 0000000000..604e0424fc --- /dev/null +++ b/tensorflow/core/kernels/cast_op_impl_uint64.cc @@ -0,0 +1,46 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/kernels/cast_op_impl.h" + +namespace tensorflow { + +typedef Eigen::ThreadPoolDevice CPUDevice; +typedef Eigen::GpuDevice GPUDevice; + +std::function<void(OpKernelContext*, const Tensor&, Tensor*)> +GetCpuCastFromUint64(DataType dst_dtype) { + CURRY_TYPES3(CAST_CASE, CPUDevice, uint64); + return nullptr; +} + +#if GOOGLE_CUDA +std::function<void(OpKernelContext*, const Tensor&, Tensor*)> +GetGpuCastFromUint64(DataType dst_dtype) { + CURRY_TYPES3_NO_BF16(CAST_CASE, GPUDevice, uint64); + return nullptr; +} +#endif // GOOGLE_CUDA + +#ifdef TENSORFLOW_USE_SYCL +typedef Eigen::SyclDevice SYCLDevice; +std::function<void(OpKernelContext*, const Tensor&, Tensor*)> +GetSyclCastFromUint64(DataType dst_dtype) { + CURRY_TYPES3_NO_HALF(CAST_CASE, SYCLDevice, uint64); + return nullptr; +} +#endif // TENSORFLOW_USE_SYCL + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/cast_op_test.cc b/tensorflow/core/kernels/cast_op_test.cc index 7da9d28a3d..75e21802c0 100644 --- a/tensorflow/core/kernels/cast_op_test.cc +++ b/tensorflow/core/kernels/cast_op_test.cc @@ -70,6 +70,8 @@ class CastOpTest : public OpsTestBase { #define TEST_ALL_CASTS_FROM(in) \ TEST_CAST(in, uint8); \ TEST_CAST(in, uint16); \ + TEST_CAST(in, uint32); \ + TEST_CAST(in, uint64); \ TEST_CAST(in, int16); \ TEST_CAST(in, int32); \ TEST_CAST(in, int64); \ @@ -80,6 +82,8 @@ class CastOpTest : public OpsTestBase { TEST_ALL_CASTS_FROM(uint8) TEST_ALL_CASTS_FROM(uint16) +TEST_ALL_CASTS_FROM(uint32) +TEST_ALL_CASTS_FROM(uint64) TEST_ALL_CASTS_FROM(int16) TEST_ALL_CASTS_FROM(int32) TEST_ALL_CASTS_FROM(int64) diff --git a/tensorflow/core/kernels/cuda_solvers.cc b/tensorflow/core/kernels/cuda_solvers.cc index a857bd3ce4..a59baaa96f 100644 --- a/tensorflow/core/kernels/cuda_solvers.cc +++ b/tensorflow/core/kernels/cuda_solvers.cc @@ -151,7 +151,7 @@ CudaSolver::CudaSolver(OpKernelContext* context) : context_(context) { reinterpret_cast<const cudaStream_t*>(context->op_device_context() ->stream() ->implementation() - ->CudaStreamMemberHack())); + ->GpuStreamMemberHack())); cuda_stream_ = *cu_stream_ptr; HandleMap* handle_map = CHECK_NOTNULL(GetHandleMapSingleton()); auto it = handle_map->find(cuda_stream_); diff --git a/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc b/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc index 5472a192d9..2a25459194 100644 --- a/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc +++ b/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc @@ -641,20 +641,6 @@ Status LaunchDepthwiseConv2dGPUSmall(OpKernelContext* ctx, return Status::OK(); } -namespace { -// Returns whether the context's GPU supports efficient fp16 math. -bool HasFastHalfMath(OpKernelContext* ctx) { - int major, minor; - ctx->op_device_context() - ->stream() - ->parent() - ->GetDeviceDescription() - .cuda_compute_capability(&major, &minor); - auto cuda_arch = major * 100 + minor * 10; - // GPUs before sm_53 don't support fp16 math, and sm_61's fp16 math is slow. - return cuda_arch >= 530 && cuda_arch != 610; -} - namespace detail { template <typename T> struct PseudoHalfType { @@ -666,9 +652,23 @@ struct PseudoHalfType<Eigen::half> { }; } // namespace detail +namespace { // Maps to float if T is __half, and to T otherwise. template <typename T> using PseudoHalfType = typename detail::PseudoHalfType<T>::Type; + +// Returns whether the context's GPU supports efficient fp16 math. +bool HasFastHalfMath(OpKernelContext* ctx) { + int major, minor; + ctx->op_device_context() + ->stream() + ->parent() + ->GetDeviceDescription() + .cuda_compute_capability(&major, &minor); + auto cuda_arch = major * 100 + minor * 10; + // GPUs before sm_53 don't support fp16 math, and sm_61's fp16 math is slow. + return cuda_arch >= 530 && cuda_arch != 610; +} } // namespace template <typename T, DepthwiseConv2dDirection kDirection, diff --git a/tensorflow/core/util/cuda_launch_config.h b/tensorflow/core/util/cuda_launch_config.h index 81df7a51d7..d0d95736d3 100644 --- a/tensorflow/core/util/cuda_launch_config.h +++ b/tensorflow/core/util/cuda_launch_config.h @@ -295,7 +295,7 @@ inline const cudaStream_t& GetCudaStream(OpKernelContext* context) { reinterpret_cast<const cudaStream_t*>(context->op_device_context() ->stream() ->implementation() - ->CudaStreamMemberHack())); + ->GpuStreamMemberHack())); return *ptr; } diff --git a/tensorflow/docs_src/guide/eager.md b/tensorflow/docs_src/guide/eager.md index e98206eef9..42ad9652f8 100644 --- a/tensorflow/docs_src/guide/eager.md +++ b/tensorflow/docs_src/guide/eager.md @@ -225,7 +225,7 @@ the tape backwards and then discard. A particular `tf.GradientTape` can only compute one gradient; subsequent calls throw a runtime error. ```py -w = tfe.Variable([[1.0]]) +w = tf.Variable([[1.0]]) with tf.GradientTape() as tape: loss = w * w @@ -260,8 +260,8 @@ def grad(weights, biases): train_steps = 200 learning_rate = 0.01 # Start with arbitrary values for W and B on the same batch of data -W = tfe.Variable(5.) -B = tfe.Variable(10.) +W = tf.Variable(5.) +B = tf.Variable(10.) print("Initial loss: {:.3f}".format(loss(W, B))) @@ -407,11 +407,11 @@ with tf.device("/gpu:0"): ### Variables and optimizers -`tfe.Variable` objects store mutable `tf.Tensor` values accessed during +`tf.Variable` objects store mutable `tf.Tensor` values accessed during training to make automatic differentiation easier. The parameters of a model can be encapsulated in classes as variables. -Better encapsulate model parameters by using `tfe.Variable` with +Better encapsulate model parameters by using `tf.Variable` with `tf.GradientTape`. For example, the automatic differentiation example above can be rewritten: @@ -419,8 +419,8 @@ can be rewritten: class Model(tf.keras.Model): def __init__(self): super(Model, self).__init__() - self.W = tfe.Variable(5., name='weight') - self.B = tfe.Variable(10., name='bias') + self.W = tf.Variable(5., name='weight') + self.B = tf.Variable(10., name='bias') def call(self, inputs): return inputs * self.W + self.B @@ -498,17 +498,17 @@ is removed, and is then deleted. ```py with tf.device("gpu:0"): - v = tfe.Variable(tf.random_normal([1000, 1000])) + v = tf.Variable(tf.random_normal([1000, 1000])) v = None # v no longer takes up GPU memory ``` ### Object-based saving -`tfe.Checkpoint` can save and restore `tfe.Variable`s to and from +`tfe.Checkpoint` can save and restore `tf.Variable`s to and from checkpoints: ```py -x = tfe.Variable(10.) +x = tf.Variable(10.) checkpoint = tfe.Checkpoint(x=x) # save as "x" @@ -612,7 +612,7 @@ def line_search_step(fn, init_x, rate=1.0): `tf.GradientTape` is a powerful interface for computing gradients, but there is another [Autograd](https://github.com/HIPS/autograd)-style API available for automatic differentiation. These functions are useful if writing math code with -only tensors and gradient functions, and without `tfe.Variables`: +only tensors and gradient functions, and without `tf.Variables`: * `tfe.gradients_function` —Returns a function that computes the derivatives of its input function parameter with respect to its arguments. The input diff --git a/tensorflow/docs_src/mobile/index.md b/tensorflow/docs_src/mobile/index.md index 419ae7094a..6032fcad02 100644 --- a/tensorflow/docs_src/mobile/index.md +++ b/tensorflow/docs_src/mobile/index.md @@ -13,9 +13,6 @@ Here are a few of the differences between the two: developed with TensorFlow Lite will have a smaller binary size, fewer dependencies, and better performance. -- TensorFlow Lite is in developer preview, so not all use cases are covered yet. - We expect you to use TensorFlow Mobile to cover production cases. - - TensorFlow Lite supports only a limited set of operators, so not all models will work on it by default. TensorFlow for Mobile has a fuller set of supported functionality. diff --git a/tensorflow/docs_src/mobile/tflite/index.md b/tensorflow/docs_src/mobile/tflite/index.md index 3d1733024e..cc4af2a875 100644 --- a/tensorflow/docs_src/mobile/tflite/index.md +++ b/tensorflow/docs_src/mobile/tflite/index.md @@ -70,10 +70,9 @@ There are several factors which are fueling interest in this domain: We believe the next wave of machine learning applications will have significant processing on mobile and embedded devices. -## TensorFlow Lite developer preview highlights +## TensorFlow Lite highlights -TensorFlow Lite is available as a developer preview and includes the -following: +TensorFlow Lite provides: - A set of core operators, both quantized and float, many of which have been tuned for mobile platforms. These can be used to create and run custom @@ -129,9 +128,6 @@ following: - Java and C++ API support -Note: This is a developer release, and it’s likely that there will be changes in -the API in upcoming versions. We do not guarantee backward or forward -compatibility with this release. ## Getting Started @@ -201,9 +197,5 @@ possible performance for a particular model on a particular device. ## Next Steps -For the developer preview, most of our documentation is on GitHub. Please take a -look at the [TensorFlow Lite -repository](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite) -on GitHub for more information and for code samples, demo applications, and -more. - +The TensorFlow Lite [GitHub repository](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite). +contains additional docs, code samples, and demo applications. diff --git a/tensorflow/examples/speech_commands/freeze.py b/tensorflow/examples/speech_commands/freeze.py index 7657b23c60..89e790d4e4 100644 --- a/tensorflow/examples/speech_commands/freeze.py +++ b/tensorflow/examples/speech_commands/freeze.py @@ -130,7 +130,7 @@ def main(_): FLAGS.clip_stride_ms, FLAGS.window_size_ms, FLAGS.window_stride_ms, FLAGS.feature_bin_count, FLAGS.model_architecture, FLAGS.preprocess) if FLAGS.quantize: - tf.contrib.quantize.create_training_graph(quant_delay=0) + tf.contrib.quantize.create_eval_graph() models.load_variables_from_checkpoint(sess, FLAGS.start_checkpoint) # Turn all the variables into inline constants inside the graph and save it. diff --git a/tensorflow/examples/speech_commands/models.py b/tensorflow/examples/speech_commands/models.py index 65ae3b1511..4d1454be0d 100644 --- a/tensorflow/examples/speech_commands/models.py +++ b/tensorflow/examples/speech_commands/models.py @@ -302,7 +302,7 @@ def create_conv_model(fingerprint_input, model_settings, is_training): label_count = model_settings['label_count'] final_fc_weights = tf.get_variable( name='final_fc_weights', - initializer=tf.truncated_normal_initializer, + initializer=tf.truncated_normal_initializer(stddev=0.01), shape=[second_conv_element_count, label_count]) final_fc_bias = tf.get_variable( name='final_fc_bias', diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py index 573422e533..fbc2a11eda 100644 --- a/tensorflow/python/keras/engine/training.py +++ b/tensorflow/python/keras/engine/training.py @@ -217,10 +217,9 @@ class Model(Network): for name in self.output_names: if name not in loss: logging.warning( - 'Output "' + name + '" missing from loss dictionary. ' - 'We assume this was done on purpose, ' - 'and we will not be expecting ' - 'any data to be passed to "' + name + '" during training.') + 'Output "' + name + '" missing from loss dictionary. We assume ' + 'this was done on purpose. The fit and evaluate APIs will not be ' + 'expecting any data to be passed to "' + name + '".') loss_functions.append(losses.get(loss.get(name))) elif isinstance(loss, list): if len(loss) != len(self.outputs): diff --git a/tensorflow/python/keras/engine/training_test.py b/tensorflow/python/keras/engine/training_test.py index d9e548f01f..c621a88fb3 100644 --- a/tensorflow/python/keras/engine/training_test.py +++ b/tensorflow/python/keras/engine/training_test.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import logging import os import unittest @@ -415,6 +416,28 @@ class TrainingTest(test.TestCase): x2 = model.predict(val_a) self.assertAllClose(x1, x2, atol=1e-7) + def test_compile_warning_for_loss_missing_output(self): + with self.test_session(): + inp = keras.layers.Input(shape=(16,), name='input_a') + out_1 = keras.layers.Dense(8, name='dense_1')(inp) + out_2 = keras.layers.Dense(3, activation='softmax', name='dense_2')(out_1) + model = keras.models.Model(inputs=[inp], outputs=[out_1, out_2]) + + with test.mock.patch.object(logging, 'warning') as mock_log: + model.compile( + loss={ + 'dense_2': 'categorical_crossentropy', + }, + optimizer='rmsprop', + metrics={ + 'dense_2': 'categorical_accuracy', + 'dense_1': 'categorical_accuracy', + }) + msg = ('Output "dense_1" missing from loss dictionary. We assume this ' + 'was done on purpose. The fit and evaluate APIs will not be ' + 'expecting any data to be passed to "dense_1".') + self.assertRegexpMatches(str(mock_log.call_args), msg) + class LossWeightingTest(test.TestCase): diff --git a/tensorflow/python/keras/metrics.py b/tensorflow/python/keras/metrics.py index e03d7dfe93..72e15763cb 100644 --- a/tensorflow/python/keras/metrics.py +++ b/tensorflow/python/keras/metrics.py @@ -19,9 +19,16 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from abc import ABCMeta +from abc import abstractmethod import six +from tensorflow.python.eager import context +from tensorflow.python.eager import function +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops from tensorflow.python.keras import backend as K +from tensorflow.python.keras.engine.base_layer import Layer from tensorflow.python.keras.losses import binary_crossentropy from tensorflow.python.keras.losses import categorical_crossentropy from tensorflow.python.keras.losses import cosine_proximity @@ -37,11 +44,385 @@ from tensorflow.python.keras.losses import sparse_categorical_crossentropy from tensorflow.python.keras.losses import squared_hinge from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object from tensorflow.python.keras.utils.generic_utils import serialize_keras_object +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import confusion_matrix +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn +from tensorflow.python.ops import state_ops +from tensorflow.python.ops import variable_scope as vs +from tensorflow.python.ops import weights_broadcast_ops +from tensorflow.python.training import distribute as distribute_lib +from tensorflow.python.util import tf_decorator from tensorflow.python.util.tf_export import tf_export +def update_state(update_state_fn): + """Decorator to wrap metric `update_state()` with `defun()`, `add_update()`. + + Args: + update_state_fn: function that accumulates metric statistics. + + Returns: + If eager execution is enabled, returns None. + If graph execution is enabled, returns an update op. This op should be + executed to update the metric state with the given inputs. + """ + + def decorated(*args, **kwargs): + """Decorated function with `defun()` and `add_update()`.""" + + # Converting update_state_fn() into a graph function, so that + # we can return a single op that performs all of the variable updates. + # Assigning to a different method name to avoid reference cycle. + defuned_update_state_fn = function.defun(update_state_fn) + update_op = defuned_update_state_fn(*args, **kwargs) + if update_op is not None: # update_op will be None in eager execution. + metric_obj = args[0] + metric_obj.add_update(update_op, inputs=True) + return update_op + + return tf_decorator.make_decorator(update_state_fn, decorated) + + +def result(result_fn): + """Decorator to wrap metric `result()` function in `merge_call()`. + + Result computation is an idempotent operation that simply calculates the + metric value using the state variables. + + If metric state variables are distributed across towers/devices and + `result()` is requested from the context of one device - This function wraps + `result()` in a distribution strategy `merge_call()`. With this, + the metric state variables will be aggregated across devices. + + Args: + result_fn: function that computes the metric result. + + Returns: + The metric result tensor. + """ + + def decorated(*args): + """Decorated function with merge_call.""" + tower_context = distribute_lib.get_tower_context() + if tower_context is None: # if in cross tower context already + return result_fn() + + # TODO(psv): Test distribution of metrics using different distribution + # strategies. + + # Creating a wrapper for merge_fn. merge_call invokes the given merge_fn + # with distribution object as the first parameter. We create a wrapper here + # so that the result function need not have that parameter. + def merge_fn_wrapper(distribution, merge_fn, *args): + # We will get `PerDevice` merge function. Taking the first one as all are + # identical copies of the function that we had passed below. + return distribution.unwrap(merge_fn)[0](*args) + + # Wrapping result in merge_call. merge_call is used when we want to leave + # tower mode and compute a value in cross tower mode. + return tower_context.merge_call(merge_fn_wrapper, result_fn, *args) + + return tf_decorator.make_decorator(result_fn, decorated) + + +def _safe_div(numerator, denominator): + """Divides two tensors element-wise, returning 0 if the denominator is <= 0. + + Args: + numerator: A `Tensor`. + denominator: A `Tensor`, with dtype matching `numerator`. + + Returns: + 0 if `denominator` <= 0, else `numerator` / `denominator` + """ + t = math_ops.truediv(numerator, denominator) + zero = array_ops.zeros_like(t, dtype=denominator.dtype) + condition = math_ops.greater(denominator, zero) + zero = math_ops.cast(zero, t.dtype) + return array_ops.where(condition, t, zero) + + +def _squeeze_or_expand_dimensions(y_pred, y_true, sample_weight): + """Squeeze or expand last dimension if needed. + + 1. Squeezes last dim of `y_pred` or `y_true` if their rank differs by 1 + (using `confusion_matrix.remove_squeezable_dimensions`). + 2. Squeezes or expands last dim of `sample_weight` if its rank differs by 1 + from the new rank of `y_pred`. + If `sample_weight` is scalar, it is kept scalar. + + This will use static shape if available. Otherwise, it will add graph + operations, which could result in a performance hit. + + Args: + y_pred: Predicted values, a `Tensor` of arbitrary dimensions. + y_true: Optional label `Tensor` whose dimensions match `y_pred`. + sample_weight: Optional weight scalar or `Tensor` whose dimensions match + `y_pred`. + + Returns: + Tuple of `y_pred`, `y_true` and `sample_weight`. Each of them possibly has + the last dimension squeezed, + `sample_weight` could be extended by one dimension. + """ + if y_true is not None: + # squeeze last dim of `y_pred` or `y_true` if their rank differs by 1 + y_true, y_pred = confusion_matrix.remove_squeezable_dimensions( + y_true, y_pred) + y_pred.get_shape().assert_is_compatible_with(y_true.get_shape()) + + if sample_weight is None: + return y_pred, y_true, None + + sample_weight = ops.convert_to_tensor(sample_weight) + weights_shape = sample_weight.get_shape() + weights_rank = weights_shape.ndims + if weights_rank == 0: # If weights is scalar, do nothing. + return y_pred, y_true, sample_weight + + y_pred_shape = y_pred.get_shape() + y_pred_rank = y_pred_shape.ndims + if (y_pred_rank is not None) and (weights_rank is not None): + # Use static rank. + if weights_rank - y_pred_rank == 1: + sample_weight = array_ops.squeeze(sample_weight, [-1]) + elif y_pred_rank - weights_rank == 1: + sample_weight = array_ops.expand_dims(sample_weight, [-1]) + return y_pred, y_true, sample_weight + + # Use dynamic rank. + weights_rank_tensor = array_ops.rank(sample_weight) + rank_diff = weights_rank_tensor - array_ops.rank(y_pred) + maybe_squeeze_weights = lambda: array_ops.squeeze(sample_weight, [-1]) + + def _maybe_expand_weights(): + return control_flow_ops.cond( + math_ops.equal(rank_diff, + -1), lambda: array_ops.expand_dims(sample_weight, [-1]), + lambda: sample_weight) + + def _maybe_adjust_weights(): + return control_flow_ops.cond( + math_ops.equal(rank_diff, 1), maybe_squeeze_weights, + _maybe_expand_weights) + + # squeeze or expand last dim of `sample_weight` if its rank differs by 1 + # from the new rank of `y_pred`. + sample_weight = control_flow_ops.cond( + math_ops.equal(weights_rank_tensor, 0), lambda: sample_weight, + _maybe_adjust_weights) + return y_pred, y_true, sample_weight + + +class Metric(Layer): + """Encapsulates metric logic and state. + + Usage with eager execution: + + ```python + m = SomeMetric(...) + for input in ...: + m.update_state(input) + print('Final result: ', m.result().numpy()) + ``` + + Usage with graph execution: + + ```python + m = SomeMetric(...) + init_op = tf.global_variables_initializer() # Initialize variables + with tf.Session() as sess: + sess.run(init_op) + for input in ...: + update_op = m.update_state(input) + sess.run(update_op) + print('Final result: ', sess.run(m.result())) + ``` + + To be implemented by subclasses: + * `__init__()`: All state variables should be created in this method by + calling `self.add_weight()` like: `self.var = self.add_weight(...)` + * `update_state()`: Has all updates to the state variables like: + self.var.assign_add(...). Please decorate the function with: + @update_state: Converts `update_state()` into a graph function, so that + we can return a single op that performs all of the variable updates and + adds the update op to the metric layer. + * `result()`: Computes and returns a value for the metric + from the state variables. Please decorate the function with: + @result: Wraps `result()` in a distribution strategy merge_call(). + + Example subclass implementation: + + ``` + class BinaryTruePositives(Metric): + def __init__(self, name='binary-true-positives', dtype=dtypes.float64): + super(BinaryTruePositives, self).__init__(name=name, dtype=dtype) + self.true_positives = self.add_weight( + 'true_positives', initializer=init_ops.zeros_initializer) + + @update_state + def update_state(self, y_true, y_pred, sample_weight=None): + y_true = math_ops.cast(y_true, dtypes.bool) + y_pred = math_ops.cast(y_pred, dtypes.bool) + y_pred, y_true, sample_weight = _squeeze_or_expand_dimensions( + y_pred, y_true, sample_weight) + + values = math_ops.logical_and( + math_ops.equal(y_true, True), math_ops.equal(y_pred, True)) + values = math_ops.cast(values, self._dtype) + if sample_weight is not None: + sample_weight = math_ops.cast(sample_weight, self._dtype) + values = math_ops.multiply(values, sample_weight) + state_ops.assign_add(self.true_positives, math_ops.reduce_sum(values)) + + @result + def result(self): + return array_ops.identity(self.true_positives) + ``` + """ + __metaclass__ = ABCMeta + + def __init__(self, name=None, dtype=dtypes.float64): + super(Metric, self).__init__(name=name, dtype=dtype) + self.stateful = True # All metric layers are stateful. + self.built = True + + def __call__(self, *args, **kwargs): + """Accumulates statistics and then computes metric result value. + + Args: + *args: + **kwargs: A mini-batch of inputs to the Metric, + passed on to `update_state()`. + + Returns: + The metric value tensor. + """ + update_op = self.update_state(*args, **kwargs) + with ops.control_dependencies([update_op]): + return self.result() + + def reset_states(self): + """Resets all of the metric state variables. + + This function is called between epochs/steps, + when a metric is evaluated during training. + """ + for v in self.variables: + K.set_value(v, 0) + + @abstractmethod + def update_state(self, *args, **kwargs): + """Accumulates statistics for the metric. + + Please decorate the function with: + @update_state: Converts `update_state()` into a graph function, so that + we can return a single op that performs all of the variable updates + This means: + a) Operations on the same resource are executed in textual order. + This should make it easier to do things like add the updated + value of a variable to another, for example. + b) You don't need to worry about collecting the update ops to execute. + All update ops added to the graph by this function will be executed. + As a result, code should generally work the same way with graph or + eager execution. + and adds the update op to the metric layer. + + Args: + *args: + **kwargs: A mini-batch of inputs to the Metric. + """ + NotImplementedError('Must be implemented in subclasses.') + + @abstractmethod + def result(self): + """Computes and returns the metric value tensor. + + Result computation is an idempotent operation that simply calculates the + metric value using the state variables. + + Please decorate the function with: + @result: Wraps `result()` in a distribution strategy merge_call(). + """ + NotImplementedError('Must be implemented in subclasses.') + + ### For use by subclasses ### + def add_weight(self, + name, + shape=(), + aggregation=vs.VariableAggregation.SUM, + synchronization=vs.VariableSynchronization.ON_READ, + initializer=None): + """Adds state variable. Only for use by subclasses.""" + return super(Metric, self).add_weight( + name=name, + shape=shape, + dtype=self._dtype, + trainable=False, + initializer=initializer, + synchronization=synchronization, + aggregation=aggregation) + + ### End: For use by subclasses ### + + +class Mean(Metric): + """Computes the (weighted) mean of the given values. + + This metric creates two variables, `total` and `count` that are used to + compute the average of `values`. This average is ultimately returned as `mean` + which is an idempotent operation that simply divides `total` by `count`. + + If `sample_weight` is `None`, weights default to 1. + Use `sample_weight` of 0 to mask values. + """ + + def __init__(self, name='mean', dtype=dtypes.float64): + super(Mean, self).__init__(name=name, dtype=dtype) + # Create new state variables + self.total = self.add_weight( + 'total', initializer=init_ops.zeros_initializer) + self.count = self.add_weight( + 'count', initializer=init_ops.zeros_initializer) + + @update_state + def update_state(self, values, sample_weight=None): + """Accumulates statistics for computing the mean. + + For example, if `values` is [1, 3, 5, 7] then the mean is 4. If + the `sample_weight` is specified as [1, 1, 0, 0] then the mean would be 2. + + Args: + values: Per-example value. + sample_weight: Optional weighting of each example. Defaults to 1. + """ + values = math_ops.cast(values, self._dtype) + if sample_weight is None: + num_values = math_ops.cast(array_ops.size(values), self._dtype) + else: + sample_weight = math_ops.cast(sample_weight, self._dtype) + + # Update dimensions of weights to match with values. + values, _, sample_weight = _squeeze_or_expand_dimensions( + values, None, sample_weight) + sample_weight = weights_broadcast_ops.broadcast_weights( + sample_weight, values) + num_values = math_ops.reduce_sum(sample_weight) + values = math_ops.multiply(values, sample_weight) + values = math_ops.reduce_sum(values) + + # Update state variables + state_ops.assign_add(self.total, values) + state_ops.assign_add(self.count, num_values) + + @result + def result(self): + return _safe_div(self.total, self.count) + + @tf_export('keras.metrics.binary_accuracy') def binary_accuracy(y_true, y_pred): return K.mean(math_ops.equal(y_true, math_ops.round(y_pred)), axis=-1) diff --git a/tensorflow/python/keras/metrics_test.py b/tensorflow/python/keras/metrics_test.py index 15e793f5fc..6d8269f34d 100644 --- a/tensorflow/python/keras/metrics_test.py +++ b/tensorflow/python/keras/metrics_test.py @@ -18,67 +18,72 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import os import numpy as np -from tensorflow.python import keras +from tensorflow.python.eager import context +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.keras import backend as K +from tensorflow.python.keras import layers +from tensorflow.python.keras import metrics +from tensorflow.python.keras.engine.training import Model +from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import state_ops +from tensorflow.python.ops import variables from tensorflow.python.platform import test +from tensorflow.python.training.checkpointable import util as checkpointable_utils class KerasMetricsTest(test.TestCase): def test_metrics(self): with self.test_session(): - y_a = keras.backend.variable(np.random.random((6, 7))) - y_b = keras.backend.variable(np.random.random((6, 7))) - for metric in [keras.metrics.binary_accuracy, - keras.metrics.categorical_accuracy]: + y_a = K.variable(np.random.random((6, 7))) + y_b = K.variable(np.random.random((6, 7))) + for metric in [metrics.binary_accuracy, metrics.categorical_accuracy]: output = metric(y_a, y_b) - self.assertEqual(keras.backend.eval(output).shape, (6,)) + self.assertEqual(K.eval(output).shape, (6,)) def test_sparse_categorical_accuracy(self): with self.test_session(): - metric = keras.metrics.sparse_categorical_accuracy - y_a = keras.backend.variable(np.random.randint(0, 7, (6,))) - y_b = keras.backend.variable(np.random.random((6, 7))) - self.assertEqual(keras.backend.eval(metric(y_a, y_b)).shape, (6,)) + metric = metrics.sparse_categorical_accuracy + y_a = K.variable(np.random.randint(0, 7, (6,))) + y_b = K.variable(np.random.random((6, 7))) + self.assertEqual(K.eval(metric(y_a, y_b)).shape, (6,)) def test_sparse_top_k_categorical_accuracy(self): with self.test_session(): - y_pred = keras.backend.variable(np.array([[0.3, 0.2, 0.1], - [0.1, 0.2, 0.7]])) - y_true = keras.backend.variable(np.array([[1], [0]])) - result = keras.backend.eval( - keras.metrics.sparse_top_k_categorical_accuracy(y_true, y_pred, k=3)) + y_pred = K.variable(np.array([[0.3, 0.2, 0.1], [0.1, 0.2, 0.7]])) + y_true = K.variable(np.array([[1], [0]])) + result = K.eval( + metrics.sparse_top_k_categorical_accuracy(y_true, y_pred, k=3)) self.assertEqual(result, 1) - result = keras.backend.eval( - keras.metrics.sparse_top_k_categorical_accuracy(y_true, y_pred, k=2)) + result = K.eval( + metrics.sparse_top_k_categorical_accuracy(y_true, y_pred, k=2)) self.assertEqual(result, 0.5) - result = keras.backend.eval( - keras.metrics.sparse_top_k_categorical_accuracy(y_true, y_pred, k=1)) + result = K.eval( + metrics.sparse_top_k_categorical_accuracy(y_true, y_pred, k=1)) self.assertEqual(result, 0.) def test_top_k_categorical_accuracy(self): with self.test_session(): - y_pred = keras.backend.variable(np.array([[0.3, 0.2, 0.1], - [0.1, 0.2, 0.7]])) - y_true = keras.backend.variable(np.array([[0, 1, 0], [1, 0, 0]])) - result = keras.backend.eval( - keras.metrics.top_k_categorical_accuracy(y_true, y_pred, k=3)) + y_pred = K.variable(np.array([[0.3, 0.2, 0.1], [0.1, 0.2, 0.7]])) + y_true = K.variable(np.array([[0, 1, 0], [1, 0, 0]])) + result = K.eval(metrics.top_k_categorical_accuracy(y_true, y_pred, k=3)) self.assertEqual(result, 1) - result = keras.backend.eval( - keras.metrics.top_k_categorical_accuracy(y_true, y_pred, k=2)) + result = K.eval(metrics.top_k_categorical_accuracy(y_true, y_pred, k=2)) self.assertEqual(result, 0.5) - result = keras.backend.eval( - keras.metrics.top_k_categorical_accuracy(y_true, y_pred, k=1)) + result = K.eval(metrics.top_k_categorical_accuracy(y_true, y_pred, k=1)) self.assertEqual(result, 0.) def test_stateful_metrics(self): with self.test_session(): np.random.seed(1334) - class BinaryTruePositives(keras.layers.Layer): + class BinaryTruePositives(layers.Layer): """Stateful Metric to count the total true positives over all batches. Assumes predictions and targets of shape `(samples, 1)`. @@ -91,11 +96,11 @@ class KerasMetricsTest(test.TestCase): def __init__(self, name='true_positives', **kwargs): super(BinaryTruePositives, self).__init__(name=name, **kwargs) - self.true_positives = keras.backend.variable(value=0, dtype='int32') + self.true_positives = K.variable(value=0, dtype='int32') self.stateful = True def reset_states(self): - keras.backend.set_value(self.true_positives, 0) + K.set_value(self.true_positives, 0) def __call__(self, y_true, y_pred): """Computes the number of true positives in a batch. @@ -120,14 +125,14 @@ class KerasMetricsTest(test.TestCase): return current_true_pos + true_pos metric_fn = BinaryTruePositives() - config = keras.metrics.serialize(metric_fn) - metric_fn = keras.metrics.deserialize( + config = metrics.serialize(metric_fn) + metric_fn = metrics.deserialize( config, custom_objects={'BinaryTruePositives': BinaryTruePositives}) # Test on simple model - inputs = keras.Input(shape=(2,)) - outputs = keras.layers.Dense(1, activation='sigmoid')(inputs) - model = keras.Model(inputs, outputs) + inputs = layers.Input(shape=(2,)) + outputs = layers.Dense(1, activation='sigmoid')(inputs) + model = Model(inputs, outputs) model.compile(optimizer='sgd', loss='binary_crossentropy', metrics=['acc', metric_fn]) @@ -184,6 +189,125 @@ class KerasMetricsTest(test.TestCase): self.assertAllClose( val_outs[2], history.history['val_true_positives'][-1], atol=1e-5) + @test_util.run_in_graph_and_eager_modes + def test_mean(self): + m = metrics.Mean(name='my_mean') + + # check config + self.assertEqual(m.name, 'my_mean') + self.assertTrue(m.stateful) + self.assertEqual(m.dtype, dtypes.float64) + self.assertEqual(len(m.variables), 2) + self.evaluate(variables.global_variables_initializer()) + + # check initial state + self.assertEqual(self.evaluate(m.total), 0) + self.assertEqual(self.evaluate(m.count), 0) + + # check __call__() + self.assertEqual(self.evaluate(m(100)), 100) + self.assertEqual(self.evaluate(m.total), 100) + self.assertEqual(self.evaluate(m.count), 1) + + # check update_state() and result() + state accumulation + tensor input + update_op = m.update_state(ops.convert_n_to_tensor([1, 5])) + self.evaluate(update_op) + self.assertEqual(self.evaluate(m.result()), 106 / 3) + self.assertEqual(self.evaluate(m.total), 106) # 100 + 1 + 5 + self.assertEqual(self.evaluate(m.count), 3) + + # check reset_states() + m.reset_states() + self.assertEqual(self.evaluate(m.total), 0) + self.assertEqual(self.evaluate(m.count), 0) + + @test_util.run_in_graph_and_eager_modes + def test_mean_with_sample_weight(self): + m = metrics.Mean() + self.evaluate(variables.global_variables_initializer()) + + # check scalar weight + result_t = m(100, sample_weight=0.5) + self.assertEqual(self.evaluate(result_t), 50 / 0.5) + self.assertEqual(self.evaluate(m.total), 50) + self.assertEqual(self.evaluate(m.count), 0.5) + + # check weights not scalar and weights rank matches values rank + result_t = m([1, 5], sample_weight=[1, 0.2]) + result = self.evaluate(result_t) + self.assertAlmostEqual(result, 52 / 1.7, 2) + self.assertAlmostEqual(self.evaluate(m.total), 52, 2) # 50 + 1 + 5 * 0.2 + self.assertAlmostEqual(self.evaluate(m.count), 1.7, 2) # 0.5 + 1.2 + + # check weights broadcast + result_t = m([1, 2], sample_weight=0.5) + self.assertAlmostEqual(self.evaluate(result_t), 53.5 / 2.7, 2) + self.assertAlmostEqual(self.evaluate(m.total), 53.5, 2) # 52 + 0.5 + 1 + self.assertAlmostEqual(self.evaluate(m.count), 2.7, 2) # 1.7 + 0.5 + 0.5 + + # check weights squeeze + result_t = m([1, 5], sample_weight=[[1], [0.2]]) + self.assertAlmostEqual(self.evaluate(result_t), 55.5 / 3.9, 2) + self.assertAlmostEqual(self.evaluate(m.total), 55.5, 2) # 53.5 + 1 + 1 + self.assertAlmostEqual(self.evaluate(m.count), 3.9, 2) # 2.7 + 1.2 + + # check weights expand + result_t = m([[1], [5]], sample_weight=[1, 0.2]) + self.assertAlmostEqual(self.evaluate(result_t), 57.5 / 5.1, 2) + self.assertAlmostEqual(self.evaluate(m.total), 57.5, 2) # 55.5 + 1 + 1 + self.assertAlmostEqual(self.evaluate(m.count), 5.1, 2) # 3.9 + 1.2 + + def test_mean_graph_with_placeholder(self): + with context.graph_mode(), self.test_session() as sess: + m = metrics.Mean() + v = array_ops.placeholder(dtypes.float32) + w = array_ops.placeholder(dtypes.float32) + sess.run(variables.global_variables_initializer()) + + # check __call__() + result_t = m(v, sample_weight=w) + result = sess.run(result_t, feed_dict=({v: 100, w: 0.5})) + self.assertEqual(sess.run(m.total), 50) + self.assertEqual(sess.run(m.count), 0.5) + self.assertEqual(result, 50 / 0.5) + + # check update_state() and result() + result = sess.run(result_t, feed_dict=({v: [1, 5], w: [1, 0.2]})) + self.assertAlmostEqual(sess.run(m.total), 52, 2) # 50 + 1 + 5 * 0.2 + self.assertAlmostEqual(sess.run(m.count), 1.7, 2) # 0.5 + 1.2 + self.assertAlmostEqual(result, 52 / 1.7, 2) + + @test_util.run_in_graph_and_eager_modes + def test_save_restore(self): + checkpoint_directory = self.get_temp_dir() + checkpoint_prefix = os.path.join(checkpoint_directory, 'ckpt') + m = metrics.Mean() + checkpoint = checkpointable_utils.Checkpoint(mean=m) + self.evaluate(variables.global_variables_initializer()) + + # update state + self.evaluate(m(100.)) + self.evaluate(m(200.)) + + # save checkpoint and then add an update + save_path = checkpoint.save(checkpoint_prefix) + self.evaluate(m(1000.)) + + # restore to the same checkpoint mean object + checkpoint.restore(save_path).assert_consumed().run_restore_ops() + self.evaluate(m(300.)) + self.assertEqual(200., self.evaluate(m.result())) + + # restore to a different checkpoint mean object + restore_mean = metrics.Mean() + restore_checkpoint = checkpointable_utils.Checkpoint(mean=restore_mean) + status = restore_checkpoint.restore(save_path) + restore_update = restore_mean(300.) + status.assert_consumed().run_restore_ops() + self.evaluate(restore_update) + self.assertEqual(200., self.evaluate(restore_mean.result())) + self.assertEqual(3, self.evaluate(restore_mean.count)) + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/kernel_tests/resource_variable_ops_test.py b/tensorflow/python/kernel_tests/resource_variable_ops_test.py index e358293a90..c739cd2c0d 100644 --- a/tensorflow/python/kernel_tests/resource_variable_ops_test.py +++ b/tensorflow/python/kernel_tests/resource_variable_ops_test.py @@ -246,6 +246,15 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) self.assertEqual(self.evaluate(read), [[2]]) + def testUseResource(self): + v = variables.Variable(1.0, use_resource=True) + self.assertTrue(isinstance(v, resource_variable_ops.ResourceVariable)) + + def testEagerNoUseResource(self): + with context.eager_mode(): + v = variables.Variable(1.0) + self.assertTrue(isinstance(v, resource_variable_ops.ResourceVariable)) + @test_util.run_in_graph_and_eager_modes def testScatterMin(self): with ops.device("cpu:0"): diff --git a/tensorflow/python/ops/resource_variable_ops.py b/tensorflow/python/ops/resource_variable_ops.py index 1f56ad25bf..5979b76ff2 100644 --- a/tensorflow/python/ops/resource_variable_ops.py +++ b/tensorflow/python/ops/resource_variable_ops.py @@ -1294,3 +1294,16 @@ def is_resource_variable(var): """"Returns True if `var` is to be considered a ResourceVariable.""" return isinstance(var, ResourceVariable) or hasattr( var, "_should_act_as_resource_variable") + + +_DEFAULT_USE_RESOURCE = False + + +def _default_variable_creator(_, *args, **kwds): + use_resource = kwds.pop("use_resource", _DEFAULT_USE_RESOURCE) + use_resource = use_resource or context.executing_eagerly() + if use_resource: + return ResourceVariable(*args, **kwds) + return variables.RefVariable(*args, **kwds) + +variables.default_variable_creator = _default_variable_creator diff --git a/tensorflow/python/ops/variable_scope.py b/tensorflow/python/ops/variable_scope.py index 77f67c18ee..0f37dcc027 100644 --- a/tensorflow/python/ops/variable_scope.py +++ b/tensorflow/python/ops/variable_scope.py @@ -191,36 +191,9 @@ class _ReuseMode(enum.Enum): # REUSE_TRUE = 3 -@tf_export("VariableSynchronization") -class VariableSynchronization(enum.Enum): - """Indicates when a distributed variable will be synced.""" - - # Indicates that the synchronization will be determined by the current - # `DistributionStrategy` (eg. With `MirroredStrategy` this would be - # `ON_WRITE`). - AUTO = 0 - - # Indicates that there will only be one copy of the variable, so there is no - # need to sync. - NONE = 1 - - # Indicates that the variable will be aggregated across devices - # every time it is updated. - ON_WRITE = 2 - - # Indicates that the variable will be aggregated across devices - # when it is read (eg. when checkpointing or when evaluating an op that uses - # the variable). - ON_READ = 3 - - -@tf_export("VariableAggregation") -class VariableAggregation(enum.Enum): - """Indicates how a distributed variable will be aggregated.""" - NONE = 0 - SUM = 1 - MEAN = 2 - +# TODO(apassos) remove these forwarding symbols. +VariableSynchronization = variables.VariableSynchronization # pylint: disable=invalid-name +VariableAggregation = variables.VariableAggregation # pylint: disable=invalid-name AUTO_REUSE = _ReuseMode.AUTO_REUSE tf_export("AUTO_REUSE").export_constant(__name__, "AUTO_REUSE") diff --git a/tensorflow/python/ops/variables.py b/tensorflow/python/ops/variables.py index 87e0de197c..6bb2d6f669 100644 --- a/tensorflow/python/ops/variables.py +++ b/tensorflow/python/ops/variables.py @@ -17,6 +17,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import enum # pylint: disable=g-bad-import-order + import six from tensorflow.core.framework import attr_value_pb2 @@ -38,8 +40,9 @@ from tensorflow.python.util.deprecation import deprecated from tensorflow.python.util.tf_export import tf_export -def _default_variable_creator(_, *args, **kwds): - return RefVariable(*args, **kwds) +def default_variable_creator(_, *args, **kwds): + del args, kwds + raise NotImplementedError("resource_variable_ops needs to be imported") def _make_getter(captured_getter, captured_previous): @@ -49,12 +52,43 @@ def _make_getter(captured_getter, captured_previous): return getter +@tf_export("VariableSynchronization") +class VariableSynchronization(enum.Enum): + """Indicates when a distributed variable will be synced.""" + + # Indicates that the synchronization will be determined by the current + # `DistributionStrategy` (eg. With `MirroredStrategy` this would be + # `ON_WRITE`). + AUTO = 0 + + # Indicates that there will only be one copy of the variable, so there is no + # need to sync. + NONE = 1 + + # Indicates that the variable will be aggregated across devices + # every time it is updated. + ON_WRITE = 2 + + # Indicates that the variable will be aggregated across devices + # when it is read (eg. when checkpointing or when evaluating an op that uses + # the variable). + ON_READ = 3 + + +@tf_export("VariableAggregation") +class VariableAggregation(enum.Enum): + """Indicates how a distributed variable will be aggregated.""" + NONE = 0 + SUM = 1 + MEAN = 2 + + class VariableMetaclass(type): """Metaclass to allow construction of tf.Variable to be overridden.""" def __call__(cls, *args, **kwargs): if cls is Variable: - previous_getter = lambda *a, **k: _default_variable_creator(None, *a, **k) + previous_getter = lambda *a, **k: default_variable_creator(None, *a, **k) # TODO(apassos) use a stack of getters here return previous_getter(*args, **kwargs) else: @@ -172,14 +206,6 @@ class Variable(six.with_metaclass(VariableMetaclass, * Replace `tf.Variable` with `tf.contrib.eager.Variable`; * Call `tf.get_variable_scope().set_use_resource(True)` inside a `tf.variable_scope` before the `tf.get_variable()` call. - - @compatibility(eager) - `tf.Variable` is not compatible with eager execution. Use - `tf.contrib.eager.Variable` instead which is compatible with both eager - execution and graph construction. See [the TensorFlow Eager Execution - guide](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/eager/python/g3doc/guide.md#variables-and-optimizers) - for details on how variables work in eager execution. - @end_compatibility """ def __init__(self, @@ -193,7 +219,10 @@ class Variable(six.with_metaclass(VariableMetaclass, dtype=None, expected_shape=None, import_scope=None, - constraint=None): + constraint=None, + use_resource=None, + synchronization=VariableSynchronization.AUTO, + aggregation=VariableAggregation.NONE): """Creates a new variable with value `initial_value`. The new variable is added to the graph collections listed in `collections`, @@ -245,20 +274,24 @@ class Variable(six.with_metaclass(VariableMetaclass, variable and return the Tensor for the projected value (which must have the same shape). Constraints are not safe to use when doing asynchronous distributed training. + use_resource: if True, a ResourceVariable is created; otherwise an + old-style ref-based variable is created. When eager execution is enabled + a resource variable is always created. + synchronization: Indicates when a distributed a variable will be + aggregated. Accepted values are constants defined in the class + @{tf.VariableSynchronization}. By default the synchronization is set to + `AUTO` and the current `DistributionStrategy` chooses + when to synchronize. If `synchronization` is set to `ON_READ`, + `trainable` must not be set to `True`. + aggregation: Indicates how a distributed variable will be aggregated. + Accepted values are constants defined in the class + @{tf.VariableAggregation}. Raises: ValueError: If both `variable_def` and initial_value are specified. ValueError: If the initial value is not specified, or does not have a shape and `validate_shape` is `True`. RuntimeError: If eager execution is enabled. - - @compatibility(eager) - `tf.Variable` is not compatible with eager execution. Use - `tfe.Variable` instead which is compatible with both eager execution - and graph construction. See [the TensorFlow Eager Execution - guide](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/eager/python/g3doc/guide.md#variables-and-optimizers) - for details on how variables work in eager execution. - @end_compatibility """ raise NotImplementedError @@ -1714,7 +1747,7 @@ class PartitionedVariable(object): """A container for partitioned `Variable` objects. @compatibility(eager) `tf.PartitionedVariable` is not compatible with - eager execution. Use `tfe.Variable` instead which is compatible + eager execution. Use `tf.Variable` instead which is compatible with both eager execution and graph construction. See [the TensorFlow Eager Execution guide](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/eager/python/g3doc/guide.md#variables-and-optimizers) diff --git a/tensorflow/python/platform/gfile.py b/tensorflow/python/platform/gfile.py index fd697d70bf..45de047894 100644 --- a/tensorflow/python/platform/gfile.py +++ b/tensorflow/python/platform/gfile.py @@ -38,7 +38,14 @@ from tensorflow.python.util.tf_export import tf_export @tf_export('gfile.GFile', 'gfile.Open') class GFile(_FileIO): - """File I/O wrappers without thread locking.""" + """File I/O wrappers without thread locking. + + Note, that this is somewhat like builtin Python file I/O, but + there are semantic differences to make it more efficient for + some backing filesystems. For example, a write mode file will + not be opened until the first write call (to minimize RPC + invocations in network filesystems). + """ def __init__(self, name, mode='r'): super(GFile, self).__init__(name=name, mode=mode) @@ -46,7 +53,14 @@ class GFile(_FileIO): @tf_export('gfile.FastGFile') class FastGFile(_FileIO): - """File I/O wrappers without thread locking.""" + """File I/O wrappers without thread locking. + + Note, that this is somewhat like builtin Python file I/O, but + there are semantic differences to make it more efficient for + some backing filesystems. For example, a write mode file will + not be opened until the first write call (to minimize RPC + invocations in network filesystems). + """ def __init__(self, name, mode='r'): super(FastGFile, self).__init__(name=name, mode=mode) diff --git a/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc b/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc index f11022ef1d..259c813c57 100644 --- a/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc +++ b/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc @@ -844,7 +844,7 @@ CUDAExecutor::GetTimerImplementation() { return std::unique_ptr<internal::TimerInterface>(new CUDATimer(this)); } -void *CUDAExecutor::CudaContextHack() { return context_; } +void *CUDAExecutor::GpuContextHack() { return context_; } CudaContext* CUDAExecutor::cuda_context() { return context_; } diff --git a/tensorflow/stream_executor/cuda/cuda_gpu_executor.h b/tensorflow/stream_executor/cuda/cuda_gpu_executor.h index 773cbfb8a1..f7c341c857 100644 --- a/tensorflow/stream_executor/cuda/cuda_gpu_executor.h +++ b/tensorflow/stream_executor/cuda/cuda_gpu_executor.h @@ -210,7 +210,7 @@ class CUDAExecutor : public internal::StreamExecutorInterface { std::unique_ptr<internal::TimerInterface> GetTimerImplementation() override; - void *CudaContextHack() override; + void *GpuContextHack() override; CudaContext* cuda_context(); diff --git a/tensorflow/stream_executor/cuda/cuda_stream.h b/tensorflow/stream_executor/cuda/cuda_stream.h index 02edff6431..bb8bda4755 100644 --- a/tensorflow/stream_executor/cuda/cuda_stream.h +++ b/tensorflow/stream_executor/cuda/cuda_stream.h @@ -40,8 +40,8 @@ class CUDAStream : public internal::StreamInterface { // Note: teardown is handled by a parent's call to DeallocateStream. ~CUDAStream() override {} - void *CudaStreamHack() override { return cuda_stream_; } - void **CudaStreamMemberHack() override { + void *GpuStreamHack() override { return cuda_stream_; } + void **GpuStreamMemberHack() override { return reinterpret_cast<void **>(&cuda_stream_); } diff --git a/tensorflow/stream_executor/host/host_gpu_executor.h b/tensorflow/stream_executor/host/host_gpu_executor.h index e82f57569f..858396ef96 100644 --- a/tensorflow/stream_executor/host/host_gpu_executor.h +++ b/tensorflow/stream_executor/host/host_gpu_executor.h @@ -202,7 +202,7 @@ class HostExecutor : public internal::StreamExecutorInterface { return std::unique_ptr<internal::TimerInterface>(new HostTimer()); } - void *CudaContextHack() override { return nullptr; } + void *GpuContextHack() override { return nullptr; } private: const PluginConfig plugin_config_; diff --git a/tensorflow/stream_executor/host/host_stream.h b/tensorflow/stream_executor/host/host_stream.h index 5d7b8a3782..be88f074cf 100644 --- a/tensorflow/stream_executor/host/host_stream.h +++ b/tensorflow/stream_executor/host/host_stream.h @@ -34,8 +34,8 @@ class HostStream : public internal::StreamInterface { bool EnqueueTask(std::function<void()> task); - void *CudaStreamHack() override { return nullptr; } - void **CudaStreamMemberHack() override { return nullptr; } + void *GpuStreamHack() override { return nullptr; } + void **GpuStreamMemberHack() override { return nullptr; } void BlockUntilDone(); diff --git a/tensorflow/stream_executor/stream_executor_internal.h b/tensorflow/stream_executor/stream_executor_internal.h index 9c989b971d..fb1b92cb84 100644 --- a/tensorflow/stream_executor/stream_executor_internal.h +++ b/tensorflow/stream_executor/stream_executor_internal.h @@ -100,19 +100,20 @@ class StreamInterface { // Default destructor for the abstract interface. virtual ~StreamInterface() {} - // Returns the CUDA stream associated with this platform's stream + // Returns the GPU stream associated with this platform's stream // implementation. // - // WARNING: checks that the underlying platform is, in fact, CUDA, causing a - // fatal error if it is not. This hack is made available solely for use from - // distbelief code, which temporarily has strong ties to CUDA as a platform. - virtual void *CudaStreamHack() { return nullptr; } - - // See the above comment on CudaStreamHack -- this further breaks abstraction - // for Eigen within distbelief, which has strong ties to CUDA as a platform, - // and a historical attachment to a programming model which takes a + // WARNING: checks that the underlying platform is, in fact, CUDA or ROCm, + // causing a fatal error if it is not. This hack is made available solely for + // use from distbelief code, which temporarily has strong ties to CUDA or + // ROCm as a platform. + virtual void *GpuStreamHack() { return nullptr; } + + // See the above comment on GpuStreamHack -- this further breaks abstraction + // for Eigen within distbelief, which has strong ties to CUDA or ROCm as a + // platform, and a historical attachment to a programming model which takes a // stream-slot rather than a stream-value. - virtual void **CudaStreamMemberHack() { return nullptr; } + virtual void **GpuStreamMemberHack() { return nullptr; } private: SE_DISALLOW_COPY_AND_ASSIGN(StreamInterface); @@ -324,13 +325,14 @@ class StreamExecutorInterface { virtual std::unique_ptr<StreamInterface> GetStreamImplementation() = 0; virtual std::unique_ptr<TimerInterface> GetTimerImplementation() = 0; - // Returns the CUDA context associated with this StreamExecutor platform - // implementation. + // Returns the CUDA or ROCm context associated with this StreamExecutor + // platform implementation. // - // WARNING: checks that the underlying platform is, in fact, CUDA, causing a - // fatal error if it is not. This hack is made available solely for use from - // distbelief code, which temporarily has strong ties to CUDA as a platform. - virtual void *CudaContextHack() { return nullptr; } + // WARNING: checks that the underlying platform is, in fact, CUDA or ROCm, + // causing a fatal error if it is not. This hack is made available solely for + // use from distbelief code, which temporarily has strong ties to CUDA or ROCm + // as a platform. + virtual void *GpuContextHack() { return nullptr; } private: SE_DISALLOW_COPY_AND_ASSIGN(StreamExecutorInterface); diff --git a/tensorflow/tools/api/golden/tensorflow.-variable.pbtxt b/tensorflow/tools/api/golden/tensorflow.-variable.pbtxt index 23b552cc38..e841c4ad89 100644 --- a/tensorflow/tools/api/golden/tensorflow.-variable.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.-variable.pbtxt @@ -49,7 +49,7 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'initial_value\', \'trainable\', \'collections\', \'validate_shape\', \'caching_device\', \'name\', \'variable_def\', \'dtype\', \'expected_shape\', \'import_scope\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'True\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'initial_value\', \'trainable\', \'collections\', \'validate_shape\', \'caching_device\', \'name\', \'variable_def\', \'dtype\', \'expected_shape\', \'import_scope\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=None, defaults=[\'None\', \'True\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], " } member_method { name: "assign" diff --git a/tensorflow/tools/ci_build/ci_sanity.sh b/tensorflow/tools/ci_build/ci_sanity.sh index db37edf809..866fe95d2b 100755 --- a/tensorflow/tools/ci_build/ci_sanity.sh +++ b/tensorflow/tools/ci_build/ci_sanity.sh @@ -354,7 +354,7 @@ do_external_licenses_check(){ # Whitelist echo ${EXTRA_LICENSE_FILE} - grep -e "@bazel_tools//src" -e "@bazel_tools//tools/" -e "@com_google_absl//" -e "//external" -e "@local" -e "@com_github_googlecloudplatform_google_cloud_cpp//" -v ${EXTRA_LICENSES_FILE} > temp.txt + grep -e "@bazel_tools//src" -e "@bazel_tools//tools/" -e "@com_google_absl//" -e "//external" -e "@local" -e "@com_github_googlecloudplatform_google_cloud_cpp//" -e "@embedded_jdk//" -v ${EXTRA_LICENSES_FILE} > temp.txt mv temp.txt ${EXTRA_LICENSES_FILE} diff --git a/tensorflow/tools/ci_build/install/install_bazel.sh b/tensorflow/tools/ci_build/install/install_bazel.sh index adbff8f6ef..e284401b8a 100755 --- a/tensorflow/tools/ci_build/install/install_bazel.sh +++ b/tensorflow/tools/ci_build/install/install_bazel.sh @@ -15,7 +15,7 @@ # ============================================================================== # Select bazel version. -BAZEL_VERSION="0.14.1" +BAZEL_VERSION="0.15.0" set +e local_bazel_ver=$(bazel version 2>&1 | grep -i label | awk '{print $3}') diff --git a/tensorflow/tools/ci_build/install/install_bazel_from_source.sh b/tensorflow/tools/ci_build/install/install_bazel_from_source.sh index 9d24b3e421..87be81577d 100755 --- a/tensorflow/tools/ci_build/install/install_bazel_from_source.sh +++ b/tensorflow/tools/ci_build/install/install_bazel_from_source.sh @@ -18,7 +18,7 @@ # It will compile bazel from source and install it in /usr/local/bin # Select bazel version. -BAZEL_VERSION="0.14.1" +BAZEL_VERSION="0.15.0" set +e local_bazel_ver=$(bazel version 2>&1 | grep -i label | awk '{print $3}') diff --git a/tensorflow/tools/ci_build/windows/bazel/bazel_test_lib.sh b/tensorflow/tools/ci_build/windows/bazel/bazel_test_lib.sh index c03cbd9c66..0482cf619a 100644 --- a/tensorflow/tools/ci_build/windows/bazel/bazel_test_lib.sh +++ b/tensorflow/tools/ci_build/windows/bazel/bazel_test_lib.sh @@ -33,10 +33,10 @@ function set_remote_cache_options { echo "build --tls_enabled=true" >> "${TMP_BAZELRC}" echo "build --remote_timeout=3600" >> "${TMP_BAZELRC}" echo "build --auth_enabled=true" >> "${TMP_BAZELRC}" - echo "build --spawn_strategy=remote" >> "${TMP_BAZELRC}" - echo "build --strategy=Javac=remote" >> "${TMP_BAZELRC}" - echo "build --strategy=Closure=remote" >> "${TMP_BAZELRC}" - echo "build --genrule_strategy=remote" >> "${TMP_BAZELRC}" + echo "build --spawn_strategy=standalone" >> "${TMP_BAZELRC}" + echo "build --strategy=Javac=standalone" >> "${TMP_BAZELRC}" + echo "build --strategy=Closure=standalone" >> "${TMP_BAZELRC}" + echo "build --genrule_strategy=standalone" >> "${TMP_BAZELRC}" echo "build --google_credentials=$GOOGLE_CLOUD_CREDENTIAL" >> "${TMP_BAZELRC}" } diff --git a/tensorflow/tools/docker/Dockerfile.devel b/tensorflow/tools/docker/Dockerfile.devel index fd94d64268..f7fe4119da 100644 --- a/tensorflow/tools/docker/Dockerfile.devel +++ b/tensorflow/tools/docker/Dockerfile.devel @@ -63,7 +63,7 @@ RUN echo "startup --batch" >>/etc/bazel.bazelrc RUN echo "build --spawn_strategy=standalone --genrule_strategy=standalone" \ >>/etc/bazel.bazelrc # Install the most recent bazel release. -ENV BAZEL_VERSION 0.14.1 +ENV BAZEL_VERSION 0.15.0 WORKDIR / RUN mkdir /bazel && \ cd /bazel && \ diff --git a/tensorflow/tools/docker/Dockerfile.devel-gpu b/tensorflow/tools/docker/Dockerfile.devel-gpu index 44120bf274..957a7ed799 100644 --- a/tensorflow/tools/docker/Dockerfile.devel-gpu +++ b/tensorflow/tools/docker/Dockerfile.devel-gpu @@ -83,7 +83,7 @@ RUN echo "startup --batch" >>/etc/bazel.bazelrc RUN echo "build --spawn_strategy=standalone --genrule_strategy=standalone" \ >>/etc/bazel.bazelrc # Install the most recent bazel release. -ENV BAZEL_VERSION 0.14.1 +ENV BAZEL_VERSION 0.15.0 WORKDIR / RUN mkdir /bazel && \ cd /bazel && \ diff --git a/tensorflow/tools/docker/Dockerfile.devel-gpu-cuda9-cudnn7 b/tensorflow/tools/docker/Dockerfile.devel-gpu-cuda9-cudnn7 index 3bedc8cf34..30bc2d2806 100644 --- a/tensorflow/tools/docker/Dockerfile.devel-gpu-cuda9-cudnn7 +++ b/tensorflow/tools/docker/Dockerfile.devel-gpu-cuda9-cudnn7 @@ -4,7 +4,7 @@ LABEL maintainer="Gunhan Gulsoy <gunan@google.com>" # It is possible to override these for releases. ARG TF_BRANCH=master -ARG BAZEL_VERSION=0.5.4 +ARG BAZEL_VERSION=0.15.0 ARG TF_AVAILABLE_CPUS=32 RUN apt-get update && apt-get install -y --no-install-recommends \ diff --git a/tensorflow/tools/graph_transforms/transform_utils.cc b/tensorflow/tools/graph_transforms/transform_utils.cc index af17fd75bc..cb084e49b7 100644 --- a/tensorflow/tools/graph_transforms/transform_utils.cc +++ b/tensorflow/tools/graph_transforms/transform_utils.cc @@ -247,9 +247,16 @@ Status SortByExecutionOrder(const GraphDef& input_graph_def, } } - if (processed < input_graph_def.node_size()) { - return errors::InvalidArgument(input_graph_def.node_size() - processed, - " nodes in a cycle"); + if (processed < num_nodes) { + LOG(WARNING) << "IN " << __func__ << (num_nodes - processed) + << " NODES IN A CYCLE"; + for (int64 i = 0; i < num_nodes; i++) { + if (pending_count[i] != 0) { + LOG(WARNING) << "PENDING: " << SummarizeNodeDef(input_graph_def.node(i)) + << "WITH PENDING COUNT = " << pending_count[i]; + } + } + return errors::InvalidArgument(num_nodes - processed, " nodes in a cycle"); } return Status::OK(); } diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl index 378de4261c..4b4f31813c 100644 --- a/tensorflow/workspace.bzl +++ b/tensorflow/workspace.bzl @@ -487,11 +487,11 @@ def tf_workspace(path_prefix="", tf_repo_name=""): tf_http_archive( name = "llvm", urls = [ - "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/bd8c8d759852871609ba2e4e79868420f751949d.tar.gz", - "https://github.com/llvm-mirror/llvm/archive/bd8c8d759852871609ba2e4e79868420f751949d.tar.gz", + "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/10c3b3d15ed6a788ac12221b784caf81fb8248b5.tar.gz", + "https://github.com/llvm-mirror/llvm/archive/10c3b3d15ed6a788ac12221b784caf81fb8248b5.tar.gz", ], - sha256 = "0c63e8583b213543309e8577ffe87a0cf34cc22269630d2c5c2f0a2345fda4a8", - strip_prefix = "llvm-bd8c8d759852871609ba2e4e79868420f751949d", + sha256 = "a9feb6b47267c30fd7c19ebfdf4dbde6757054f716fa77c09bcb1106799c3253", + strip_prefix = "llvm-10c3b3d15ed6a788ac12221b784caf81fb8248b5", build_file = clean_dep("//third_party/llvm:llvm.autogenerated.BUILD"), ) diff --git a/third_party/examples/eager/spinn/spinn.py b/third_party/examples/eager/spinn/spinn.py index 67456a5bdf..c242ef3fdd 100644 --- a/third_party/examples/eager/spinn/spinn.py +++ b/third_party/examples/eager/spinn/spinn.py @@ -419,7 +419,7 @@ class SNLIClassifierTrainer(tfe.Checkpointable): # Create a custom learning rate Variable for the RMSProp optimizer, because # the learning rate needs to be manually decayed later (see # decay_learning_rate()). - self._learning_rate = tfe.Variable(lr, name="learning_rate") + self._learning_rate = tf.Variable(lr, name="learning_rate") self._optimizer = tf.train.RMSPropOptimizer(self._learning_rate, epsilon=1e-6) |