aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar gracehoney <31743510+aaroey@users.noreply.github.com>2018-07-18 16:04:12 -0700
committerGravatar gracehoney <31743510+aaroey@users.noreply.github.com>2018-07-18 16:04:12 -0700
commit804b14e822f06ade2b52925f42924b1ad5e60790 (patch)
tree7eee939833f7f9d81365c9df5f852d7b3835e351
parent0e6bb6e3358a741bd995cb9b0055091c6b42a632 (diff)
parent5a78e98e877bdca794ffd9e5c4f00da5d2e7ee7d (diff)
Merge branch 'master' of https://github.com/tensorflow/tensorflow into fix_plugin_test
-rw-r--r--ISSUE_TEMPLATE.md3
-rw-r--r--WORKSPACE2
-rw-r--r--configure.py2
-rw-r--r--tensorflow/compiler/jit/BUILD6
-rw-r--r--tensorflow/compiler/jit/deadness_analysis.cc546
-rw-r--r--tensorflow/compiler/jit/deadness_analysis.h68
-rw-r--r--tensorflow/compiler/jit/deadness_analysis_test.cc443
-rw-r--r--tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc40
-rw-r--r--tensorflow/compiler/jit/kernels/xla_launch_op.cc6
-rw-r--r--tensorflow/compiler/jit/mark_for_compilation_pass.cc26
-rw-r--r--tensorflow/compiler/jit/xla_device_ops.h4
-rw-r--r--tensorflow/compiler/jit/xla_fusion_optimizer.cc12
-rw-r--r--tensorflow/compiler/tests/sort_ops_test.py32
-rw-r--r--tensorflow/compiler/tf2xla/kernels/if_op.cc1
-rw-r--r--tensorflow/compiler/tf2xla/kernels/topk_op.cc40
-rw-r--r--tensorflow/compiler/tf2xla/kernels/while_op.cc1
-rw-r--r--tensorflow/compiler/tf2xla/xla_compiler.cc10
-rw-r--r--tensorflow/compiler/tf2xla/xla_compiler_test.cc52
-rw-r--r--tensorflow/compiler/xla/python/local_computation_builder.cc3
-rw-r--r--tensorflow/compiler/xla/python/local_computation_builder.h3
-rw-r--r--tensorflow/compiler/xla/python/local_computation_builder.i3
-rw-r--r--tensorflow/compiler/xla/python/xla_client.py3
-rw-r--r--tensorflow/compiler/xla/python/xla_client_test.py18
-rw-r--r--tensorflow/compiler/xla/shape_tree.h140
-rw-r--r--tensorflow/compiler/xla/shape_tree_test.cc21
-rw-r--r--tensorflow/compiler/xla/shape_util.h13
-rw-r--r--tensorflow/contrib/autograph/converters/control_flow.py4
-rw-r--r--tensorflow/contrib/autograph/converters/control_flow_test.py3
-rw-r--r--tensorflow/contrib/checkpoint/python/containers.py6
-rw-r--r--tensorflow/contrib/cmake/tf_core_kernels.cmake1
-rw-r--r--tensorflow/contrib/data/kernels/BUILD11
-rw-r--r--tensorflow/contrib/data/kernels/assert_next_dataset_op.cc152
-rw-r--r--tensorflow/contrib/data/kernels/csv_dataset_op.cc100
-rw-r--r--tensorflow/contrib/data/ops/dataset_ops.cc13
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/BUILD1
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py60
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/BUILD14
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/csv_dataset_serialization_test.py73
-rw-r--r--tensorflow/contrib/data/python/ops/optimization.py53
-rw-r--r--tensorflow/contrib/eager/python/examples/generative_examples/dcgan.ipynb711
-rw-r--r--tensorflow/contrib/eager/python/examples/l2hmc/l2hmc.py2
-rw-r--r--tensorflow/contrib/eager/python/examples/l2hmc/neural_nets.py5
-rw-r--r--tensorflow/contrib/eager/python/examples/notebooks/custom_training.ipynb7
-rw-r--r--tensorflow/contrib/eager/python/examples/revnet/revnet_test.py2
-rw-r--r--tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py2
-rw-r--r--tensorflow/contrib/eager/python/examples/sagan/sagan.py2
-rw-r--r--tensorflow/contrib/eager/python/examples/workshop/2_models.ipynb2
-rw-r--r--tensorflow/contrib/eager/python/tfe_test.py7
-rw-r--r--tensorflow/contrib/lite/delegates/eager/BUILD32
-rw-r--r--tensorflow/contrib/lite/delegates/eager/buffer_map.cc2
-rw-r--r--tensorflow/contrib/lite/delegates/eager/buffer_map.h2
-rw-r--r--tensorflow/contrib/lite/delegates/eager/buffer_map_test.cc2
-rw-r--r--tensorflow/contrib/lite/delegates/eager/delegate_data.cc46
-rw-r--r--tensorflow/contrib/lite/delegates/eager/delegate_data.h48
-rw-r--r--tensorflow/contrib/lite/delegates/eager/delegate_data_test.cc44
-rw-r--r--tensorflow/contrib/lite/delegates/eager/util.cc2
-rw-r--r--tensorflow/contrib/lite/delegates/eager/util.h2
-rw-r--r--tensorflow/contrib/lite/delegates/eager/util_test.cc2
-rw-r--r--tensorflow/contrib/lite/kernels/pow_test.cc28
-rw-r--r--tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc23
-rw-r--r--tensorflow/contrib/lite/toco/BUILD2
-rw-r--r--tensorflow/contrib/lite/toco/export_tensorflow.cc159
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/convert_expanddims_to_reshape.cc5
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h1
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc3
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc94
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_reshape.cc9
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_gather.cc12
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc4
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_gather_attributes.cc53
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_reduce_attributes.cc2
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/unfuse_activation_functions.cc9
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/unpartition_embedding_lookup.cc2
-rw-r--r--tensorflow/contrib/lite/toco/import_tensorflow.cc53
-rw-r--r--tensorflow/contrib/lite/toco/model.h54
-rw-r--r--tensorflow/contrib/lite/toco/tflite/operator.cc5
-rw-r--r--tensorflow/contrib/lite/toco/toco_tooling.cc1
-rw-r--r--tensorflow/contrib/lite/toco/tooling_util.cc17
-rw-r--r--tensorflow/contrib/makefile/tf_op_files.txt2
-rw-r--r--tensorflow/contrib/nccl/kernels/nccl_manager.cc2
-rw-r--r--tensorflow/contrib/tensor_forest/BUILD2
-rw-r--r--tensorflow/contrib/tensorrt/convert/convert_graph.cc4
-rw-r--r--tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.cu.cc2
-rw-r--r--tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc4
-rw-r--r--tensorflow/contrib/tensorrt/kernels/trt_engine_op.h2
-rw-r--r--tensorflow/contrib/tensorrt/resources/trt_allocator.cc26
-rw-r--r--tensorflow/contrib/tensorrt/resources/trt_allocator.h14
-rw-r--r--tensorflow/contrib/tensorrt/resources/trt_resources.h2
-rw-r--r--tensorflow/contrib/tensorrt/test/tf_trt_integration_test.py22
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu_estimator.py6
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ResourceScatterNdUpdate.pbtxt2
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_device.cc3
-rw-r--r--tensorflow/core/common_runtime/process_state.cc2
-rw-r--r--tensorflow/core/common_runtime/process_state.h4
-rw-r--r--tensorflow/core/graph/algorithm.cc37
-rw-r--r--tensorflow/core/graph/algorithm.h16
-rw-r--r--tensorflow/core/graph/algorithm_test.cc52
-rw-r--r--tensorflow/core/graph/graph_constructor.cc8
-rw-r--r--tensorflow/core/grappler/clusters/cluster.cc8
-rw-r--r--tensorflow/core/grappler/clusters/cluster.h3
-rw-r--r--tensorflow/core/grappler/optimizers/data/BUILD3
-rw-r--r--tensorflow/core/grappler/optimizers/data/graph_utils.cc4
-rw-r--r--tensorflow/core/kernels/BUILD2
-rw-r--r--tensorflow/core/kernels/cast_op.cc24
-rw-r--r--tensorflow/core/kernels/cast_op_gpu.cu.cc8
-rw-r--r--tensorflow/core/kernels/cast_op_impl.h30
-rw-r--r--tensorflow/core/kernels/cast_op_impl_uint32.cc46
-rw-r--r--tensorflow/core/kernels/cast_op_impl_uint64.cc46
-rw-r--r--tensorflow/core/kernels/cast_op_test.cc4
-rw-r--r--tensorflow/core/kernels/cuda_solvers.cc2
-rw-r--r--tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc28
-rw-r--r--tensorflow/core/util/cuda_launch_config.h2
-rw-r--r--tensorflow/docs_src/guide/eager.md22
-rw-r--r--tensorflow/docs_src/mobile/index.md3
-rw-r--r--tensorflow/docs_src/mobile/tflite/index.md16
-rw-r--r--tensorflow/examples/speech_commands/freeze.py2
-rw-r--r--tensorflow/examples/speech_commands/models.py2
-rw-r--r--tensorflow/python/keras/engine/training.py7
-rw-r--r--tensorflow/python/keras/engine/training_test.py23
-rw-r--r--tensorflow/python/keras/metrics.py381
-rw-r--r--tensorflow/python/keras/metrics_test.py196
-rw-r--r--tensorflow/python/kernel_tests/resource_variable_ops_test.py9
-rw-r--r--tensorflow/python/ops/resource_variable_ops.py13
-rw-r--r--tensorflow/python/ops/variable_scope.py33
-rw-r--r--tensorflow/python/ops/variables.py75
-rw-r--r--tensorflow/python/platform/gfile.py18
-rw-r--r--tensorflow/stream_executor/cuda/cuda_gpu_executor.cc2
-rw-r--r--tensorflow/stream_executor/cuda/cuda_gpu_executor.h2
-rw-r--r--tensorflow/stream_executor/cuda/cuda_stream.h4
-rw-r--r--tensorflow/stream_executor/host/host_gpu_executor.h2
-rw-r--r--tensorflow/stream_executor/host/host_stream.h4
-rw-r--r--tensorflow/stream_executor/stream_executor_internal.h34
-rw-r--r--tensorflow/tools/api/golden/tensorflow.-variable.pbtxt2
-rwxr-xr-xtensorflow/tools/ci_build/ci_sanity.sh2
-rwxr-xr-xtensorflow/tools/ci_build/install/install_bazel.sh2
-rwxr-xr-xtensorflow/tools/ci_build/install/install_bazel_from_source.sh2
-rw-r--r--tensorflow/tools/ci_build/windows/bazel/bazel_test_lib.sh8
-rw-r--r--tensorflow/tools/docker/Dockerfile.devel2
-rw-r--r--tensorflow/tools/docker/Dockerfile.devel-gpu2
-rw-r--r--tensorflow/tools/docker/Dockerfile.devel-gpu-cuda9-cudnn72
-rw-r--r--tensorflow/tools/graph_transforms/transform_utils.cc13
-rw-r--r--tensorflow/workspace.bzl8
-rw-r--r--third_party/examples/eager/spinn/spinn.py2
143 files changed, 4189 insertions, 573 deletions
diff --git a/ISSUE_TEMPLATE.md b/ISSUE_TEMPLATE.md
index 2f3df7cda9..52faed9297 100644
--- a/ISSUE_TEMPLATE.md
+++ b/ISSUE_TEMPLATE.md
@@ -15,9 +15,10 @@ If you open a GitHub issue, here is our policy:
### System information
- **Have I written custom code (as opposed to using a stock example script provided in TensorFlow)**:
- **OS Platform and Distribution (e.g., Linux Ubuntu 16.04)**:
+- **Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if the issue happens on mobile device**:
- **TensorFlow installed from (source or binary)**:
- **TensorFlow version (use command below)**:
-- **Python version**:
+- **Python version**:
- **Bazel version (if compiling from source)**:
- **GCC/Compiler version (if compiling from source)**:
- **CUDA/cuDNN version**:
diff --git a/WORKSPACE b/WORKSPACE
index fd7570a80a..17961829a6 100644
--- a/WORKSPACE
+++ b/WORKSPACE
@@ -18,7 +18,7 @@ closure_repositories()
# files, in case the parsing of those build files depends on the bazel
# version we require here.
load("//tensorflow:version_check.bzl", "check_bazel_version_at_least")
-check_bazel_version_at_least("0.10.0")
+check_bazel_version_at_least("0.15.0")
load("//tensorflow:workspace.bzl", "tf_workspace")
diff --git a/configure.py b/configure.py
index c482628ec8..25729adf36 100644
--- a/configure.py
+++ b/configure.py
@@ -1429,7 +1429,7 @@ def main():
# environment variables.
environ_cp = dict(os.environ)
- check_bazel_version('0.10.0')
+ check_bazel_version('0.15.0')
reset_tf_configure_bazelrc(args.workspace)
cleanup_makefile()
diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD
index c2245b8eae..9174a67cc6 100644
--- a/tensorflow/compiler/jit/BUILD
+++ b/tensorflow/compiler/jit/BUILD
@@ -304,11 +304,13 @@ cc_library(
name = "compilation_passes",
srcs = [
"build_xla_launch_ops_pass.cc",
+ "deadness_analysis.cc",
"encapsulate_subgraphs_pass.cc",
"mark_for_compilation_pass.cc",
],
hdrs = [
"build_xla_launch_ops_pass.h",
+ "deadness_analysis.h",
"encapsulate_subgraphs_pass.h",
"mark_for_compilation_pass.h",
],
@@ -325,6 +327,7 @@ cc_library(
"//tensorflow/compiler/tf2xla:dump_graph",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:util",
"//tensorflow/core:core_cpu",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
@@ -377,6 +380,7 @@ tf_cc_test(
name = "compilation_passes_test",
size = "small",
srcs = [
+ "deadness_analysis_test.cc",
"encapsulate_subgraphs_pass_test.cc",
"mark_for_compilation_pass_test.cc",
],
@@ -387,6 +391,7 @@ tf_cc_test(
"//tensorflow/cc:cc_ops_internal",
"//tensorflow/cc:function_ops",
"//tensorflow/cc:ops",
+ "//tensorflow/cc:sendrecv_ops",
"//tensorflow/compiler/jit/kernels:xla_launch_op",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
@@ -458,6 +463,7 @@ cc_library(
visibility = ["//visibility:public"],
deps = [
":common",
+ ":compilation_passes",
":union_find",
":xla_cluster_util",
"//tensorflow/compiler/jit/graphcycles",
diff --git a/tensorflow/compiler/jit/deadness_analysis.cc b/tensorflow/compiler/jit/deadness_analysis.cc
new file mode 100644
index 0000000000..b2d119029a
--- /dev/null
+++ b/tensorflow/compiler/jit/deadness_analysis.cc
@@ -0,0 +1,546 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/jit/deadness_analysis.h"
+#include "tensorflow/core/graph/algorithm.h"
+#include "tensorflow/core/graph/tensor_id.h"
+#include "tensorflow/core/lib/gtl/flatset.h"
+#include "tensorflow/core/lib/hash/hash.h"
+
+// ALGORITHM OVERVIEW
+//
+// We map every output produced by each node in the TensorFlow graph (including
+// control dependence) into an instance of the Predicate class. Instances of
+// Predicate denote logical formulas and mapping a node `n` to a predicate
+// `pred` implies that `n` is executed whenver `pred` is true. Then we can
+// deduce mismatching liveness in the inputs to node by comparing the predicate
+// those inputs are mapped to.
+//
+// Loops are handled pessimistically -- we map Merge nodes with backedges to
+// uninterpreted symbols (the same kind we use to represent Switch and _Recv).
+// Predicate equality has to hold over all possible assignments to these
+// uninterpreted symbols.
+
+namespace tensorflow {
+
+namespace {
+
+// Represents a logical predicate, used as described in the algorithm overview
+// above.
+class Predicate {
+ public:
+ enum class Kind { kAnd, kOr, kNot, kSymbol };
+
+ virtual string ToString() const = 0;
+ virtual bool operator==(const Predicate& other) const = 0;
+ virtual bool operator!=(const Predicate& other) const {
+ return !(*this == other);
+ }
+ int64 hash() const { return hash_; }
+
+ virtual Kind kind() const = 0;
+ virtual ~Predicate() {}
+
+ protected:
+ explicit Predicate(int64 hash) : hash_(hash) {}
+
+ private:
+ const int64 hash_;
+};
+
+int64 HashPredicateSequence(Predicate::Kind kind,
+ gtl::ArraySlice<Predicate*> preds) {
+ int64 hash = ::tensorflow::hash<Predicate::Kind>()(kind);
+ for (Predicate* pred : preds) {
+ hash = Hash64Combine(hash, pred->hash());
+ }
+ return hash;
+}
+
+bool PredicateSequenceEqual(gtl::ArraySlice<Predicate*> lhs,
+ gtl::ArraySlice<Predicate*> rhs) {
+ if (lhs.size() != rhs.size()) {
+ return false;
+ }
+ for (int64 i = 0; i < lhs.size(); i++) {
+ if (*lhs[i] != *rhs[i]) {
+ return false;
+ }
+ }
+ return true;
+}
+
+// Represents a logical conjunction of a set of predicates.
+class AndPredicate : public Predicate {
+ public:
+ explicit AndPredicate(std::vector<Predicate*> operands)
+ : Predicate(HashPredicateSequence(Kind::kAnd, operands)),
+ operands_(std::move(operands)) {}
+
+ string ToString() const override {
+ if (operands().empty()) {
+ return "#true";
+ }
+
+ std::vector<string> operands_str;
+ std::transform(operands().begin(), operands().end(),
+ std::back_inserter(operands_str),
+ [](Predicate* pred) { return pred->ToString(); });
+
+ return strings::StrCat("(", str_util::Join(operands_str, " & "), ")");
+ }
+
+ bool operator==(const Predicate& other) const override {
+ return other.kind() == Kind::kAnd &&
+ PredicateSequenceEqual(
+ dynamic_cast<const AndPredicate&>(other).operands(), operands());
+ }
+
+ Kind kind() const override { return Kind::kAnd; }
+
+ const tensorflow::gtl::ArraySlice<Predicate*> operands() const {
+ return operands_;
+ }
+
+ private:
+ std::vector<Predicate*> operands_;
+};
+
+// Represents a logical disjunction of a set of predicates.
+class OrPredicate : public Predicate {
+ public:
+ explicit OrPredicate(std::vector<Predicate*> operands)
+ : Predicate(HashPredicateSequence(Kind::kOr, operands)),
+ operands_(std::move(operands)) {}
+
+ string ToString() const override {
+ if (operands().empty()) {
+ return "#false";
+ }
+
+ std::vector<string> operands_str;
+ std::transform(operands().begin(), operands().end(),
+ std::back_inserter(operands_str),
+ [](Predicate* pred) { return pred->ToString(); });
+
+ return strings::StrCat("(", str_util::Join(operands_str, " | "), ")");
+ }
+
+ bool operator==(const Predicate& other) const override {
+ return other.kind() == Kind::kOr &&
+ PredicateSequenceEqual(
+ dynamic_cast<const OrPredicate&>(other).operands(), operands());
+ }
+
+ Kind kind() const override { return Kind::kOr; }
+ const tensorflow::gtl::ArraySlice<Predicate*> operands() const {
+ return operands_;
+ }
+
+ private:
+ std::vector<Predicate*> operands_;
+};
+
+// Represents a logical negation of a set of predicates.
+class NotPredicate : public Predicate {
+ public:
+ explicit NotPredicate(Predicate* operand)
+ : Predicate(HashPredicateSequence(Kind::kNot, {operand})),
+ operand_(operand) {}
+
+ string ToString() const override {
+ return strings::StrCat("~", operand()->ToString());
+ }
+
+ bool operator==(const Predicate& other) const override {
+ return other.kind() == Kind::kNot &&
+ *dynamic_cast<const NotPredicate&>(other).operand() == *operand();
+ }
+
+ Kind kind() const override { return Kind::kNot; }
+ Predicate* operand() const { return operand_; }
+
+ private:
+ Predicate* operand_;
+};
+
+// Represents an uninterpreted symbol in a logical predicate.
+//
+// Two predicates are equivalent iff they are equivalent for all assignments to
+// the symbols contained in them.
+class SymbolPredicate : public Predicate {
+ public:
+ explicit SymbolPredicate(TensorId tensor_id, bool must_be_true)
+ : Predicate(Hash(tensor_id, must_be_true)),
+ tensor_id_(std::move(tensor_id)),
+ must_be_true_(must_be_true) {}
+
+ string ToString() const override { return tensor_id_.ToString(); }
+ bool operator==(const Predicate& other) const override {
+ return other.kind() == Kind::kSymbol &&
+ must_be_true() ==
+ dynamic_cast<const SymbolPredicate&>(other).must_be_true() &&
+ dynamic_cast<const SymbolPredicate&>(other).tensor_id() ==
+ tensor_id();
+ }
+
+ Kind kind() const override { return Kind::kSymbol; }
+
+ // If `must_be_true()` is true this SymbolPredicate represents the proposition
+ // "tensor_id() is live and evaluates to true".
+ //
+ // If `must_be_true()` is false then this SymbolPredicate represents the
+ // proposition "tensor_id() is live (and may evalutate to any value)"
+ TensorId tensor_id() const { return tensor_id_; }
+ bool must_be_true() const { return must_be_true_; }
+
+ private:
+ TensorId tensor_id_;
+ bool must_be_true_;
+
+ static int64 Hash(const TensorId tensor_id, bool must_be_true) {
+ return Hash64Combine(
+ ::tensorflow::hash<bool>()(must_be_true),
+ Hash64Combine(::tensorflow::hash<Predicate::Kind>()(Kind::kSymbol),
+ TensorId::Hasher{}(tensor_id)));
+ }
+};
+
+// Creates and owns Predicate instances. Simplifies predicates as it creates
+// them.
+class PredicateFactory {
+ public:
+ Predicate* MakeAndPredicate(gtl::ArraySlice<Predicate*> operands) {
+ return MakeAndOrImpl(operands, /*is_and=*/true);
+ }
+ Predicate* MakeOrPredicate(gtl::ArraySlice<Predicate*> operands) {
+ return MakeAndOrImpl(operands, /*is_and=*/false);
+ }
+
+ Predicate* MakeNotPredicate(Predicate* pred) {
+ return Make<NotPredicate>(pred);
+ }
+
+ Predicate* MakeSymbolPredicate(TensorId tensor_id, bool must_be_true) {
+ return Make<SymbolPredicate>(tensor_id, must_be_true);
+ }
+
+ Predicate* MakeTrue() { return MakeAndPredicate({}); }
+ Predicate* MakeFalse() { return MakeOrPredicate({}); }
+
+ private:
+ template <typename PredicateT, typename... Args>
+ Predicate* Make(Args... args) {
+ std::unique_ptr<PredicateT> pred(
+ new PredicateT(std::forward<Args>(args)...));
+ predicate_storage_.emplace_back(std::move(pred));
+ return predicate_storage_.back().get();
+ }
+
+ Predicate* MakeAndOrImpl(gtl::ArraySlice<Predicate*> operands, bool is_and);
+
+ struct PredicatePtrHash {
+ size_t operator()(const Predicate* pred) const { return pred->hash(); }
+ };
+
+ struct PredicatePtrEq {
+ size_t operator()(const Predicate* a, const Predicate* b) const {
+ return *a == *b;
+ }
+ };
+
+ using PredicateSet =
+ gtl::FlatSet<Predicate*, PredicatePtrHash, PredicatePtrEq>;
+
+ std::vector<std::unique_ptr<Predicate>> predicate_storage_;
+};
+
+// Common code to create AndPredicate or OrPredicate instances.
+Predicate* PredicateFactory::MakeAndOrImpl(gtl::ArraySlice<Predicate*> operands,
+ bool is_and) {
+ Predicate::Kind pred_kind =
+ is_and ? Predicate::Kind::kAnd : Predicate::Kind::kOr;
+ PredicateSet simplified_ops_set;
+ std::vector<Predicate*> simplified_ops;
+ for (Predicate* op : operands) {
+ // Simplify A&A => A and A|A => A.
+ if (!simplified_ops_set.insert(op).second) {
+ continue;
+ }
+
+ if (op->kind() == pred_kind) {
+ // "Inline" the operands of an inner And/Or into the parent And/Or.
+ gtl::ArraySlice<Predicate*> operands =
+ is_and ? dynamic_cast<AndPredicate*>(op)->operands()
+ : dynamic_cast<OrPredicate*>(op)->operands();
+ for (Predicate* subop : operands) {
+ if (simplified_ops_set.insert(subop).second) {
+ simplified_ops.push_back(subop);
+ }
+ }
+ } else {
+ simplified_ops.push_back(op);
+ }
+ }
+
+ if (simplified_ops.size() == 1) {
+ return simplified_ops[0];
+ }
+
+ // Simplify "A&~A=>False" and "A|~A=>True".
+ PredicateSet negated_ops;
+ for (Predicate* op : simplified_ops) {
+ if (op->kind() == Predicate::Kind::kNot) {
+ negated_ops.insert(dynamic_cast<NotPredicate&>(*op).operand());
+ }
+ }
+
+ for (Predicate* op : simplified_ops) {
+ if (negated_ops.count(op)) {
+ return is_and ? MakeFalse() : MakeTrue();
+ }
+ }
+
+ std::stable_sort(
+ simplified_ops.begin(), simplified_ops.end(),
+ [](Predicate* a, Predicate* b) { return a->hash() < b->hash(); });
+
+ return is_and ? Make<AndPredicate>(std::move(simplified_ops))
+ : Make<OrPredicate>(std::move(simplified_ops));
+}
+
+class DeadnessAnalysisImpl : public DeadnessAnalysis {
+ public:
+ explicit DeadnessAnalysisImpl(const Graph* graph)
+ : graph_(*graph), vlog_(VLOG_IS_ON(2)) {}
+
+ Status Populate();
+ bool HasInputsWithMismatchingDeadness(const Node& node) override;
+ void Print() const override;
+
+ private:
+ enum class EdgeKind { kDataAndControl, kDataOnly, kControlOnly };
+
+ std::vector<Predicate*> GetIncomingPreds(Node* n, EdgeKind edge_kind);
+ void SetPred(Node* n, int output_idx, Predicate* pred) {
+ CHECK(
+ predicate_map_.insert({TensorId(n->name(), output_idx), pred}).second);
+ }
+ void SetPred(Node* n, gtl::ArraySlice<int> output_idxs, Predicate* pred) {
+ for (int output_idx : output_idxs) {
+ SetPred(n, output_idx, pred);
+ }
+ }
+
+ Status HandleSwitch(Node* n);
+ Status HandleMerge(Node* n);
+ Status HandleRecv(Node* n);
+ Status HandleGeneric(Node* n);
+
+ const Graph& graph_;
+ gtl::FlatMap<TensorId, Predicate*, TensorId::Hasher> predicate_map_;
+ PredicateFactory predicate_factory_;
+ bool vlog_;
+};
+
+TensorId InputEdgeToTensorId(const Edge* e) {
+ return TensorId(e->src()->name(), e->src_output());
+}
+
+std::vector<Predicate*> DeadnessAnalysisImpl::GetIncomingPreds(
+ Node* n, DeadnessAnalysisImpl::EdgeKind edge_kind) {
+ std::vector<Predicate*> incoming_preds;
+ for (const Edge* in_edge : n->in_edges()) {
+ bool should_process =
+ edge_kind == EdgeKind::kDataAndControl ||
+ (in_edge->IsControlEdge() && edge_kind == EdgeKind::kControlOnly) ||
+ (!in_edge->IsControlEdge() && edge_kind == EdgeKind::kDataOnly);
+
+ if (should_process) {
+ auto it = predicate_map_.find(InputEdgeToTensorId(in_edge));
+ CHECK(it != predicate_map_.end());
+ incoming_preds.push_back(it->second);
+ }
+ }
+ return incoming_preds;
+}
+
+Status DeadnessAnalysisImpl::HandleSwitch(Node* n) {
+ std::vector<Predicate*> input_preds =
+ GetIncomingPreds(n, EdgeKind::kDataAndControl);
+ const Edge* pred_edge;
+ TF_RETURN_IF_ERROR(n->input_edge(1, &pred_edge));
+ Predicate* true_switch = predicate_factory_.MakeSymbolPredicate(
+ TensorId(pred_edge->src()->name(), pred_edge->src_output()),
+ /*must_be_true=*/true);
+ Predicate* false_switch = predicate_factory_.MakeNotPredicate(true_switch);
+
+ // Output 0 is alive iff all inputs are alive and the condition is false.
+ input_preds.push_back(false_switch);
+ SetPred(n, 0, predicate_factory_.MakeAndPredicate(input_preds));
+ input_preds.pop_back();
+
+ // Output 1 is alive iff all inputs are alive and the condition is true.
+ input_preds.push_back(true_switch);
+ SetPred(n, 1, predicate_factory_.MakeAndPredicate(input_preds));
+ input_preds.pop_back();
+
+ // Control is alive iff any inputs are alive.
+ SetPred(n, Graph::kControlSlot,
+ predicate_factory_.MakeAndPredicate(input_preds));
+
+ return Status::OK();
+}
+
+Status DeadnessAnalysisImpl::HandleMerge(Node* n) {
+ // Merge ignores deadness of its control inputs. A merge that isn't the
+ // target of a backedge has is alive iff any of its data inputs are. We treat
+ // the liveness of a merge that is the target of a backedge symbolically.
+
+ bool has_backedge = std::any_of(
+ n->in_edges().begin(), n->in_edges().end(), [](const Edge* e) {
+ return !e->IsControlEdge() && e->src()->IsNextIteration();
+ });
+
+ Predicate* input_data_pred =
+ has_backedge ? predicate_factory_.MakeSymbolPredicate(
+ TensorId(n->name(), 0), /*must_be_true=*/false)
+ : predicate_factory_.MakeOrPredicate(
+ GetIncomingPreds(n, EdgeKind::kDataOnly));
+
+ SetPred(n, {0, 1, Graph::kControlSlot}, input_data_pred);
+ return Status::OK();
+}
+
+Status DeadnessAnalysisImpl::HandleRecv(Node* n) {
+ // In addition to being alive or dead based on the inputs, a _Recv can also
+ // acquire a dead signal from a _Send.
+ std::vector<Predicate*> input_preds =
+ GetIncomingPreds(n, EdgeKind::kDataAndControl);
+ input_preds.push_back(predicate_factory_.MakeSymbolPredicate(
+ TensorId(n->name(), 0), /*must_be_true=*/false));
+ SetPred(n, {0, Graph::kControlSlot},
+ predicate_factory_.MakeAndPredicate(input_preds));
+ return Status::OK();
+}
+
+Status DeadnessAnalysisImpl::HandleGeneric(Node* n) {
+ // Generally nodes are alive iff all their inputs are alive.
+ Predicate* pred = predicate_factory_.MakeAndPredicate(
+ GetIncomingPreds(n, EdgeKind::kDataAndControl));
+ for (int output_idx = 0; output_idx < n->num_outputs(); output_idx++) {
+ SetPred(n, output_idx, pred);
+ }
+ SetPred(n, Graph::kControlSlot, pred);
+ return Status::OK();
+}
+
+Status DeadnessAnalysisImpl::Populate() {
+ std::vector<Node*> rpo;
+ GetReversePostOrder(graph_, &rpo, /*stable_comparator=*/{},
+ /*edge_filter=*/[](const Edge& edge) {
+ return !edge.src()->IsNextIteration();
+ });
+
+ // This an abstract interpretation over the deadness propagation semantics of
+ // the graph executor.
+ for (Node* n : rpo) {
+ if (n->IsSwitch()) {
+ TF_RETURN_IF_ERROR(HandleSwitch(n));
+ } else if (n->IsMerge()) {
+ TF_RETURN_IF_ERROR(HandleMerge(n));
+ } else if (n->IsControlTrigger()) {
+ SetPred(n, Graph::kControlSlot, predicate_factory_.MakeTrue());
+ } else if (n->IsRecv() || n->IsHostRecv()) {
+ TF_RETURN_IF_ERROR(HandleRecv(n));
+ } else {
+ TF_RETURN_IF_ERROR(HandleGeneric(n));
+ }
+ }
+
+ return Status::OK();
+}
+
+bool DeadnessAnalysisImpl::HasInputsWithMismatchingDeadness(const Node& node) {
+ CHECK(!node.IsMerge());
+
+ if (vlog_) {
+ VLOG(2) << "HasInputsWithMismatchingDeadness(" << node.name() << ")";
+ }
+
+ Predicate* pred = nullptr;
+ for (const Edge* edge : node.in_edges()) {
+ auto it = predicate_map_.find(InputEdgeToTensorId(edge));
+ CHECK(it != predicate_map_.end());
+ if (vlog_) {
+ VLOG(2) << " " << InputEdgeToTensorId(edge).ToString() << ": "
+ << it->second->ToString();
+ }
+
+ // Today we just compare the predicates for equality (with some
+ // canonicalization/simplification happening before) but we could be more
+ // sophisticated here if need be.
+ if (pred != nullptr && *pred != *it->second) {
+ if (vlog_) {
+ VLOG(2) << "HasInputsWithMismatchingDeadness(" << node.name()
+ << ") -> true";
+ }
+ return true;
+ }
+ pred = it->second;
+ }
+
+ if (vlog_) {
+ VLOG(2) << "HasInputsWithMismatchingDeadness(" << node.name()
+ << ") -> false";
+ }
+
+ return false;
+}
+
+void DeadnessAnalysisImpl::Print() const {
+ std::vector<TensorId> tensor_ids;
+ for (const auto& kv_pair : predicate_map_) {
+ tensor_ids.push_back(kv_pair.first);
+ }
+
+ std::sort(tensor_ids.begin(), tensor_ids.end());
+
+ for (TensorId tensor_id : tensor_ids) {
+ auto it = predicate_map_.find(tensor_id);
+ CHECK(it != predicate_map_.end()) << tensor_id.ToString();
+ VLOG(2) << tensor_id.ToString() << " -> " << it->second->ToString();
+ }
+}
+
+} // namespace
+
+DeadnessAnalysis::~DeadnessAnalysis() {}
+
+/*static*/ Status DeadnessAnalysis::Run(
+ const Graph& graph, std::unique_ptr<DeadnessAnalysis>* result) {
+ std::unique_ptr<DeadnessAnalysisImpl> analysis(
+ new DeadnessAnalysisImpl(&graph));
+ TF_RETURN_IF_ERROR(analysis->Populate());
+
+ if (VLOG_IS_ON(2)) {
+ analysis->Print();
+ }
+
+ *result = std::move(analysis);
+ return Status::OK();
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/deadness_analysis.h b/tensorflow/compiler/jit/deadness_analysis.h
new file mode 100644
index 0000000000..6e7ab41161
--- /dev/null
+++ b/tensorflow/compiler/jit/deadness_analysis.h
@@ -0,0 +1,68 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_JIT_DEADNESS_ANALYSIS_H_
+#define TENSORFLOW_COMPILER_JIT_DEADNESS_ANALYSIS_H_
+
+#include "tensorflow/core/graph/graph.h"
+
+namespace tensorflow {
+
+// This analyzes a TensorFlow graph to identify nodes which may have partially
+// dead inputs (i.e. these nodes may have some dead inputs and some alive
+// inputs).
+//
+// For example, the ADD node in the following graph
+//
+// V0 PRED0 V1 PRED1
+// | | | |
+// v v v v
+// SWITCH SWITCH
+// | |
+// +---+ + ---+
+// | |
+// v v
+// ADD
+//
+// can have its inputs independently dead or alive based on the runtime values
+// of PRED0 and PRED1.
+//
+// It is tempting to call this a liveness analysis but I avoided that because
+// "liveness" already has other connotations.
+class DeadnessAnalysis {
+ public:
+ // Returns true if `node` may have some live inputs and some dead inputs.
+ //
+ // This is a conservatively correct routine -- if it returns false then `node`
+ // is guaranteed to not have inputs with mismatching liveness, but not the
+ // converse.
+ //
+ // REQUIRES: node is not a Merge operation.
+ virtual bool HasInputsWithMismatchingDeadness(const Node& node) = 0;
+
+ // Prints out the internal state of this instance. For debugging purposes
+ // only.
+ virtual void Print() const = 0;
+ virtual ~DeadnessAnalysis();
+
+ // Run the deadness analysis over `graph` and returns an error or a populated
+ // instance of DeadnessAnalysis in `result`.
+ static Status Run(const Graph& graph,
+ std::unique_ptr<DeadnessAnalysis>* result);
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMPILER_JIT_DEADNESS_ANALYSIS_H_
diff --git a/tensorflow/compiler/jit/deadness_analysis_test.cc b/tensorflow/compiler/jit/deadness_analysis_test.cc
new file mode 100644
index 0000000000..584385cab7
--- /dev/null
+++ b/tensorflow/compiler/jit/deadness_analysis_test.cc
@@ -0,0 +1,443 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/jit/deadness_analysis.h"
+
+#include "tensorflow/cc/framework/ops.h"
+#include "tensorflow/cc/ops/array_ops.h"
+#include "tensorflow/cc/ops/control_flow_ops_internal.h"
+#include "tensorflow/cc/ops/function_ops.h"
+#include "tensorflow/cc/ops/sendrecv_ops.h"
+#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/compiler/jit/defs.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/graph/algorithm.h"
+#include "tensorflow/core/graph/graph_constructor.h"
+#include "tensorflow/core/graph/graph_def_builder.h"
+#include "tensorflow/core/graph/graph_def_builder_util.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace {
+
+Status AnalyzeDeadness(Graph* graph,
+ std::unique_ptr<DeadnessAnalysis>* result) {
+ FixupSourceAndSinkEdges(graph);
+ return DeadnessAnalysis::Run(*graph, result);
+}
+
+ops::Switch CreateSwitch(const Scope& root, const string& prefix) {
+ Output value = ops::Placeholder(root.WithOpName(prefix + "/value"), DT_FLOAT);
+ Output predicate =
+ ops::Placeholder(root.WithOpName(prefix + "/pred"), DT_BOOL);
+ return ops::Switch(root.WithOpName(prefix + "/switch"), value, predicate);
+}
+
+Output CreateInductionVariable(const Scope& root, const string& prefix,
+ const string& frame_name, int32 init) {
+ Output initial_value = ops::Const(root.WithOpName(prefix + "/init"), init);
+ Output enter_initial_value = ops::internal::Enter(
+ root.WithOpName(prefix + "/enter"), initial_value, frame_name);
+
+ ops::Merge iv(root.WithOpName(prefix + "/iv"), {enter_initial_value});
+ Output increment_by = ops::Const(root.WithOpName(prefix + "/incr"), 1);
+ Output final_value = ops::Const(root.WithOpName(prefix + "/final"), 10);
+ Output loop_cond_expr =
+ ops::Less(root.WithOpName(prefix + "/less"), iv.output, final_value);
+ Output loop_cond =
+ ops::LoopCond(root.WithOpName(prefix + "/cond"), loop_cond_expr);
+ ops::Switch latch(root.WithOpName(prefix + "/latch"), iv.output, loop_cond);
+ ops::internal::Exit exit(root.WithOpName(prefix + "/exit"), iv.output);
+ Output iv_next =
+ ops::Add(root.WithOpName(prefix + "/ivnext"), iv.output, increment_by);
+ Output next_iteration =
+ ops::NextIteration(root.WithOpName(prefix + "next_iteration"), iv_next);
+
+ root.graph()->AddEdge(next_iteration.node(), 0, iv.output.node(), 1);
+ root.graph()->AddControlEdge(iv.output.node(), increment_by.node());
+ root.graph()->AddControlEdge(iv.output.node(), final_value.node());
+
+ return iv.output;
+}
+
+TEST(DeadnessAnalysisTest, BasicPositive) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+
+ ops::Switch sw = CreateSwitch(root, "0");
+ Output add =
+ ops::Add(root.WithOpName("add"), sw.output_true, sw.output_false);
+
+ std::unique_ptr<DeadnessAnalysis> result;
+ TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
+
+ EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*add.node()));
+}
+
+TEST(DeadnessAnalysisTest, BasicNegative) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+
+ Output a = ops::Placeholder(root.WithOpName("a"), DT_FLOAT);
+ Output b = ops::Placeholder(root.WithOpName("b"), DT_FLOAT);
+ Output add = ops::Add(root.WithOpName("add"), a, b);
+
+ std::unique_ptr<DeadnessAnalysis> result;
+ TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
+
+ EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add.node()));
+}
+
+TEST(DeadnessAnalysisTest, AndIsCommutative) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+
+ ops::Switch sw_0 = CreateSwitch(root, "0");
+ ops::Switch sw_1 = CreateSwitch(root, "1");
+
+ Output a0 =
+ ops::Add(root.WithOpName("a0"), sw_0.output_false, sw_1.output_false);
+ Output a1 =
+ ops::Add(root.WithOpName("a1"), sw_1.output_false, sw_0.output_false);
+
+ Output b0 =
+ ops::Add(root.WithOpName("b0"), sw_0.output_false, sw_1.output_true);
+ Output b1 =
+ ops::Add(root.WithOpName("b1"), sw_1.output_true, sw_0.output_false);
+
+ Output live0 = ops::Add(root.WithOpName("live0"), a0, a1);
+ Output live1 = ops::Add(root.WithOpName("live1"), b0, b1);
+
+ Output halfdead0 = ops::Add(root.WithOpName("halfdead0"), a0, b0);
+ Output halfdead1 = ops::Add(root.WithOpName("halfdead1"), a1, b1);
+
+ std::unique_ptr<DeadnessAnalysis> result;
+ TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
+
+ EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*live0.node()));
+ EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*live1.node()));
+
+ EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*halfdead0.node()));
+ EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*halfdead1.node()));
+}
+
+TEST(DeadnessAnalysisTest, AndIsAssociative) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+
+ ops::Switch sw_0 = CreateSwitch(root, "0");
+ ops::Switch sw_1 = CreateSwitch(root, "1");
+ ops::Switch sw_2 = CreateSwitch(root, "2");
+
+ Output a0 =
+ ops::Add(root.WithOpName("a0"), sw_0.output_false, sw_1.output_false);
+ Output a1 = ops::Add(root.WithOpName("a1"), a0, sw_2.output_false);
+
+ Output b0 =
+ ops::Add(root.WithOpName("b0"), sw_1.output_false, sw_2.output_false);
+ Output b1 = ops::Add(root.WithOpName("b1"), sw_0.output_false, b0);
+
+ Output add = ops::Add(root.WithOpName("add"), a1, b1);
+
+ std::unique_ptr<DeadnessAnalysis> result;
+ TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
+
+ EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add.node()));
+}
+
+TEST(DeadnessAnalysisTest, OrIsCommutative) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+
+ ops::Switch sw_0 = CreateSwitch(root, "0");
+ ops::Switch sw_1 = CreateSwitch(root, "1");
+
+ ops::Merge m0(root.WithOpName("m0"), {sw_0.output_false, sw_1.output_false});
+ ops::Merge m1(root.WithOpName("m1"), {sw_1.output_false, sw_0.output_false});
+ ops::Merge m2(root.WithOpName("m2"), {sw_0.output_false, sw_1.output_true});
+ ops::Merge m3(root.WithOpName("m3"), {sw_1.output_true, sw_0.output_false});
+
+ Output live0 = ops::Add(root.WithOpName("live0"), m0.output, m1.output);
+ Output live1 = ops::Add(root.WithOpName("live1"), m2.output, m3.output);
+
+ Output halfdead0 =
+ ops::Add(root.WithOpName("halfdead0"), m0.output, m2.output);
+ Output halfdead1 =
+ ops::Add(root.WithOpName("halfdead1"), m1.output, m3.output);
+
+ std::unique_ptr<DeadnessAnalysis> result;
+ TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
+
+ EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*live0.node()));
+ EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*live1.node()));
+
+ EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*halfdead0.node()));
+ EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*halfdead1.node()));
+}
+
+TEST(DeadnessAnalysisTest, OrIsAssociative) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+
+ ops::Switch sw_0 = CreateSwitch(root, "0");
+ ops::Switch sw_1 = CreateSwitch(root, "1");
+ ops::Switch sw_2 = CreateSwitch(root, "2");
+
+ ops::Merge m0(root.WithOpName("m0"), {sw_0.output_false, sw_1.output_false});
+ ops::Merge m1(root.WithOpName("m1"), {m0.output, sw_2.output_false});
+ ops::Merge m2(root.WithOpName("m2"), {sw_1.output_false, sw_2.output_false});
+ ops::Merge m3(root.WithOpName("m3"), {sw_0.output_false, m2.output});
+
+ Output add = ops::Add(root.WithOpName("add"), m1.output, m3.output);
+
+ std::unique_ptr<DeadnessAnalysis> result;
+ TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
+
+ EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add.node()));
+}
+
+TEST(DeadnessAnalysisTest, AndOfOr) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+
+ ops::Switch sw_0 = CreateSwitch(root, "0");
+ ops::Switch sw_1 = CreateSwitch(root, "1");
+ ops::Switch sw_2 = CreateSwitch(root, "2");
+ ops::Switch sw_3 = CreateSwitch(root, "3");
+
+ ops::Merge m0(root.WithOpName("m0"), {sw_0.output_false, sw_1.output_false});
+ ops::Merge m1(root.WithOpName("m1"), {sw_2.output_false, sw_3.output_false});
+
+ Output add0 = ops::Add(root.WithOpName("add0"), m0.output, m1.output);
+ Output add1 = ops::Add(root.WithOpName("add1"), m0.output, m1.output);
+
+ Output add2 = ops::Add(root.WithOpName("add2"), add0, add1);
+
+ std::unique_ptr<DeadnessAnalysis> result;
+ TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
+
+ EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add2.node()));
+}
+
+TEST(DeadnessAnalysisTest, OrOfAnd) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+
+ ops::Switch sw_0 = CreateSwitch(root, "0");
+ ops::Switch sw_1 = CreateSwitch(root, "1");
+ ops::Switch sw_2 = CreateSwitch(root, "2");
+ ops::Switch sw_3 = CreateSwitch(root, "3");
+
+ Output add0 =
+ ops::Add(root.WithOpName("add0"), sw_0.output_false, sw_1.output_false);
+ Output add1 =
+ ops::Add(root.WithOpName("add1"), sw_2.output_false, sw_3.output_false);
+
+ ops::Merge m0(root.WithOpName("m0"), {add0, add1});
+ ops::Merge m1(root.WithOpName("m1"), {add0, add1});
+
+ Output add2 = ops::Add(root.WithOpName("add2"), m0.output, m1.output);
+
+ std::unique_ptr<DeadnessAnalysis> result;
+ TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
+
+ EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add2.node()));
+}
+
+TEST(DeadnessAnalysisTest, NEGATIVE_AndOrDistributive) {
+ // This demonstrates one of the weaknesses in the current approach -- since we
+ // only do some basic simplifications we can't see that "(A|B)&C" ==
+ // "(A&C)|(B&C)".
+ Scope root = Scope::NewRootScope().ExitOnError();
+
+ ops::Switch sw_0 = CreateSwitch(root, "0");
+ ops::Switch sw_1 = CreateSwitch(root, "1");
+ ops::Switch sw_2 = CreateSwitch(root, "2");
+
+ ops::Merge m0(root.WithOpName("m0"), {sw_0.output_false, sw_1.output_false});
+ Output add0 = ops::Add(root.WithOpName("add0"), m0.output, sw_2.output_false);
+
+ Output add1 =
+ ops::Add(root.WithOpName("add1"), sw_0.output_false, sw_2.output_false);
+ Output add2 =
+ ops::Add(root.WithOpName("add2"), sw_1.output_false, sw_2.output_false);
+ ops::Merge m1(root.WithOpName("m1"), {add1, add2});
+
+ Output add3 = ops::Add(root.WithOpName("add3"), add0, m1.output);
+
+ std::unique_ptr<DeadnessAnalysis> result;
+ TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
+
+ EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*add2.node()));
+}
+
+TEST(DeadnessAnalysisTest, Ternary) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+
+ Output predicate = ops::Placeholder(root.WithOpName("predicate"), DT_BOOL);
+ Output true_value = ops::Placeholder(root.WithOpName("true_value"), DT_FLOAT);
+ Output false_value =
+ ops::Placeholder(root.WithOpName("false_value"), DT_FLOAT);
+
+ ops::Switch predicated_true(root.WithOpName("predicated_true"), true_value,
+ predicate);
+
+ ops::Switch predicated_false(root.WithOpName("predicated_false"), true_value,
+ predicate);
+ ops::Merge merge(root.WithOpName("ternary"), {predicated_true.output_true,
+ predicated_false.output_false});
+ Output addend = ops::Placeholder(root.WithOpName("addend"), DT_FLOAT);
+ Output add = ops::Add(root.WithOpName("add"), merge.output, addend);
+
+ std::unique_ptr<DeadnessAnalysis> result;
+ TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
+
+ EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add.node()));
+}
+
+TEST(DeadnessAnalysisTest, Recv) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+
+ Output recv_a = ops::_Recv(root.WithOpName("recv_a"), DT_FLOAT, "tensor_a",
+ "sender", 0, "receiver");
+ Output recv_b = ops::_Recv(root.WithOpName("recv_b"), DT_FLOAT, "tensor_b",
+ "sender", 0, "receiver");
+ Output add = ops::Add(root.WithOpName("add"), recv_a, recv_b);
+
+ std::unique_ptr<DeadnessAnalysis> result;
+ TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
+
+ EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*add.node()));
+}
+
+TEST(DeadnessAnalysisTest, HostRecv) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+
+ Output recv_a = ops::_HostRecv(root.WithOpName("recv_a"), DT_FLOAT,
+ "tensor_a", "sender", 0, "receiver");
+ Output recv_b = ops::_HostRecv(root.WithOpName("recv_b"), DT_FLOAT,
+ "tensor_b", "sender", 0, "receiver");
+ Output add = ops::Add(root.WithOpName("add"), recv_a, recv_b);
+
+ std::unique_ptr<DeadnessAnalysis> result;
+ TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
+
+ EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*add.node()));
+}
+
+TEST(DeadnessAnalysisTest, Loop) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+ Output iv0 = CreateInductionVariable(root, "iv0", "fr0", 0);
+ Output iv1 = CreateInductionVariable(root, "iv1", "fr0", 0);
+ Output iv2 = CreateInductionVariable(root, "iv2", "fr0", 1);
+ Output add0 = ops::Add(root.WithOpName("add0"), iv0, iv1);
+ Output add1 = ops::Add(root.WithOpName("add1"), iv1, iv2);
+
+ std::unique_ptr<DeadnessAnalysis> result;
+ TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
+
+ // NB! iv0 and iv1 are equivalent and a smarter deadness analysis would have
+ // noticed that. Today we are pessimistic here because we assign an
+ // uninterpreted symbol to merges with backedges.
+
+ EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*add0.node()));
+ EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*add1.node()));
+}
+
+TEST(DeadnessAnalysisTest, ControlInputs) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+ ops::Switch sw = CreateSwitch(root, "0");
+
+ Output id0 = ops::Identity(root.WithOpName("id0"), sw.output_false);
+ Output id1 = ops::Identity(root.WithOpName("id1"), sw.output_true);
+
+ Output const0 = ops::Const(root.WithOpName("const0"), 1);
+ Output const1 = ops::Const(root.WithOpName("const1"), 2);
+
+ Output add = ops::Add(root.WithOpName("add"), const0, const1);
+
+ root.graph()->AddControlEdge(id0.node(), const0.node());
+ root.graph()->AddControlEdge(id1.node(), const1.node());
+
+ std::unique_ptr<DeadnessAnalysis> result;
+ TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
+
+ EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*add.node()));
+}
+
+TEST(DeadnessAnalysisTest, ControlTrigger) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+ ops::Switch sw = CreateSwitch(root, "0");
+
+ Output id0 = ops::Identity(root.WithOpName("id0"), sw.output_false);
+ Output id1 = ops::Identity(root.WithOpName("id1"), sw.output_true);
+
+ ops::ControlTrigger ctrl_trigger0(root.WithOpName("ctrl_trigger0"));
+ ops::ControlTrigger ctrl_trigger1(root.WithOpName("ctrl_trigger1"));
+
+ Output const0 = ops::Const(root.WithOpName("const0"), 1);
+ Output const1 = ops::Const(root.WithOpName("const1"), 2);
+
+ Output add = ops::Add(root.WithOpName("add"), const0, const1);
+
+ root.graph()->AddControlEdge(id0.node(), ctrl_trigger0.operation.node());
+ root.graph()->AddControlEdge(ctrl_trigger0.operation.node(), const0.node());
+
+ root.graph()->AddControlEdge(id1.node(), ctrl_trigger1.operation.node());
+ root.graph()->AddControlEdge(ctrl_trigger1.operation.node(), const1.node());
+
+ std::unique_ptr<DeadnessAnalysis> result;
+ TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
+
+ EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add.node()));
+}
+
+TEST(DeadnessAnalysisTest, ControlInputsToMerge) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+ ops::Switch sw = CreateSwitch(root, "0");
+
+ Output id0 = ops::Identity(root.WithOpName("id0"), sw.output_false);
+ Output id1 = ops::Identity(root.WithOpName("id1"), sw.output_true);
+
+ Output constant = ops::Const(root.WithOpName("constant"), 5);
+ ops::Merge m0(root.WithOpName("m0"), {constant});
+ ops::Merge m1(root.WithOpName("m0"), {constant});
+ Output add = ops::Add(root.WithOpName("add"), m0.output, m1.output);
+
+ root.graph()->AddControlEdge(id0.node(), m0.output.node());
+ root.graph()->AddControlEdge(id1.node(), m1.output.node());
+
+ std::unique_ptr<DeadnessAnalysis> result;
+ TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
+
+ EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add.node()));
+}
+
+TEST(DeadnessAnalysisTest, RecvVsSwitch) {
+ // Demonstrates why we need the must_be_true bit on SymbolP.
+ Scope root = Scope::NewRootScope().ExitOnError();
+
+ Output recv = ops::_Recv(root.WithOpName("recv"), DT_BOOL, "tensor", "sender",
+ 0, "receiver");
+ Output value = ops::Placeholder(root.WithOpName("value"), DT_BOOL);
+ ops::Switch sw(root.WithOpName("switch"), value, recv);
+ Output logical_and =
+ ops::LogicalAnd(root.WithOpName("and"), recv, sw.output_true);
+
+ std::unique_ptr<DeadnessAnalysis> result;
+ TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
+
+ EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*logical_and.node()));
+}
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc
index 9c424b201e..fdd71c6a58 100644
--- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc
+++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc
@@ -138,7 +138,7 @@ class Encapsulator {
// Find subgraphs marked with 'group_attribute', and build a new
// subgraph, one for each value of 'group_attribute'.
- Status SplitIntoSubgraphs();
+ Status SplitIntoSubgraphs(FunctionLibraryDefinition* library);
// Build a FunctionDef for each subgraph, and add it 'library'. The values of
// the 'group_attribute' annotations become the function names.
@@ -1478,7 +1478,7 @@ Status Encapsulator::CopySubgraphEdges(
return Status::OK();
}
-Status Encapsulator::SplitIntoSubgraphs() {
+Status Encapsulator::SplitIntoSubgraphs(FunctionLibraryDefinition* library) {
Status s;
// Map from input graph nodes to subgraph nodes.
@@ -1513,6 +1513,15 @@ Status Encapsulator::SplitIntoSubgraphs() {
TF_RETURN_IF_ERROR(BuildControlFlowInfo(subgraph.GetGraph(), &dummy));
}
+ if (VLOG_IS_ON(1)) {
+ // Dump subgraphs.
+ for (auto& entry : subgraphs_) {
+ dump_graph::DumpGraphToFile(
+ strings::StrCat("encapsulate_subgraphs_subgraph_", entry.first),
+ *entry.second.GetGraph(), library);
+ }
+ }
+
return s;
}
@@ -1936,6 +1945,8 @@ Status Encapsulator::DoStaticShapeInferenceForOutsideCompilationSend(
// continue.
TensorShapeProto proto;
context->ShapeHandleToProto(shape, &proto);
+ VLOG(2) << "Node " << src_node->name()
+ << " has known shape: " << proto.DebugString();
if (dummy_node_images.find(src_node) == dummy_node_images.end()) {
dummy_node_images[src_node] =
AddDummyShapedNode(src_node, src_port, control_flow_info,
@@ -1953,6 +1964,8 @@ Status Encapsulator::DoStaticShapeInferenceForOutsideCompilationSend(
if (VLOG_IS_ON(2)) {
TensorShapeProto proto;
context->ShapeHandleToProto(shape, &proto);
+ VLOG(2) << "Node " << src_node->name()
+ << " has unknown shape: " << proto.DebugString();
}
stack.push_back({src_node, false});
}
@@ -2195,6 +2208,23 @@ Status Encapsulator::FindClusterDependencies() {
}
}
}
+ if (VLOG_IS_ON(2)) {
+ // Print debug information.
+ VLOG(2) << "node_ancestors_map:";
+ for (const auto& node_iter : node_ancestors_map) {
+ VLOG(2) << "\t" << node_iter.first->name() << ": subgraph = '"
+ << node_iter.second.subgraph
+ << "', outside_compilation_cluster = '"
+ << node_iter.second.outside_compilation_cluster
+ << "', ancestor_clusters: "
+ << (node_iter.second.ancestor_clusters.empty() ? "(empty)" : "");
+ for (const auto& cluster_iter : node_iter.second.ancestor_clusters) {
+ VLOG(2) << "\t\tsubgraph = '" << cluster_iter.subgraph
+ << "', outside_compilation_cluster = '"
+ << cluster_iter.outside_compilation_cluster << "'";
+ }
+ }
+ }
return Status::OK();
}
@@ -2402,7 +2432,7 @@ Status EncapsulateSubgraphsInFunctions(
std::move(outside_compilation_attribute),
&graph_in);
TF_RETURN_IF_ERROR(encapsulator.FindClusterDependencies());
- TF_RETURN_IF_ERROR(encapsulator.SplitIntoSubgraphs());
+ TF_RETURN_IF_ERROR(encapsulator.SplitIntoSubgraphs(library));
TF_RETURN_IF_ERROR(encapsulator.BuildFunctionDefs(
rewrite_subgraph_fn, reuse_existing_functions, library));
@@ -2451,7 +2481,7 @@ Status EncapsulateSubgraphsPass::Run(
const GraphOptimizationPassOptions& options) {
VLOG(1) << "EncapsulateSubgraphsPass::Run";
if (VLOG_IS_ON(1)) {
- dump_graph::DumpGraphToFile("before_encapsulate_subgraphs", **options.graph,
+ dump_graph::DumpGraphToFile("encapsulate_subgraphs_before", **options.graph,
options.flib_def);
}
@@ -2534,7 +2564,7 @@ Status EncapsulateSubgraphsPass::Run(
"EncapsulateSubgraphsPass failed");
if (VLOG_IS_ON(1)) {
- dump_graph::DumpGraphToFile("after_encapsulate_subgraphs", *graph_out,
+ dump_graph::DumpGraphToFile("encapsulate_subgraphs_after", *graph_out,
options.flib_def);
}
diff --git a/tensorflow/compiler/jit/kernels/xla_launch_op.cc b/tensorflow/compiler/jit/kernels/xla_launch_op.cc
index 338fb5a6f0..c5d0e4f8fb 100644
--- a/tensorflow/compiler/jit/kernels/xla_launch_op.cc
+++ b/tensorflow/compiler/jit/kernels/xla_launch_op.cc
@@ -51,7 +51,11 @@ XlaLocalLaunchBase::XlaLocalLaunchBase(OpKernelConstruction* ctx,
if (device_type_ == DeviceType(DEVICE_CPU)) {
platform_id_ = se::host::kHostPlatformId;
} else if (device_type_ == DeviceType(DEVICE_GPU)) {
- platform_id_ = se::cuda::kCudaPlatformId;
+ platform_id_ = ctx->device()
+ ->tensorflow_gpu_device_info()
+ ->stream->parent()
+ ->platform()
+ ->id();
} else {
platform_id_ = nullptr;
}
diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc
index 8c3882116d..6558f14dd6 100644
--- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc
+++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc
@@ -21,6 +21,7 @@ limitations under the License.
#include <unordered_map>
#include <unordered_set>
+#include "tensorflow/compiler/jit/deadness_analysis.h"
#include "tensorflow/compiler/jit/defs.h"
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
#include "tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h"
@@ -28,6 +29,7 @@ limitations under the License.
#include "tensorflow/compiler/jit/xla_cluster_util.h"
#include "tensorflow/compiler/tf2xla/dump_graph.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/framework/graph_def_util.h"
#include "tensorflow/core/framework/memory_types.h"
@@ -462,17 +464,27 @@ Status MarkForCompilationPass::Run(
VLOG(1) << "flags->tf_xla_fusion_only = " << flags->tf_xla_fusion_only;
const FunctionLibraryDefinition* fld = options.flib_def;
- auto is_compilable = [global_jit_level, cpu_global_jit, fusion_only, fld](
- const Node* node, const DeviceType& device_type) {
+ std::unique_ptr<DeadnessAnalysis> deadness;
+ {
+ XLA_SCOPED_LOGGING_TIMER_LEVEL("DeadnessAnalysis", 0);
+ TF_RETURN_IF_ERROR(DeadnessAnalysis::Run(**options.graph, &deadness));
+ }
+
+ auto is_compilable = [&](const Node* node, const DeviceType& device_type) {
const XlaOpRegistry::DeviceRegistration* registration;
if (!XlaOpRegistry::GetCompilationDevice(device_type.type(),
&registration)) {
return false;
}
+ // TODO(b/111570009): This bailout for ControlTrigger is probably not
+ // needed.
+ //
// Don't compile control trigger nodes. We won't preserve their deadness
// semantics correctly, so it's safest not to compile them.
- if (node->IsControlTrigger()) return false;
+ if (node->IsControlTrigger()) {
+ return false;
+ }
// If this device requires a JIT, we must say yes.
if (registration->requires_compilation) return true;
@@ -485,6 +497,14 @@ Status MarkForCompilationPass::Run(
status = fld->GetAttr(*node, kXlaCompileAttr, &compile);
if (status.ok()) return compile;
+ // If inputs to `node` can have conflicting deadness (i.e. some are alive
+ // and some are dead) then don't compile it. XLA cannot represent the
+ // deadness semantics of these nodes correctly and auto-clustering these
+ // nodes can cause deadness propagate to nodes that should be live.
+ if (node->IsMerge() || deadness->HasInputsWithMismatchingDeadness(*node)) {
+ return false;
+ }
+
// Check for fusable ops only if requested.
if (global_jit_level > 0 && fusion_only && !IsXlaFusable(node->def())) {
return false;
diff --git a/tensorflow/compiler/jit/xla_device_ops.h b/tensorflow/compiler/jit/xla_device_ops.h
index 134dcc1bb5..6adda327f1 100644
--- a/tensorflow/compiler/jit/xla_device_ops.h
+++ b/tensorflow/compiler/jit/xla_device_ops.h
@@ -77,9 +77,7 @@ class XlaAssignVariableOp : public AsyncOpKernel {
ConstantOp); \
REGISTER_KERNEL_BUILDER( \
Name("Identity").Device(DEVICE).TypeConstraint("T", TYPES), IdentityOp); \
- REGISTER_KERNEL_BUILDER( \
- Name("IdentityN").Device(DEVICE).TypeConstraint("T", TYPES), \
- IdentityNOp); \
+ REGISTER_KERNEL_BUILDER(Name("IdentityN").Device(DEVICE), IdentityNOp); \
REGISTER_KERNEL_BUILDER(Name("Placeholder").Device(DEVICE), PlaceholderOp); \
REGISTER_KERNEL_BUILDER(Name("PlaceholderV2").Device(DEVICE), \
PlaceholderOp); \
diff --git a/tensorflow/compiler/jit/xla_fusion_optimizer.cc b/tensorflow/compiler/jit/xla_fusion_optimizer.cc
index 74257b09a8..b70e1cf52b 100644
--- a/tensorflow/compiler/jit/xla_fusion_optimizer.cc
+++ b/tensorflow/compiler/jit/xla_fusion_optimizer.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include <unordered_map>
#include <unordered_set>
+#include "tensorflow/compiler/jit/deadness_analysis.h"
#include "tensorflow/compiler/jit/defs.h"
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
#include "tensorflow/compiler/jit/union_find.h"
@@ -146,6 +147,9 @@ Status XlaFusionOptimizer::Optimize(grappler::Cluster* cluster,
TF_RETURN_IF_ERROR(
ImportGraphDef(options, item.graph, &graph, &shape_refiner));
+ std::unique_ptr<DeadnessAnalysis> deadness;
+ TF_RETURN_IF_ERROR(DeadnessAnalysis::Run(graph, &deadness));
+
// Collect nodes that can be fused via XLA, while ignoring those that
// explicitly ask for XLA: (*) nodes that are marked to be compiled
// explicitly. (*) nodes assigned to XLA device.
@@ -185,6 +189,14 @@ Status XlaFusionOptimizer::Optimize(grappler::Cluster* cluster,
continue;
}
+ // If inputs to `node` can have conflicting deadness (i.e. some are alive
+ // and some are dead) then don't compile it. XLA cannot represent the
+ // deadness semantics of these nodes correctly and auto-clustering these
+ // nodes can cause deadness propagate to nodes that should be live.
+ if (node->IsMerge() || deadness->HasInputsWithMismatchingDeadness(*node)) {
+ continue;
+ }
+
compilation_candidates.insert(node);
}
diff --git a/tensorflow/compiler/tests/sort_ops_test.py b/tensorflow/compiler/tests/sort_ops_test.py
index 9e2ef964a1..7ff01be3cb 100644
--- a/tensorflow/compiler/tests/sort_ops_test.py
+++ b/tensorflow/compiler/tests/sort_ops_test.py
@@ -88,6 +88,38 @@ class XlaSortOpTest(xla_test.XLATestCase):
topk, [x.astype(dtype)],
expected=[x[indices].astype(dtype), indices])
+ def testTopK2D(self):
+ # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU.
+ if self.device in ["XLA_CPU", "XLA_GPU"]:
+ return
+
+ supported_types = set(
+ [dtypes.bfloat16.as_numpy_dtype, np.float32, np.int32, np.uint32])
+ for dtype in supported_types.intersection(self.numeric_types):
+ # Use small input size for bfloat16. Otherwise, we'll get duplicate values
+ # after conversion to bfloat16, so the possible resulting index array is
+ # no longer unique.
+ if dtype == dtypes.bfloat16.as_numpy_dtype:
+ array_size = 10
+ k_options = [0, 1, 2, 10]
+ else:
+ array_size = 200 * 1000
+ k_options = [0, 1, 2, 10, 20, 100, 1000, 200 * 1000]
+ batch = 16
+ for x in [np.arange(batch * array_size)]:
+ np.random.shuffle(x)
+ x = np.reshape(x, [batch, array_size])
+ for k in k_options:
+ indices = x.argsort(axis=1)[::, -1:-k - 1:-1]
+ expected = np.sort(x, axis=1)[::, -1:-k - 1:-1]
+
+ def topk(v, k=k):
+ return nn_ops.top_k(v, k=k, sorted=True)
+
+ self._assertOpOutputMatchesExpected(
+ topk, [x.astype(dtype)],
+ expected=[expected.astype(dtype), indices])
+
def testTopKZeros(self):
"""Tests that positive and negative zeros sort correctly."""
# TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU.
diff --git a/tensorflow/compiler/tf2xla/kernels/if_op.cc b/tensorflow/compiler/tf2xla/kernels/if_op.cc
index f5fcf3cacd..e2160feba0 100644
--- a/tensorflow/compiler/tf2xla/kernels/if_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/if_op.cc
@@ -246,6 +246,7 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) {
VLOG(1) << "Done building If";
}
+REGISTER_XLA_OP(Name("If").AllowResourceTypes(), XlaIfOp);
REGISTER_XLA_OP(Name("XlaIf").AllowResourceTypes(), XlaIfOp);
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/topk_op.cc b/tensorflow/compiler/tf2xla/kernels/topk_op.cc
index 1ddcb08c8e..82d4a69777 100644
--- a/tensorflow/compiler/tf2xla/kernels/topk_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/topk_op.cc
@@ -41,33 +41,35 @@ class TopKOp : public XlaOpKernel {
OP_REQUIRES(context, input_shape.dims() >= 1,
errors::InvalidArgument("input must be >= 1-D, got shape ",
input_shape.DebugString()));
+ int last_dim = input_shape.dims() - 1;
+ int last_dim_size = input_shape.dim_size(last_dim);
OP_REQUIRES(
- context, input_shape.dim_size(input_shape.dims() - 1) >= k,
+ context, last_dim_size >= k,
errors::InvalidArgument("input must have at least k columns. Had ",
- input_shape.dim_size(input_shape.dims() - 1),
- ", needed ", k));
-
- OP_REQUIRES(
- context, input_shape.dims() == 1,
- errors::Unimplemented("TopK is implemented for 1-D inputs, got shape ",
- input_shape.DebugString()));
+ last_dim_size, ", needed ", k));
xla::XlaBuilder* const b = context->builder();
- if (input_shape.dim_size(0) < k) {
- k = input_shape.dim_size(0);
+ if (last_dim_size < k) {
+ k = last_dim_size;
}
const xla::XlaOp input = context->Input(0);
- xla::XlaOp iota_s32 = xla::Iota(b, xla::S32, input_shape.dim_size(0));
- xla::XlaOp sort_result = xla::Sort(xla::Neg(input), iota_s32);
+
+ xla::XlaOp iota_s32 = xla::Iota(b, xla::S32, last_dim_size);
+ auto input_dims = input_shape.dim_sizes();
+ std::vector<int64> broadcast_dims(input_dims.begin(), input_dims.end() - 1);
+ xla::XlaOp broadcast_s32 = xla::Broadcast(iota_s32, broadcast_dims);
+ xla::XlaOp sort_result = xla::Sort(xla::Neg(input), broadcast_s32);
+
+ std::vector<int64> start_indices(input_shape.dims(), 0);
+ std::vector<int64> limit_indices(input_dims.begin(), input_dims.end());
+ limit_indices[last_dim] = k;
+ std::vector<int64> strides(input_shape.dims(), 1);
+
xla::XlaOp values =
- xla::Neg(xla::Slice(xla::GetTupleElement(sort_result, 0),
- /*start_indices=*/{0},
- /*limit_indices=*/{k},
- /*strides=*/{1}));
+ xla::Neg(xla::Slice(xla::GetTupleElement(sort_result, 0), start_indices,
+ limit_indices, strides));
xla::XlaOp indices = xla::Slice(xla::GetTupleElement(sort_result, 1),
- /*start_indices=*/{0},
- /*limit_indices=*/{k},
- /*strides=*/{1});
+ start_indices, limit_indices, strides);
context->SetOutput(0, values);
context->SetOutput(1, indices);
}
diff --git a/tensorflow/compiler/tf2xla/kernels/while_op.cc b/tensorflow/compiler/tf2xla/kernels/while_op.cc
index 9413a30a6c..009fdd81b2 100644
--- a/tensorflow/compiler/tf2xla/kernels/while_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/while_op.cc
@@ -299,6 +299,7 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) {
VLOG(1) << "Done building while loop";
}
+REGISTER_XLA_OP(Name("While").AllowResourceTypes(), XlaWhileOp);
REGISTER_XLA_OP(Name("XlaWhile").AllowResourceTypes(), XlaWhileOp);
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc
index 319cbc74e9..cb47581e36 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler.cc
+++ b/tensorflow/compiler/tf2xla/xla_compiler.cc
@@ -422,16 +422,18 @@ Status BuildComputation(
// assignment will be placed on this value, which will cause the resource
// update to be returned from the same device that provided the resource.
handle = xla::GetTupleElement(xla::Tuple(builder, {handle}), 0);
-
elems.push_back(handle);
}
}
*num_computation_outputs = elems.size();
- // Builds the XLA computation.
- if (always_return_tuple || elems.size() != 1) {
- xla::Tuple(builder, elems);
+ // Builds the XLA computation. We *always* form a tuple here to ensure that
+ // the output value is the last thing added into the XLA computation, even
+ // if there is only one output value.
+ auto tuple = xla::Tuple(builder, elems);
+ if (!always_return_tuple && elems.size() == 1) {
+ xla::GetTupleElement(tuple, 0);
}
builder->ClearOpMetadata();
diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc
index 6f76816a86..2fb93be01d 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc
+++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc
@@ -228,6 +228,58 @@ TEST_F(XlaCompilerTest, Simple) {
EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
}
+// Tests compilation of a graph where the _Retval node is not necessarily last
+// amongst the graph nodes in construction order, and always_return_tuple is
+// false. Regression test for bug where the wrong value was returned.
+TEST_F(XlaCompilerTest, OutOfOrderGraph) {
+ Scope scope = Scope::NewRootScope().ExitOnError();
+ auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0);
+ auto b = ops::_Arg(scope.WithOpName("B"), DT_INT32, 1);
+ // The _Retval node is not last in construction order.
+ auto d = ops::_Retval(scope.WithOpName("D"), a, 0);
+ auto c = ops::Add(scope.WithOpName("C"), a, b);
+
+ std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
+ TF_ASSERT_OK(scope.ToGraph(graph.get()));
+
+ // Builds a description of the arguments.
+ std::vector<XlaCompiler::Argument> args(2);
+ args[0].kind = XlaCompiler::Argument::kParameter;
+ args[0].type = DT_INT32;
+ args[0].shape = TensorShape({2});
+ args[1].kind = XlaCompiler::Argument::kParameter;
+ args[1].type = DT_INT32;
+ args[1].shape = TensorShape({2});
+
+ // Compiles the graph.
+ XlaCompiler compiler(DefaultOptions());
+
+ XlaCompiler::CompileOptions compile_options;
+ compile_options.always_return_tuple = false;
+ XlaCompiler::CompilationResult result;
+ TF_ASSERT_OK(compiler.CompileGraph(compile_options, "add", std::move(graph),
+ args, &result));
+
+ // Tests that the generated computation works.
+ std::unique_ptr<xla::Literal> param0_literal =
+ xla::LiteralUtil::CreateR1<int32>({7, 42});
+ std::unique_ptr<xla::Literal> param1_literal =
+ xla::LiteralUtil::CreateR1<int32>({-3, 101});
+ std::unique_ptr<xla::GlobalData> param0_data =
+ client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ std::unique_ptr<xla::GlobalData> param1_data =
+ client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
+
+ std::unique_ptr<xla::GlobalData> actual =
+ client_
+ ->Execute(*result.computation, {param0_data.get(), param1_data.get()})
+ .ConsumeValueOrDie();
+ std::unique_ptr<xla::Literal> actual_literal =
+ client_->Transfer(*actual).ConsumeValueOrDie();
+
+ EXPECT_TRUE(xla::LiteralTestUtil::Equal(*param0_literal, *actual_literal));
+}
+
TEST_F(XlaCompilerTest, HasSaneErrorOnNonCompileTimeConstantInputToReshape) {
// Builds a graph that adds reshapes a tensor, but with the shape not
// statically known.
diff --git a/tensorflow/compiler/xla/python/local_computation_builder.cc b/tensorflow/compiler/xla/python/local_computation_builder.cc
index be55d50b23..66b1c08a39 100644
--- a/tensorflow/compiler/xla/python/local_computation_builder.cc
+++ b/tensorflow/compiler/xla/python/local_computation_builder.cc
@@ -614,6 +614,9 @@ _FORWARD_BINOP(Min)
_FORWARD_BINOP(And)
_FORWARD_BINOP(Or)
_FORWARD_BINOP(Xor)
+_FORWARD_BINOP(ShiftLeft)
+_FORWARD_BINOP(ShiftRightArithmetic)
+_FORWARD_BINOP(ShiftRightLogical)
_FORWARD_UNOP(Not)
_FORWARD_UNOP(Abs)
_FORWARD_UNOP(Exp)
diff --git a/tensorflow/compiler/xla/python/local_computation_builder.h b/tensorflow/compiler/xla/python/local_computation_builder.h
index 690ff277e8..17ad044578 100644
--- a/tensorflow/compiler/xla/python/local_computation_builder.h
+++ b/tensorflow/compiler/xla/python/local_computation_builder.h
@@ -333,6 +333,9 @@ class LocalComputationBuilder {
_FORWARD_BINOP(And)
_FORWARD_BINOP(Or)
_FORWARD_BINOP(Xor)
+ _FORWARD_BINOP(ShiftLeft)
+ _FORWARD_BINOP(ShiftRightArithmetic)
+ _FORWARD_BINOP(ShiftRightLogical)
_FORWARD_UNOP(Not)
_FORWARD_UNOP(Abs)
_FORWARD_UNOP(Exp)
diff --git a/tensorflow/compiler/xla/python/local_computation_builder.i b/tensorflow/compiler/xla/python/local_computation_builder.i
index afdea88cb7..42bf76e5d8 100644
--- a/tensorflow/compiler/xla/python/local_computation_builder.i
+++ b/tensorflow/compiler/xla/python/local_computation_builder.i
@@ -989,6 +989,9 @@ tensorflow::ImportNumpy();
%unignore xla::swig::LocalComputationBuilder::And;
%unignore xla::swig::LocalComputationBuilder::Or;
%unignore xla::swig::LocalComputationBuilder::Xor;
+%unignore xla::swig::LocalComputationBuilder::ShiftLeft;
+%unignore xla::swig::LocalComputationBuilder::ShiftRightArithmetic;
+%unignore xla::swig::LocalComputationBuilder::ShiftRightLogical;
%unignore xla::swig::LocalComputationBuilder::Not;
%unignore xla::swig::LocalComputationBuilder::Abs;
%unignore xla::swig::LocalComputationBuilder::Exp;
diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py
index e2b6eaa096..f93d7bda2d 100644
--- a/tensorflow/compiler/xla/python/xla_client.py
+++ b/tensorflow/compiler/xla/python/xla_client.py
@@ -125,6 +125,9 @@ _BINARY_OPS = [
'Or',
'Xor',
'Pow',
+ 'ShiftLeft',
+ 'ShiftRightArithmetic',
+ 'ShiftRightLogical',
]
diff --git a/tensorflow/compiler/xla/python/xla_client_test.py b/tensorflow/compiler/xla/python/xla_client_test.py
index 0564ddcb85..93177aa647 100644
--- a/tensorflow/compiler/xla/python/xla_client_test.py
+++ b/tensorflow/compiler/xla/python/xla_client_test.py
@@ -171,6 +171,24 @@ class ComputationsWithConstantsTest(LocalComputationTest):
c.Constant(NumpyArrayF32([[1, -1, 1], [-1, 1, -1]])))
self._ExecuteAndCompareClose(c, expected=[[2, 1, 4], [3, 6, 5]])
+ def testShiftLeft(self):
+ c = self._NewComputation()
+ c.ShiftLeft(c.Constant(NumpyArrayS32([3])),
+ c.Constant(NumpyArrayS32([2])))
+ self._ExecuteAndCompareClose(c, expected=[12])
+
+ def testShiftRightArithmetic(self):
+ c = self._NewComputation()
+ c.ShiftRightArithmetic(c.Constant(NumpyArrayS32([-2])),
+ c.Constant(NumpyArrayS32([1])))
+ self._ExecuteAndCompareClose(c, expected=[-1])
+
+ def testShiftRightLogical(self):
+ c = self._NewComputation()
+ c.ShiftRightLogical(c.Constant(NumpyArrayS32([-1])),
+ c.Constant(NumpyArrayS32([1])))
+ self._ExecuteAndCompareClose(c, expected=[2**31 - 1])
+
def testGetProto(self):
c = self._NewComputation()
c.Add(
diff --git a/tensorflow/compiler/xla/shape_tree.h b/tensorflow/compiler/xla/shape_tree.h
index c74dd648ad..4aacc87b78 100644
--- a/tensorflow/compiler/xla/shape_tree.h
+++ b/tensorflow/compiler/xla/shape_tree.h
@@ -44,6 +44,10 @@ struct ShapeTreeNode {
// Data corresponding to this node.
std::pair<ShapeIndex, T> data;
+ // Children of this node, as indices into the container's nodes_ array.
+ std::vector<size_t> children;
+
+ // Tells whether this is a leaf node.
bool is_leaf = true;
explicit ShapeTreeNode(ShapeIndex index)
@@ -52,20 +56,6 @@ struct ShapeTreeNode {
: data(std::move(index), std::move(data)) {}
};
-// Internal representation of an index table entry.
-struct IndexTableEntry {
- // Index of the node in the ShapeTreeNode vector.
- uint32 index;
- // Index of the first child in a IndexTableEntry vector. In the index
- // table all children entries for a given node will be placed next to each
- // other. This allows us to use a single field to index them.
- uint32 children_start;
-#ifndef NDEBUG
- // Number of children, used for bounds checking.
- uint32 children_count;
-#endif
-};
-
} // namespace internal
template <typename ContainerType, typename IteratorType, typename ValueType>
@@ -94,7 +84,6 @@ template <typename T>
class ShapeTree {
public:
using Node = internal::ShapeTreeNode<T>;
- using Index = internal::IndexTableEntry;
// Default constructor creates a tree with a nil shape (i.e. an empty tuple).
ShapeTree() : ShapeTree(ShapeUtil::MakeNil()) {}
@@ -278,12 +267,11 @@ class ShapeTree {
private:
// Initialize node->children based on 'shape'. All children are assigned the
// the given 'init_value'.
- void InitChildren(const Shape& shape, const T& init_value, Node* node,
- Index* index);
+ void InitChildren(const Shape& shape, const T& init_value, Node* node);
// Initialize node->children based on 'shape'. All children have
// default-constructed data values.
- void InitChildren(const Shape& shape, Node* node, Index* index);
+ void InitChildren(const Shape& shape, Node* node);
// Returns the number of subshapes, including interior nodes, in shape.
int64 CountSubshapes(const Shape& shape);
@@ -303,9 +291,6 @@ class ShapeTree {
// The nodes in this shape tree.
std::vector<Node> nodes_;
- // Index table for node lookups.
- std::vector<Index> index_table_;
-
// If we own our Shape, this field contains it, and shape_ is a pointer into
// here. Otherwise if we don't own our shape, this is nullptr.
std::shared_ptr<Shape> shape_storage_;
@@ -388,74 +373,36 @@ int64 ShapeTree<T>::CountSubshapes(const Shape& shape) {
template <typename T>
void ShapeTree<T>::InitChildren(const Shape& shape, const T& init_value,
- Node* node, Index* index) {
+ Node* node) {
if (ShapeUtil::IsTuple(shape)) {
const int64 size = ShapeUtil::TupleElementCount(shape);
-#ifndef NDEBUG
- index->children_count = size;
-#endif
+ node->children.reserve(size);
node->is_leaf = false;
ShapeIndex shape_index = node->data.first;
shape_index.push_back(0);
-
- // At the end of the index_table, reserve a continuous space to hold the
- // children of current node. In order to enforce the invariant that all
- // children of a given node are placed together, we need to do the
- // reservation before we recurse into any of its children.
- int64 children_start_position = index_table_.size();
- index_table_.resize(index_table_.size() + size);
-
for (int i = 0; i < size; ++i) {
shape_index[shape_index.size() - 1] = i;
- index_table_[children_start_position + i].index = nodes_.size();
- // The first child of the node in the index table is placed at the end of
- // the table.
- index_table_[children_start_position + i].children_start =
- index_table_.size();
+ node->children.push_back(nodes_.size());
nodes_.emplace_back(shape_index, init_value);
- InitChildren(shape.tuple_shapes(i), init_value, &nodes_.back(),
- &index_table_[children_start_position + i]);
+ InitChildren(shape.tuple_shapes(i), init_value, &nodes_.back());
}
- } else {
-#ifndef NDEBUG
- index->children_count = 0;
-#endif
}
}
template <typename T>
-void ShapeTree<T>::InitChildren(const Shape& shape, Node* node, Index* index) {
+void ShapeTree<T>::InitChildren(const Shape& shape, Node* node) {
if (ShapeUtil::IsTuple(shape)) {
const int64 size = ShapeUtil::TupleElementCount(shape);
-#ifndef NDEBUG
- index->children_count = size;
-#endif
+ node->children.reserve(size);
node->is_leaf = false;
ShapeIndex shape_index = node->data.first;
shape_index.push_back(0);
-
- // At the end of the index_table, reserve a continuous space to hold the
- // children of current node. In order to enforce the invariant that all
- // children of a given node are placed together, we need to do the
- // reservation before we recurse into any of its children.
- int64 children_start_position = index_table_.size();
- index_table_.resize(index_table_.size() + size);
-
for (int i = 0; i < size; ++i) {
shape_index[shape_index.size() - 1] = i;
- index_table_[children_start_position + i].index = nodes_.size();
- // The first child of the node in the index table is placed at the end of
- // the table.
- index_table_[children_start_position + i].children_start =
- index_table_.size();
+ node->children.push_back(nodes_.size());
nodes_.emplace_back(shape_index);
- InitChildren(shape.tuple_shapes(i), &nodes_.back(),
- &index_table_[children_start_position + i]);
+ InitChildren(shape.tuple_shapes(i), &nodes_.back());
}
- } else {
-#ifndef NDEBUG
- index->children_count = 0;
-#endif
}
}
@@ -466,36 +413,24 @@ ShapeTree<T>::ShapeTree(Shape shape)
// The shape_ field is just used to hold the structure of the shape.
// It should not be relied upon to store layout information.
LayoutUtil::ClearLayout(shape_storage_.get());
- const int64 count = CountSubshapes(*shape_);
- nodes_.reserve(count);
+ nodes_.reserve(CountSubshapes(*shape_));
nodes_.emplace_back(ShapeIndex{});
-
- index_table_.reserve(count);
- index_table_.emplace_back(Index{0, 1});
- InitChildren(*shape_, &nodes_[0], &index_table_[0]);
+ InitChildren(*shape_, &nodes_[0]);
}
template <typename T>
ShapeTree<T>::ShapeTree(const Shape* shape) : shape_(shape) {
- const int64 count = CountSubshapes(*shape_);
- nodes_.reserve(count);
+ nodes_.reserve(CountSubshapes(*shape_));
nodes_.emplace_back(ShapeIndex{});
-
- index_table_.reserve(count);
- index_table_.emplace_back(Index{0, 1});
- InitChildren(*shape_, &nodes_[0], &index_table_[0]);
+ InitChildren(*shape_, &nodes_[0]);
}
template <typename T>
ShapeTree<T>::ShapeTree(const std::shared_ptr<Shape>& shape)
: shape_storage_(shape), shape_(shape_storage_.get()) {
- const int64 count = CountSubshapes(*shape_);
- nodes_.reserve(count);
+ nodes_.reserve(CountSubshapes(*shape_));
nodes_.emplace_back(ShapeIndex{});
-
- index_table_.reserve(count);
- index_table_.emplace_back(Index{0, 1});
- InitChildren(*shape_, &nodes_[0], &index_table_[0]);
+ InitChildren(*shape_, &nodes_[0]);
}
template <typename T>
@@ -505,38 +440,26 @@ ShapeTree<T>::ShapeTree(Shape shape, const T& init_value)
// The shape_ field is just used to hold the structure of the shape.
// It should not be relied upon to store layout information.
LayoutUtil::ClearLayout(shape_storage_.get());
- const int64 count = CountSubshapes(*shape_);
- nodes_.reserve(count);
+ nodes_.reserve(CountSubshapes(*shape_));
nodes_.emplace_back(ShapeIndex{}, init_value);
-
- index_table_.reserve(count);
- index_table_.emplace_back(Index{0, 1});
- InitChildren(*shape_, init_value, &nodes_[0], &index_table_[0]);
+ InitChildren(*shape_, init_value, &nodes_[0]);
}
template <typename T>
ShapeTree<T>::ShapeTree(const Shape* shape, const T& init_value)
: shape_(shape) {
- const int64 count = CountSubshapes(*shape_);
- nodes_.reserve(count);
+ nodes_.reserve(CountSubshapes(*shape_));
nodes_.emplace_back(ShapeIndex{}, init_value);
-
- index_table_.reserve(count);
- index_table_.emplace_back(Index{0, 1});
- InitChildren(*shape_, init_value, &nodes_[0], &index_table_[0]);
+ InitChildren(*shape_, init_value, &nodes_[0]);
}
template <typename T>
ShapeTree<T>::ShapeTree(const std::shared_ptr<Shape>& shape,
const T& init_value)
: shape_storage_(shape), shape_(shape_storage_.get()) {
- const int64 count = CountSubshapes(*shape_);
- nodes_.reserve(count);
+ nodes_.reserve(CountSubshapes(*shape_));
nodes_.emplace_back(ShapeIndex{}, init_value);
-
- index_table_.reserve(count);
- index_table_.emplace_back(Index{0, 1});
- InitChildren(*shape_, init_value, &nodes_[0], &index_table_[0]);
+ InitChildren(*shape_, init_value, &nodes_[0]);
}
template <typename T>
@@ -551,16 +474,13 @@ T* ShapeTree<T>::mutable_element(ShapeIndexView index) {
template <typename T>
internal::ShapeTreeNode<T>* ShapeTree<T>::Lookup(ShapeIndexView index) {
- Index* iter = &index_table_[0];
+ Node* node = &nodes_[0];
for (const int64 i : index) {
CHECK_GE(i, 0);
-#ifndef NDEBUG
- CHECK_LT(i, iter->children_count);
-#endif
- iter = &index_table_[iter->children_start + i];
+ CHECK_LT(i, node->children.size());
+ node = &nodes_[node->children[i]];
}
-
- return &nodes_[iter->index];
+ return node;
}
template <typename T>
diff --git a/tensorflow/compiler/xla/shape_tree_test.cc b/tensorflow/compiler/xla/shape_tree_test.cc
index 4391078b64..51de82e957 100644
--- a/tensorflow/compiler/xla/shape_tree_test.cc
+++ b/tensorflow/compiler/xla/shape_tree_test.cc
@@ -227,16 +227,14 @@ TEST_F(ShapeTreeTest, NestedTupleShape) {
TEST_F(ShapeTreeTest, InvalidIndexingTuple) {
ShapeTree<int> shape_tree{tuple_shape_};
-#ifndef NDEBUG
+
EXPECT_DEATH(shape_tree.element({4}), "");
-#endif
}
TEST_F(ShapeTreeTest, InvalidIndexingNestedTuple) {
ShapeTree<int> shape_tree{nested_tuple_shape_};
-#ifndef NDEBUG
+
EXPECT_DEATH(shape_tree.element({0, 0}), "");
-#endif
}
TEST_F(ShapeTreeTest, ShapeTreeOfNonCopyableType) {
@@ -604,15 +602,12 @@ void BM_Iterate(int iters, int depth, int fan_out) {
}
}
-#define BENCHMARK_WITH_ARGS(name) \
- BENCHMARK(name)->ArgPair(2, 8)->ArgPair(1, 1000)
-
-BENCHMARK_WITH_ARGS(BM_Construct);
-BENCHMARK_WITH_ARGS(BM_ConstructUnowned);
-BENCHMARK_WITH_ARGS(BM_Copy);
-BENCHMARK_WITH_ARGS(BM_Move);
-BENCHMARK_WITH_ARGS(BM_ForEach);
-BENCHMARK_WITH_ARGS(BM_Iterate);
+BENCHMARK(BM_Construct)->ArgPair(2, 8);
+BENCHMARK(BM_ConstructUnowned)->ArgPair(2, 8);
+BENCHMARK(BM_Copy)->ArgPair(2, 8);
+BENCHMARK(BM_Move)->ArgPair(2, 8);
+BENCHMARK(BM_ForEach)->ArgPair(2, 8);
+BENCHMARK(BM_Iterate)->ArgPair(2, 8);
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h
index 83d15e8fe3..17c1d7b10a 100644
--- a/tensorflow/compiler/xla/shape_util.h
+++ b/tensorflow/compiler/xla/shape_util.h
@@ -31,7 +31,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
-#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/lib/gtl/optional.h"
#include "tensorflow/core/platform/cpu_info.h"
#include "tensorflow/core/platform/env.h"
@@ -74,12 +73,10 @@ class ShapeIndex {
// push_front is O(n^2), but shapes don't usually have a ton of dimensions.
void push_front(int64 value) { indices_.insert(indices_.begin(), value); }
- using container_type = gtl::InlinedVector<int64, 2>;
-
- container_type::const_iterator begin() const { return indices_.begin(); }
- container_type::const_iterator end() const { return indices_.end(); }
- container_type::iterator begin() { return indices_.begin(); }
- container_type::iterator end() { return indices_.end(); }
+ std::vector<int64>::const_iterator begin() const { return indices_.begin(); }
+ std::vector<int64>::const_iterator end() const { return indices_.end(); }
+ std::vector<int64>::iterator begin() { return indices_.begin(); }
+ std::vector<int64>::iterator end() { return indices_.end(); }
const int64* data() const { return indices_.data(); }
@@ -100,7 +97,7 @@ class ShapeIndex {
string ToString() const;
private:
- container_type indices_;
+ std::vector<int64> indices_;
};
// A view into a ShapeIndex as above, with the cheap/easy ability to consume the
diff --git a/tensorflow/contrib/autograph/converters/control_flow.py b/tensorflow/contrib/autograph/converters/control_flow.py
index a25232f713..5a5a6ad63a 100644
--- a/tensorflow/contrib/autograph/converters/control_flow.py
+++ b/tensorflow/contrib/autograph/converters/control_flow.py
@@ -171,8 +171,8 @@ class ControlFlowTransformer(converter.Base):
# actually has some return value as well.
cond_results = None
# TODO(mdan): This doesn't belong here; it's specific to the operator.
- returned_from_body = templates.replace_as_expression('1')
- returned_from_orelse = templates.replace_as_expression('1')
+ returned_from_body = templates.replace_as_expression('tf.constant(1)')
+ returned_from_orelse = templates.replace_as_expression('tf.constant(1)')
body_name = self.ctx.namer.new_symbol('if_true', body_scope.referenced)
orelse_name = self.ctx.namer.new_symbol('if_false', orelse_scope.referenced)
diff --git a/tensorflow/contrib/autograph/converters/control_flow_test.py b/tensorflow/contrib/autograph/converters/control_flow_test.py
index 6670b8a66f..ade3501426 100644
--- a/tensorflow/contrib/autograph/converters/control_flow_test.py
+++ b/tensorflow/contrib/autograph/converters/control_flow_test.py
@@ -31,7 +31,8 @@ class ControlFlowTest(converter_testing.TestCase):
def assertTransformedResult(self, test_fn, inputs, expected):
if not isinstance(inputs, tuple):
inputs = (inputs,)
- with self.converted(test_fn, control_flow, {}) as result:
+ with self.converted(test_fn, control_flow, {},
+ constant_op.constant) as result:
with self.test_session() as sess:
self.assertEqual(sess.run(result.test_fn(*inputs)), expected)
diff --git a/tensorflow/contrib/checkpoint/python/containers.py b/tensorflow/contrib/checkpoint/python/containers.py
index 4d3d531299..242c1e8ba4 100644
--- a/tensorflow/contrib/checkpoint/python/containers.py
+++ b/tensorflow/contrib/checkpoint/python/containers.py
@@ -35,9 +35,9 @@ class UniqueNameTracker(data_structures.CheckpointableDataStructure):
self.slotdeps = tf.contrib.checkpoint.UniqueNameTracker()
slotdeps = self.slotdeps
slots = []
- slots.append(slotdeps.track(tfe.Variable(3.), "x")) # Named "x"
- slots.append(slotdeps.track(tfe.Variable(4.), "y"))
- slots.append(slotdeps.track(tfe.Variable(5.), "x")) # Named "x_1"
+ slots.append(slotdeps.track(tf.Variable(3.), "x")) # Named "x"
+ slots.append(slotdeps.track(tf.Variable(4.), "y"))
+ slots.append(slotdeps.track(tf.Variable(5.), "x")) # Named "x_1"
```
"""
diff --git a/tensorflow/contrib/cmake/tf_core_kernels.cmake b/tensorflow/contrib/cmake/tf_core_kernels.cmake
index 844f62649d..7b892ba248 100644
--- a/tensorflow/contrib/cmake/tf_core_kernels.cmake
+++ b/tensorflow/contrib/cmake/tf_core_kernels.cmake
@@ -68,6 +68,7 @@ if(tensorflow_BUILD_CONTRIB_KERNELS)
"${tensorflow_source_dir}/tensorflow/contrib/coder/kernels/range_coder_ops.cc"
"${tensorflow_source_dir}/tensorflow/contrib/coder/kernels/range_coder_ops_util.cc"
"${tensorflow_source_dir}/tensorflow/contrib/coder/ops/coder_ops.cc"
+ "${tensorflow_source_dir}/tensorflow/contrib/data/kernels/assert_next_dataset_op.cc"
"${tensorflow_source_dir}/tensorflow/contrib/data/kernels/csv_dataset_op.cc"
"${tensorflow_source_dir}/tensorflow/contrib/data/kernels/directed_interleave_dataset_op.cc"
"${tensorflow_source_dir}/tensorflow/contrib/data/kernels/ignore_errors_dataset_op.cc"
diff --git a/tensorflow/contrib/data/kernels/BUILD b/tensorflow/contrib/data/kernels/BUILD
index 7b69e10441..566cbb246a 100644
--- a/tensorflow/contrib/data/kernels/BUILD
+++ b/tensorflow/contrib/data/kernels/BUILD
@@ -71,8 +71,19 @@ cc_library(
)
cc_library(
+ name = "assert_next_dataset_op",
+ srcs = ["assert_next_dataset_op.cc"],
+ deps = [
+ "//tensorflow/core:framework_headers_lib",
+ "//third_party/eigen3",
+ "@protobuf_archive//:protobuf_headers",
+ ],
+)
+
+cc_library(
name = "dataset_kernels",
deps = [
+ ":assert_next_dataset_op",
":csv_dataset_op",
":directed_interleave_dataset_op",
":ignore_errors_dataset_op",
diff --git a/tensorflow/contrib/data/kernels/assert_next_dataset_op.cc b/tensorflow/contrib/data/kernels/assert_next_dataset_op.cc
new file mode 100644
index 0000000000..95b8e1f7fd
--- /dev/null
+++ b/tensorflow/contrib/data/kernels/assert_next_dataset_op.cc
@@ -0,0 +1,152 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <map>
+
+#include "tensorflow/core/framework/dataset.h"
+#include "tensorflow/core/framework/partial_tensor_shape.h"
+#include "tensorflow/core/framework/tensor.h"
+
+namespace tensorflow {
+namespace {
+
+// See documentation in ../ops/dataset_ops.cc for a high-level
+// description of the following op.
+class AssertNextDatasetOp : public UnaryDatasetOpKernel {
+ public:
+ explicit AssertNextDatasetOp(OpKernelConstruction* ctx)
+ : UnaryDatasetOpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
+ }
+
+ protected:
+ void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
+ DatasetBase** output) override {
+ std::vector<string> transformations;
+ OP_REQUIRES_OK(ctx, ParseVectorArgument<string>(ctx, "transformations",
+ &transformations));
+ *output =
+ new Dataset(ctx, input, transformations, output_types_, output_shapes_);
+ }
+
+ private:
+ class Dataset : public GraphDatasetBase {
+ public:
+ Dataset(OpKernelContext* ctx, const DatasetBase* input,
+ const std::vector<string>& transformations,
+ const DataTypeVector& output_types,
+ const std::vector<PartialTensorShape>& output_shapes)
+ : GraphDatasetBase(ctx),
+ input_(input),
+ transformations_(transformations),
+ output_types_(output_types),
+ output_shapes_(output_shapes) {
+ input_->Ref();
+ }
+
+ ~Dataset() override { input_->Unref(); }
+
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
+ const string& prefix) const override {
+ return std::unique_ptr<IteratorBase>(
+ new Iterator({this, strings::StrCat(prefix, "::Assert")}));
+ }
+
+ const DataTypeVector& output_dtypes() const override {
+ return output_types_;
+ }
+ const std::vector<PartialTensorShape>& output_shapes() const override {
+ return output_shapes_;
+ }
+
+ string DebugString() const override {
+ return "AssertNextDatasetOp::Dataset";
+ }
+
+ protected:
+ Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b,
+ Node** output) const override {
+ Node* input_graph_node = nullptr;
+ TF_RETURN_IF_ERROR(b->AddParentDataset(ctx, input_, &input_graph_node));
+ Node* transformations_node = nullptr;
+ TF_RETURN_IF_ERROR(b->AddVector(transformations_, &transformations_node));
+ TF_RETURN_IF_ERROR(b->AddDataset(
+ this, {input_graph_node, transformations_node}, output));
+ return Status::OK();
+ }
+
+ private:
+ class Iterator : public DatasetIterator<Dataset> {
+ public:
+ explicit Iterator(const Params& params)
+ : DatasetIterator<Dataset>(params) {}
+
+ Status Initialize(IteratorContext* ctx) override {
+ std::vector<string> tokens =
+ str_util::Split(prefix(), ':', str_util::SkipEmpty());
+ if (dataset()->transformations_.size() > tokens.size() - 2) {
+ return errors::InvalidArgument(
+ "Asserted next ", dataset()->transformations_.size(),
+ " transformations but encountered only ", tokens.size() - 2, ".");
+ }
+ int n = tokens.size();
+ for (size_t i = 0; i < dataset()->transformations_.size(); ++i) {
+ if (dataset()->transformations_[i] != tokens[n - 2 - i]) {
+ return errors::InvalidArgument(
+ "Asserted ", dataset()->transformations_[i],
+ " transformation at offset ", i, " but encountered ",
+ tokens[n - 2 - i], " transformation instead.");
+ }
+ }
+ return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+ }
+
+ Status GetNextInternal(IteratorContext* ctx,
+ std::vector<Tensor>* out_tensors,
+ bool* end_of_sequence) override {
+ return input_impl_->GetNext(ctx, out_tensors, end_of_sequence);
+ }
+
+ protected:
+ Status SaveInternal(IteratorStateWriter* writer) override {
+ TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_));
+ return Status::OK();
+ }
+
+ Status RestoreInternal(IteratorContext* ctx,
+ IteratorStateReader* reader) override {
+ TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
+ return Status::OK();
+ }
+
+ private:
+ std::unique_ptr<IteratorBase> input_impl_;
+ };
+
+ const DatasetBase* input_;
+ const std::vector<string> transformations_;
+ const DataTypeVector output_types_;
+ const std::vector<PartialTensorShape> output_shapes_;
+ };
+
+ DataTypeVector output_types_;
+ std::vector<PartialTensorShape> output_shapes_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("AssertNextDataset").Device(DEVICE_CPU),
+ AssertNextDatasetOp);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/contrib/data/kernels/csv_dataset_op.cc b/tensorflow/contrib/data/kernels/csv_dataset_op.cc
index dadde705e1..f7e3ed886c 100644
--- a/tensorflow/contrib/data/kernels/csv_dataset_op.cc
+++ b/tensorflow/contrib/data/kernels/csv_dataset_op.cc
@@ -150,6 +150,7 @@ class CSVDatasetOp : public DatasetOpKernel {
delim_(delim),
na_value_(std::move(na_value)),
use_compression_(!compression_type.empty()),
+ compression_type_(std::move(compression_type)),
options_(options) {}
std::unique_ptr<IteratorBase> MakeIteratorInternal(
@@ -169,10 +170,45 @@ class CSVDatasetOp : public DatasetOpKernel {
protected:
Status AsGraphDefInternal(DatasetGraphDefBuilder* b,
Node** output) const override {
- // TODO(rachelim): Implement this
- std::vector<Node*> input_tensors;
- TF_RETURN_IF_ERROR(b->AddDataset(this, input_tensors, output));
- return errors::Unimplemented("CSVDataset: AsGraphDefInternal");
+ Node* filenames = nullptr;
+ Node* compression_type = nullptr;
+ Node* buffer_size = nullptr;
+ Node* header = nullptr;
+ Node* delim = nullptr;
+ Node* use_quote_delim = nullptr;
+ Node* na_value = nullptr;
+ Node* select_cols = nullptr;
+
+ std::vector<Node*> record_defaults;
+ record_defaults.reserve(record_defaults_.size());
+ for (const Tensor& t : record_defaults_) {
+ Node* node;
+ TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
+ record_defaults.emplace_back(node);
+ }
+
+ TF_RETURN_IF_ERROR(b->AddVector(filenames_, &filenames));
+ TF_RETURN_IF_ERROR(b->AddScalar(compression_type_, &compression_type));
+ TF_RETURN_IF_ERROR(
+ b->AddScalar(options_.input_buffer_size, &buffer_size));
+ TF_RETURN_IF_ERROR(b->AddScalar(header_, &header));
+
+ string delim_string(1, delim_);
+ TF_RETURN_IF_ERROR(b->AddScalar(delim_string, &delim));
+ TF_RETURN_IF_ERROR(b->AddScalar(use_quote_delim_, &use_quote_delim));
+ TF_RETURN_IF_ERROR(b->AddScalar(na_value_, &na_value));
+ TF_RETURN_IF_ERROR(b->AddVector(select_cols_, &select_cols));
+
+ TF_RETURN_IF_ERROR(b->AddDataset(
+ this,
+ {std::make_pair(0, filenames), std::make_pair(1, compression_type),
+ std::make_pair(2, buffer_size), std::make_pair(3, header),
+ std::make_pair(4, delim), std::make_pair(5, use_quote_delim),
+ std::make_pair(6, na_value),
+ std::make_pair(7, select_cols)}, // Single tensor inputs
+ {std::make_pair(8, record_defaults)}, // Tensor list inputs
+ {}, output));
+ return Status::OK();
}
private:
@@ -224,14 +260,58 @@ class CSVDatasetOp : public DatasetOpKernel {
protected:
Status SaveInternal(IteratorStateWriter* writer) override {
mutex_lock l(mu_);
- // TODO(rachelim): Implement save
- return errors::Unimplemented("CSVDataset: SaveInternal");
+ TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("current_file_index"),
+ current_file_index_));
+ // `input_stream_` is empty if
+ // 1. GetNext has not been called even once.
+ // 2. All files have been read and the iterator has been exhausted.
+ if (input_stream_ && num_buffer_reads_ > 0) {
+ TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("pos"), pos_));
+ // If num_buffer_reads_ == 0, the buffer hasn't been filled even once.
+ TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("num_buffer_reads"),
+ num_buffer_reads_));
+ }
+ return Status::OK();
}
+
Status RestoreInternal(IteratorContext* ctx,
IteratorStateReader* reader) override {
mutex_lock l(mu_);
- // TODO(rachelim): Implement restore
- return errors::Unimplemented("CSVDataset: RestoreInternal");
+ ResetStreamsLocked();
+ int64 current_file_index;
+ TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("current_file_index"),
+ &current_file_index));
+ current_file_index_ = size_t(current_file_index);
+ // The keys "pos" and "num_buffer_reads" are written only if
+ // the iterator was saved with an open, partially read file.
+ if (reader->Contains(full_name("pos"))) {
+ int64 pos, num_buffer_reads;
+ TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("pos"), &pos));
+ TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("num_buffer_reads"),
+ &num_buffer_reads));
+
+ TF_RETURN_IF_ERROR(SetupStreamsLocked(ctx->env()));
+
+ num_buffer_reads_ = size_t(num_buffer_reads - 1);
+
+ // Restores the most recently held buffer
+ Status s = input_stream_->SkipNBytes(
+ num_buffer_reads_ * dataset()->options_.input_buffer_size);
+ if (!s.ok() && !errors::IsOutOfRange(s)) {
+ // We might get out of range error here if the size of the file
+ // is not an exact multiple of the buffer size, and the last buffer
+ // read is < buffer_size. This is valid and we do not surface the
+ // error.
+ return s;
+ }
+
+ Status s2 = FillBuffer(&buffer_);
+ if (!s2.ok() && !errors::IsOutOfRange(s2)) {
+ return s2;
+ }
+ pos_ = size_t(pos);
+ }
+ return Status::OK();
}
private:
@@ -533,6 +613,7 @@ class CSVDatasetOp : public DatasetOpKernel {
Status FillBuffer(string* result) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
result->clear();
+ ++num_buffer_reads_;
Status s = input_stream_->ReadNBytes(
dataset()->options_.input_buffer_size, result);
@@ -712,6 +793,7 @@ class CSVDatasetOp : public DatasetOpKernel {
}
buffer_.clear();
pos_ = 0;
+ num_buffer_reads_ = 0;
if (dataset()->header_) {
// Read one line, but don't include it. Pass nullptrs as dummy
// pointers to objects that shouldn't be invoked anyway
@@ -737,6 +819,7 @@ class CSVDatasetOp : public DatasetOpKernel {
string buffer_ GUARDED_BY(mu_); // Maintain our own buffer
size_t pos_ GUARDED_BY(
mu_); // Index into the buffer must be maintained between iters
+ size_t num_buffer_reads_ GUARDED_BY(mu_);
std::shared_ptr<io::RandomAccessInputStream> random_access_input_stream_
GUARDED_BY(mu_);
std::shared_ptr<io::InputStreamInterface> input_stream_ GUARDED_BY(mu_);
@@ -755,6 +838,7 @@ class CSVDatasetOp : public DatasetOpKernel {
const char delim_;
const string na_value_;
const bool use_compression_;
+ const string compression_type_;
const io::ZlibCompressionOptions options_;
}; // class Dataset
diff --git a/tensorflow/contrib/data/ops/dataset_ops.cc b/tensorflow/contrib/data/ops/dataset_ops.cc
index a623c27ff8..b5c6f2e241 100644
--- a/tensorflow/contrib/data/ops/dataset_ops.cc
+++ b/tensorflow/contrib/data/ops/dataset_ops.cc
@@ -177,4 +177,17 @@ display_name: A human-readable name for the threads that may be visible in
some visualizations.
)doc");
+REGISTER_OP("AssertNextDataset")
+ .Input("input_dataset: variant")
+ .Input("transformations: string")
+ .Output("handle: variant")
+ .Attr("output_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ shape_inference::ShapeHandle unused;
+ // transformations should be a vector.
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
+ return shape_inference::ScalarShape(c);
+ });
+
} // namespace tensorflow
diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD
index 18457320b9..d372bed479 100644
--- a/tensorflow/contrib/data/python/kernel_tests/BUILD
+++ b/tensorflow/contrib/data/python/kernel_tests/BUILD
@@ -208,7 +208,6 @@ py_test(
srcs_version = "PY2AND3",
deps = [
"//tensorflow/contrib/data/python/ops:optimization",
- "//tensorflow/core:protos_all_py",
"//tensorflow/python:client_testlib",
"//tensorflow/python:errors",
"//tensorflow/python/data/ops:dataset_ops",
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py
index 21eebccd11..cfef40e192 100644
--- a/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py
@@ -18,7 +18,6 @@ from __future__ import division
from __future__ import print_function
from tensorflow.contrib.data.python.ops import optimization
-from tensorflow.core.framework import graph_pb2
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import errors
from tensorflow.python.platform import test
@@ -26,41 +25,76 @@ from tensorflow.python.platform import test
class OptimizeDatasetTest(test.TestCase):
+ def testAssertSuffix(self):
+ dataset = dataset_ops.Dataset.from_tensors(0).apply(
+ optimization.assert_next(["Map"])).map(lambda x: x)
+ iterator = dataset.make_one_shot_iterator()
+ get_next = iterator.get_next()
+
+ with self.test_session() as sess:
+ self.assertEqual(0, sess.run(get_next))
+
+ def testAssertSuffixInvalid(self):
+ dataset = dataset_ops.Dataset.from_tensors(0).apply(
+ optimization.assert_next(["Whoops"])).map(lambda x: x)
+ iterator = dataset.make_one_shot_iterator()
+ get_next = iterator.get_next()
+
+ with self.test_session() as sess:
+ with self.assertRaisesRegexp(
+ errors.InvalidArgumentError,
+ "Asserted Whoops transformation at offset 0 but encountered "
+ "Map transformation instead."
+ ):
+ sess.run(get_next)
+
+ def testAssertSuffixShort(self):
+ dataset = dataset_ops.Dataset.from_tensors(0).apply(
+ optimization.assert_next(["Map", "Whoops"])).map(lambda x: x)
+ iterator = dataset.make_one_shot_iterator()
+ get_next = iterator.get_next()
+
+ with self.test_session() as sess:
+ with self.assertRaisesRegexp(
+ errors.InvalidArgumentError,
+ "Asserted next 2 transformations but encountered only 1."):
+ sess.run(get_next)
+
def testDefaultOptimizations(self):
- dataset = dataset_ops.Dataset.range(10).map(lambda x: x * x).batch(
- 10).apply(optimization.optimize())
+ dataset = dataset_ops.Dataset.range(10).apply(
+ optimization.assert_next(
+ ["Map", "Batch"])).map(lambda x: x * x).batch(10).apply(
+ optimization.optimize())
iterator = dataset.make_one_shot_iterator()
get_next = iterator.get_next()
with self.test_session() as sess:
- graph = graph_pb2.GraphDef().FromString(
- sess.run(dataset._as_serialized_graph()))
self.assertAllEqual([x * x for x in range(10)], sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
def testEmptyOptimizations(self):
- dataset = dataset_ops.Dataset.range(10).map(lambda x: x * x).batch(
- 10).apply(optimization.optimize([]))
+ dataset = dataset_ops.Dataset.range(10).apply(
+ optimization.assert_next(
+ ["Map", "Batch"])).map(lambda x: x * x).batch(10).apply(
+ optimization.optimize([]))
iterator = dataset.make_one_shot_iterator()
get_next = iterator.get_next()
with self.test_session() as sess:
- graph = graph_pb2.GraphDef().FromString(
- sess.run(dataset._as_serialized_graph()))
self.assertAllEqual([x * x for x in range(10)], sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
def testOptimization(self):
- dataset = dataset_ops.Dataset.range(10).map(lambda x: x * x).batch(
- 10).apply(optimization.optimize(["map_and_batch_fusion"]))
+ dataset = dataset_ops.Dataset.range(10).apply(
+ optimization.assert_next(
+ ["MapAndBatch"])).map(lambda x: x * x).batch(10).apply(
+ optimization.optimize(["map_and_batch_fusion"]))
iterator = dataset.make_one_shot_iterator()
get_next = iterator.get_next()
with self.test_session() as sess:
- graph = graph_pb2.GraphDef().FromString(
- sess.run(dataset._as_serialized_graph()))
self.assertAllEqual([x * x for x in range(10)], sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/BUILD b/tensorflow/contrib/data/python/kernel_tests/serialization/BUILD
index 686788522a..3c3f23f9a9 100644
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/BUILD
+++ b/tensorflow/contrib/data/python/kernel_tests/serialization/BUILD
@@ -73,6 +73,20 @@ py_test(
)
py_test(
+ name = "csv_dataset_serialization_test",
+ size = "small",
+ srcs = ["csv_dataset_serialization_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["no_pip"],
+ deps = [
+ ":dataset_serialization_test_base",
+ "//tensorflow/contrib/data/python/ops:readers",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_ops",
+ ],
+)
+
+py_test(
name = "dataset_constructor_serialization_test",
size = "medium",
srcs = ["dataset_constructor_serialization_test.py"],
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/csv_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/csv_dataset_serialization_test.py
new file mode 100644
index 0000000000..247f2046ea
--- /dev/null
+++ b/tensorflow/contrib/data/python/kernel_tests/serialization/csv_dataset_serialization_test.py
@@ -0,0 +1,73 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the CsvDataset serialization."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import gzip
+import os
+
+from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.contrib.data.python.ops import readers
+from tensorflow.python.platform import test
+
+
+class CsvDatasetSerializationTest(
+ dataset_serialization_test_base.DatasetSerializationTestBase):
+
+ def setUp(self):
+ self._num_cols = 7
+ self._num_rows = 10
+ self._num_epochs = 14
+ self._num_outputs = self._num_rows * self._num_epochs
+
+ inputs = [
+ ",".join(str(self._num_cols * j + i)
+ for i in range(self._num_cols))
+ for j in range(self._num_rows)
+ ]
+ contents = "\n".join(inputs).encode("utf-8")
+
+ self._filename = os.path.join(self.get_temp_dir(), "file.csv")
+ self._compressed = os.path.join(self.get_temp_dir(),
+ "comp.csv") # GZip compressed
+
+ with open(self._filename, "wb") as f:
+ f.write(contents)
+ with gzip.GzipFile(self._compressed, "wb") as f:
+ f.write(contents)
+
+ def ds_func(self, **kwargs):
+ compression_type = kwargs.get("compression_type", None)
+ if compression_type == "GZIP":
+ filename = self._compressed
+ elif compression_type is None:
+ filename = self._filename
+ else:
+ raise ValueError("Invalid compression type:", compression_type)
+
+ return readers.CsvDataset(filename, **kwargs).repeat(self._num_epochs)
+
+ def testSerializationCore(self):
+ defs = [[0]] * self._num_cols
+ self.run_core_tests(
+ lambda: self.ds_func(record_defaults=defs, buffer_size=2),
+ lambda: self.ds_func(record_defaults=defs, buffer_size=12),
+ self._num_outputs)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/data/python/ops/optimization.py b/tensorflow/contrib/data/python/ops/optimization.py
index cf89657226..018c5115e1 100644
--- a/tensorflow/contrib/data/python/ops/optimization.py
+++ b/tensorflow/contrib/data/python/ops/optimization.py
@@ -18,12 +18,34 @@ from __future__ import division
from __future__ import print_function
from tensorflow.contrib.data.python.ops import contrib_op_loader # pylint: disable=unused-import
+from tensorflow.contrib.data.python.ops import gen_dataset_ops as contrib_gen_dataset_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import gen_dataset_ops
+# TODO(jsimsa): Support RE matching for both individual transformation (e.g. to
+# account for indexing) and transformation sequence.
+def assert_next(transformations):
+ """A transformation that asserts which transformations happen next.
+
+ Args:
+ transformations: A `tf.string` vector `tf.Tensor` identifying the
+ transformations that are expected to happen next.
+
+ Returns:
+ A `Dataset` transformation function, which can be passed to
+ @{tf.data.Dataset.apply}.
+ """
+
+ def _apply_fn(dataset):
+ """Function from `Dataset` to `Dataset` that applies the transformation."""
+ return _AssertNextDataset(dataset, transformations)
+
+ return _apply_fn
+
+
def optimize(optimizations=None):
"""A transformation that applies optimizations.
@@ -44,6 +66,37 @@ def optimize(optimizations=None):
return _apply_fn
+class _AssertNextDataset(dataset_ops.Dataset):
+ """A `Dataset` that asserts which transformations happen next."""
+
+ def __init__(self, input_dataset, transformations):
+ """See `assert_next()` for details."""
+ super(_AssertNextDataset, self).__init__()
+ self._input_dataset = input_dataset
+ if transformations is None:
+ raise ValueError("At least one transformation should be specified")
+ self._transformations = ops.convert_to_tensor(
+ transformations, dtype=dtypes.string, name="transformations")
+
+ def _as_variant_tensor(self):
+ return contrib_gen_dataset_ops.assert_next_dataset(
+ self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
+ self._transformations,
+ **dataset_ops.flat_structure(self))
+
+ @property
+ def output_classes(self):
+ return self._input_dataset.output_classes
+
+ @property
+ def output_shapes(self):
+ return self._input_dataset.output_shapes
+
+ @property
+ def output_types(self):
+ return self._input_dataset.output_types
+
+
class _OptimizeDataset(dataset_ops.Dataset):
"""A `Dataset` that acts as an identity, and applies optimizations."""
diff --git a/tensorflow/contrib/eager/python/examples/generative_examples/dcgan.ipynb b/tensorflow/contrib/eager/python/examples/generative_examples/dcgan.ipynb
new file mode 100644
index 0000000000..43c8c355dc
--- /dev/null
+++ b/tensorflow/contrib/eager/python/examples/generative_examples/dcgan.ipynb
@@ -0,0 +1,711 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "0TD5ZrvEMbhZ"
+ },
+ "source": [
+ "##### Copyright 2018 The TensorFlow Authors.\n",
+ "\n",
+ "Licensed under the Apache License, Version 2.0 (the \"License\").\n",
+ "\n",
+ "# DCGAN: An example with tf.keras and eager\n",
+ "\n",
+ "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\u003ctd\u003e\n",
+ "\u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/generative_examples/dcgan.ipynb\"\u003e\n",
+ " \u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e \n",
+ "\u003c/td\u003e\u003ctd\u003e\n",
+ "\u003ca target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/eager/python/examples/generative_examples/dcgan.ipynb\"\u003e\u003cimg width=32px src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\u003c/td\u003e\u003c/table\u003e"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "ITZuApL56Mny"
+ },
+ "source": [
+ "This notebook demonstrates how to generate images of handwritten digits using [tf.keras](https://www.tensorflow.org/programmers_guide/keras) and [eager execution](https://www.tensorflow.org/programmers_guide/eager). To do this, we use Deep Convolutional Generative Adverserial Networks ([DCGAN](https://arxiv.org/pdf/1511.06434.pdf)).\n",
+ "\n",
+ "On a colab GPU(Tesla K80), the model takes around 40 seconds per epoch to train.\n",
+ "\n",
+ "Below is the output generated after training the generator and discriminator models for 100 epochs.\n",
+ "\n",
+ "![sample output](https://tensorflow.org/images/gan/dcgan.gif)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ },
+ "colab_type": "code",
+ "id": "u_2z-B3piVsw"
+ },
+ "outputs": [],
+ "source": [
+ "# to generate gifs\n",
+ "!pip install imageio"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "e1_Y75QXJS6h"
+ },
+ "source": [
+ "## Import TensorFlow and enable eager execution"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ },
+ "colab_type": "code",
+ "id": "YfIk2es3hJEd"
+ },
+ "outputs": [],
+ "source": [
+ "# Import TensorFlow \u003e= 1.9 and enable eager execution\n",
+ "import tensorflow as tf\n",
+ "tf.enable_eager_execution()\n",
+ "\n",
+ "import os\n",
+ "import time\n",
+ "import numpy as np\n",
+ "import glob\n",
+ "import matplotlib.pyplot as plt\n",
+ "import PIL\n",
+ "import imageio\n",
+ "from IPython import display"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "iYn4MdZnKCey"
+ },
+ "source": [
+ "## Load the dataset\n",
+ "\n",
+ "We are going to use the MNIST dataset to train the generator and the discriminator. The generator will then generate handwritten digits."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ },
+ "colab_type": "code",
+ "id": "a4fYMGxGhrna"
+ },
+ "outputs": [],
+ "source": [
+ "(train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ },
+ "colab_type": "code",
+ "id": "NFC2ghIdiZYE"
+ },
+ "outputs": [],
+ "source": [
+ "train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')\n",
+ "# We are normalizing the images to the range of [-1, 1]\n",
+ "train_images = (train_images - 127.5) / 127.5"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ },
+ "colab_type": "code",
+ "id": "S4PIDhoDLbsZ"
+ },
+ "outputs": [],
+ "source": [
+ "BUFFER_SIZE = 60000\n",
+ "BATCH_SIZE = 256"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "PIGN6ouoQxt3"
+ },
+ "source": [
+ "## Use tf.data to create batches and shuffle the dataset"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ },
+ "colab_type": "code",
+ "id": "-yKCCQOoJ7cn"
+ },
+ "outputs": [],
+ "source": [
+ "train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "THY-sZMiQ4UV"
+ },
+ "source": [
+ "## Write the generator and discriminator models\n",
+ "\n",
+ "* **Generator** \n",
+ " * It is responsible for **creating the convincing images good enough to fool the discriminator**.\n",
+ " * It consists of Conv2DTranspose(Upsampling) layers. We start with a fully connected layer and upsample the image 2 times so as to reach the desired image size(mnist image size) which is (28, 28, 1). \n",
+ " * We use **leaky relu** activation except for the **last layer** which uses **tanh** activation.\n",
+ " \n",
+ "* **Discriminator**\n",
+ " * **The discriminator is responsible for classifying the fake images from the real images.**\n",
+ " * In other words, the discriminator is given generated images(from the generator) and the real MNIST images. The job of the discriminator is to classify these images into fake(generated) and real(MNIST images).\n",
+ " * **Basically the generator should be good enough to fool the discriminator that the generated images are real**."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ },
+ "colab_type": "code",
+ "id": "VGLbvBEmjK0a"
+ },
+ "outputs": [],
+ "source": [
+ "class Generator(tf.keras.Model):\n",
+ " def __init__(self):\n",
+ " super(Generator, self).__init__()\n",
+ " self.fc1 = tf.keras.layers.Dense(7*7*64, use_bias=False)\n",
+ " self.batchnorm1 = tf.keras.layers.BatchNormalization()\n",
+ " \n",
+ " self.conv1 = tf.keras.layers.Conv2DTranspose(64, (5, 5), strides=(1, 1), padding='same', use_bias=False)\n",
+ " self.batchnorm2 = tf.keras.layers.BatchNormalization()\n",
+ " \n",
+ " self.conv2 = tf.keras.layers.Conv2DTranspose(32, (5, 5), strides=(2, 2), padding='same', use_bias=False)\n",
+ " self.batchnorm3 = tf.keras.layers.BatchNormalization()\n",
+ " \n",
+ " self.conv3 = tf.keras.layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False)\n",
+ "\n",
+ " def call(self, x, training=True):\n",
+ " x = self.fc1(x)\n",
+ " x = self.batchnorm1(x, training=training)\n",
+ " x = tf.nn.relu(x)\n",
+ "\n",
+ " x = tf.reshape(x, shape=(-1, 7, 7, 64))\n",
+ "\n",
+ " x = self.conv1(x)\n",
+ " x = self.batchnorm2(x, training=training)\n",
+ " x = tf.nn.relu(x)\n",
+ "\n",
+ " x = self.conv2(x)\n",
+ " x = self.batchnorm3(x, training=training)\n",
+ " x = tf.nn.relu(x)\n",
+ "\n",
+ " x = tf.nn.tanh(self.conv3(x)) \n",
+ " return x"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ },
+ "colab_type": "code",
+ "id": "bkOfJxk5j5Hi"
+ },
+ "outputs": [],
+ "source": [
+ "class Discriminator(tf.keras.Model):\n",
+ " def __init__(self):\n",
+ " super(Discriminator, self).__init__()\n",
+ " self.conv1 = tf.keras.layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same')\n",
+ " self.conv2 = tf.keras.layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same')\n",
+ " self.dropout = tf.keras.layers.Dropout(0.3)\n",
+ " self.flatten = tf.keras.layers.Flatten()\n",
+ " self.fc1 = tf.keras.layers.Dense(1)\n",
+ "\n",
+ " def call(self, x, training=True):\n",
+ " x = tf.nn.leaky_relu(self.conv1(x))\n",
+ " x = self.dropout(x, training=training)\n",
+ " x = tf.nn.leaky_relu(self.conv2(x))\n",
+ " x = self.dropout(x, training=training)\n",
+ " x = self.flatten(x)\n",
+ " x = self.fc1(x)\n",
+ " return x"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ },
+ "colab_type": "code",
+ "id": "gDkA05NE6QMs"
+ },
+ "outputs": [],
+ "source": [
+ "generator = Generator()\n",
+ "discriminator = Discriminator()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "0FMYgY_mPfTi"
+ },
+ "source": [
+ "## Define the loss functions and the optimizer\n",
+ "\n",
+ "* **Discriminator loss**\n",
+ " * The discriminator loss function takes 2 inputs; **real images, generated images**\n",
+ " * real_loss is a sigmoid cross entropy loss of the **real images** and an **array of ones(since these are the real images)**\n",
+ " * generated_loss is a sigmoid cross entropy loss of the **generated images** and an **array of zeros(since these are the fake images)**\n",
+ " * Then the total_loss is the sum of real_loss and the generated_loss\n",
+ " \n",
+ "* **Generator loss**\n",
+ " * It is a sigmoid cross entropy loss of the generated images and an **array of ones**\n",
+ " \n",
+ "\n",
+ "* The discriminator and the generator optimizers are different since we will train them separately."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ },
+ "colab_type": "code",
+ "id": "wkMNfBWlT-PV"
+ },
+ "outputs": [],
+ "source": [
+ "def discriminator_loss(real_output, generated_output):\n",
+ " # [1,1,...,1] with real output since it is true and we want\n",
+ " # our generated examples to look like it\n",
+ " real_loss = tf.losses.sigmoid_cross_entropy(multi_class_labels=tf.ones_like(real_output), logits=real_output)\n",
+ "\n",
+ " # [0,0,...,0] with generated images since they are fake\n",
+ " generated_loss = tf.losses.sigmoid_cross_entropy(multi_class_labels=tf.zeros_like(generated_output), logits=generated_output)\n",
+ "\n",
+ " total_loss = real_loss + generated_loss\n",
+ "\n",
+ " return total_loss"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ },
+ "colab_type": "code",
+ "id": "90BIcCKcDMxz"
+ },
+ "outputs": [],
+ "source": [
+ "def generator_loss(generated_output):\n",
+ " return tf.losses.sigmoid_cross_entropy(tf.ones_like(generated_output), generated_output)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ },
+ "colab_type": "code",
+ "id": "iWCn_PVdEJZ7"
+ },
+ "outputs": [],
+ "source": [
+ "discriminator_optimizer = tf.train.AdamOptimizer(1e-4)\n",
+ "generator_optimizer = tf.train.AdamOptimizer(1e-4)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "Rw1fkAczTQYh"
+ },
+ "source": [
+ "## Training\n",
+ "\n",
+ "* We start by iterating over the dataset\n",
+ "* The generator is given **noise as an input** which when passed through the generator model will output a image looking like a handwritten digit\n",
+ "* The discriminator is given the **real MNIST images as well as the generated images(from the generator)**.\n",
+ "* Next, we calculate the generator and the discriminator loss.\n",
+ "* Then, we calculate the gradients of loss with respect to both the generator and the discriminator variables(inputs) and apply those to the optimizer.\n",
+ "\n",
+ "## Generate Images\n",
+ "\n",
+ "* After training, its time to generate some images!\n",
+ "* We start by creating noise array as an input to the generator\n",
+ "* The generator will then convert the noise into handwritten images.\n",
+ "* Last step is to plot the predictions and **voila!**"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ },
+ "colab_type": "code",
+ "id": "NS2GWywBbAWo"
+ },
+ "outputs": [],
+ "source": [
+ "EPOCHS = 150\n",
+ "noise_dim = 100\n",
+ "num_examples_to_generate = 100\n",
+ "\n",
+ "# keeping the random vector constant for generation(prediction) so\n",
+ "# it will be easier to see the improvement of the gan.\n",
+ "random_vector_for_generation = tf.random_normal([num_examples_to_generate,\n",
+ " noise_dim])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ },
+ "colab_type": "code",
+ "id": "RmdVsmvhPxyy"
+ },
+ "outputs": [],
+ "source": [
+ "def generate_and_save_images(model, epoch, test_input):\n",
+ " # make sure the training parameter is set to False because we\n",
+ " # don't want to train the batchnorm layer when doing inference.\n",
+ " predictions = model(test_input, training=False)\n",
+ "\n",
+ " fig = plt.figure(figsize=(10,10))\n",
+ " \n",
+ " for i in range(predictions.shape[0]):\n",
+ " plt.subplot(10, 10, i+1)\n",
+ " plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')\n",
+ " plt.axis('off')\n",
+ " \n",
+ " # tight_layout minimizes the overlap between 2 sub-plots\n",
+ " plt.tight_layout()\n",
+ " plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))\n",
+ " plt.show()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ },
+ "colab_type": "code",
+ "id": "2M7LmLtGEMQJ"
+ },
+ "outputs": [],
+ "source": [
+ "def train(dataset, epochs, noise_dim): \n",
+ " for epoch in range(epochs):\n",
+ " start = time.time()\n",
+ " \n",
+ " for images in dataset:\n",
+ " # generating noise from a uniform distribution\n",
+ " noise = tf.random_normal([BATCH_SIZE, noise_dim])\n",
+ " \n",
+ " with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:\n",
+ " generated_images = generator(noise, training=True)\n",
+ " \n",
+ " real_output = discriminator(images, training=True)\n",
+ " generated_output = discriminator(generated_images, training=True)\n",
+ " \n",
+ " gen_loss = generator_loss(generated_output)\n",
+ " disc_loss = discriminator_loss(real_output, generated_output)\n",
+ " \n",
+ " gradients_of_generator = gen_tape.gradient(gen_loss, generator.variables)\n",
+ " gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.variables)\n",
+ " \n",
+ " generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.variables))\n",
+ " discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.variables))\n",
+ "\n",
+ " \n",
+ " if epoch % 10 == 0:\n",
+ " display.clear_output(wait=True)\n",
+ " generate_and_save_images(generator,\n",
+ " epoch + 1,\n",
+ " random_vector_for_generation)\n",
+ "\n",
+ " print ('Time taken for epoch {} is {} sec'.format(epoch + 1,\n",
+ " time.time()-start))\n",
+ " # generating after the final epoch\n",
+ " generate_and_save_images(generator,\n",
+ " epochs,\n",
+ " random_vector_for_generation)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ },
+ "colab_type": "code",
+ "id": "Ly3UN0SLLY2l"
+ },
+ "outputs": [],
+ "source": [
+ "train(train_dataset, EPOCHS, noise_dim)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "P4M_vIbUi7c0"
+ },
+ "source": [
+ "# Display an image using the epoch number"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ },
+ "colab_type": "code",
+ "id": "WfO5wCdclHGL"
+ },
+ "outputs": [],
+ "source": [
+ "def display_image(epoch_no):\n",
+ " plt.figure(figsize=(15,15))\n",
+ " plt.imshow(np.array(PIL.Image.open('image_at_epoch_{:04d}.png'.format(epoch_no))))\n",
+ " plt.axis('off')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ },
+ "colab_type": "code",
+ "id": "5x3q9_Oe5q0A"
+ },
+ "outputs": [],
+ "source": [
+ "display_image(EPOCHS)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "NywiH3nL8guF"
+ },
+ "source": [
+ "## Generate a GIF of all the saved images."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "xmO0Dmu2WICn"
+ },
+ "source": [
+ "\u003c!-- TODO(markdaoust): Remove the hack when Ipython version is updated --\u003e\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ },
+ "colab_type": "code",
+ "id": "IGKQgENQ8lEI"
+ },
+ "outputs": [],
+ "source": [
+ "with imageio.get_writer('dcgan.gif', mode='I') as writer:\n",
+ " filenames = glob.glob('image*.png')\n",
+ " filenames = sorted(filenames)\n",
+ " for filename in filenames:\n",
+ " image = imageio.imread(filename)\n",
+ " writer.append_data(image)\n",
+ " # this is a hack to display the gif inside the notebook\n",
+ " os.system('mv dcgan.gif dcgan.gif.png')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ },
+ "colab_type": "code",
+ "id": "uV0yiKpzNP1b"
+ },
+ "outputs": [],
+ "source": [
+ "display.Image(filename=\"dcgan.gif.png\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ },
+ "colab_type": "code",
+ "id": "4UJjSnIMOzOJ"
+ },
+ "outputs": [],
+ "source": [
+ ""
+ ]
+ }
+ ],
+ "metadata": {
+ "accelerator": "GPU",
+ "colab": {
+ "collapsed_sections": [],
+ "default_view": {},
+ "name": "dcgan.ipynb",
+ "private_outputs": true,
+ "provenance": [
+ {
+ "file_id": "1eb0NOTQapkYs3X0v-zL1x5_LFKgDISnp",
+ "timestamp": 1527173385672
+ }
+ ],
+ "toc_visible": true,
+ "version": "0.3.2",
+ "views": {}
+ },
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+}
diff --git a/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc.py b/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc.py
index 729d8525fa..275aee5130 100644
--- a/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc.py
+++ b/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc.py
@@ -54,7 +54,7 @@ class Dynamics(tf.keras.Model):
self.position_fn = neural_nets.GenericNet(x_dim, factor=2.)
self.momentum_fn = neural_nets.GenericNet(x_dim, factor=1.)
- self.eps = tfe.Variable(
+ self.eps = tf.Variable(
initial_value=eps, name="eps", dtype=tf.float32, trainable=True)
def apply_transition(self, position):
diff --git a/tensorflow/contrib/eager/python/examples/l2hmc/neural_nets.py b/tensorflow/contrib/eager/python/examples/l2hmc/neural_nets.py
index e230ad5e25..68e0bc3123 100644
--- a/tensorflow/contrib/eager/python/examples/l2hmc/neural_nets.py
+++ b/tensorflow/contrib/eager/python/examples/l2hmc/neural_nets.py
@@ -25,7 +25,6 @@ from __future__ import division
from __future__ import print_function
import tensorflow as tf
-import tensorflow.contrib.eager as tfe
class GenericNet(tf.keras.Model):
@@ -47,13 +46,13 @@ class GenericNet(tf.keras.Model):
# Scale
self.scale_layer = _custom_dense(x_dim, .001)
- self.coeff_scale = tfe.Variable(
+ self.coeff_scale = tf.Variable(
initial_value=tf.zeros([1, x_dim]), name='coeff_scale', trainable=True)
# Translation
self.translation_layer = _custom_dense(x_dim, factor=.001)
# Transformation
self.transformation_layer = _custom_dense(x_dim, .001)
- self.coeff_transformation = tfe.Variable(
+ self.coeff_transformation = tf.Variable(
initial_value=tf.zeros([1, x_dim]),
name='coeff_transformation',
trainable=True)
diff --git a/tensorflow/contrib/eager/python/examples/notebooks/custom_training.ipynb b/tensorflow/contrib/eager/python/examples/notebooks/custom_training.ipynb
index 591e2d0c85..5f1b48fa0d 100644
--- a/tensorflow/contrib/eager/python/examples/notebooks/custom_training.ipynb
+++ b/tensorflow/contrib/eager/python/examples/notebooks/custom_training.ipynb
@@ -118,7 +118,6 @@
"cell_type": "code",
"source": [
"import tensorflow as tf\n",
- "tfe = tf.contrib.eager # Shorthand for some symbols\n",
"\n",
"tf.enable_eager_execution()"
],
@@ -184,7 +183,7 @@
},
"cell_type": "code",
"source": [
- "v = tfe.Variable(1.0)\n",
+ "v = tf.Variable(1.0)\n",
"assert v.numpy() == 1.0\n",
"\n",
"# Re-assign the value\n",
@@ -258,8 +257,8 @@
" def __init__(self):\n",
" # Initialize variable to (5.0, 0.0)\n",
" # In practice, these should be initialized to random values.\n",
- " self.W = tfe.Variable(5.0)\n",
- " self.b = tfe.Variable(0.0)\n",
+ " self.W = tf.Variable(5.0)\n",
+ " self.b = tf.Variable(0.0)\n",
" \n",
" def __call__(self, x):\n",
" return self.W * x + self.b\n",
diff --git a/tensorflow/contrib/eager/python/examples/revnet/revnet_test.py b/tensorflow/contrib/eager/python/examples/revnet/revnet_test.py
index b2ac4b67c9..b0d0a5486d 100644
--- a/tensorflow/contrib/eager/python/examples/revnet/revnet_test.py
+++ b/tensorflow/contrib/eager/python/examples/revnet/revnet_test.py
@@ -138,7 +138,7 @@ class RevNetTest(tf.test.TestCase):
minval=0,
maxval=self.config.n_classes,
dtype=tf.int32)
- global_step = tfe.Variable(0., trainable=False)
+ global_step = tf.Variable(0., trainable=False)
model = revnet.RevNet(config=config)
model(x)
updates = model.get_updates_for(x)
diff --git a/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py b/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py
index c2340a293a..d64bf5354e 100644
--- a/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py
+++ b/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py
@@ -310,7 +310,7 @@ def main(_):
with tf.device("/device:GPU:0" if have_gpu else None):
# Make learning_rate a Variable so it can be included in the checkpoint
# and we can resume training with the last saved learning_rate.
- learning_rate = tfe.Variable(20.0, name="learning_rate")
+ learning_rate = tf.Variable(20.0, name="learning_rate")
model = PTBModel(corpus.vocab_size(), FLAGS.embedding_dim,
FLAGS.hidden_dim, FLAGS.num_layers, FLAGS.dropout,
use_cudnn_rnn)
diff --git a/tensorflow/contrib/eager/python/examples/sagan/sagan.py b/tensorflow/contrib/eager/python/examples/sagan/sagan.py
index 561be36c91..8130414985 100644
--- a/tensorflow/contrib/eager/python/examples/sagan/sagan.py
+++ b/tensorflow/contrib/eager/python/examples/sagan/sagan.py
@@ -62,7 +62,7 @@ class SelfAttentionModule(tf.keras.Model):
kernel_size=1,
strides=(1, 1),
data_format=data_format)
- self.scale = tfe.Variable(0., trainable=True)
+ self.scale = tf.Variable(0., trainable=True)
def call(self, x):
f = self.f(x)
diff --git a/tensorflow/contrib/eager/python/examples/workshop/2_models.ipynb b/tensorflow/contrib/eager/python/examples/workshop/2_models.ipynb
index 4f1410e00b..f3a65f5aab 100644
--- a/tensorflow/contrib/eager/python/examples/workshop/2_models.ipynb
+++ b/tensorflow/contrib/eager/python/examples/workshop/2_models.ipynb
@@ -69,7 +69,7 @@
"cell_type": "code",
"source": [
"# Creating variables\n",
- "v = tfe.Variable(1.0)\n",
+ "v = tf.Variable(1.0)\n",
"v"
],
"execution_count": 2,
diff --git a/tensorflow/contrib/eager/python/tfe_test.py b/tensorflow/contrib/eager/python/tfe_test.py
index db50b33af2..4454abfb96 100644
--- a/tensorflow/contrib/eager/python/tfe_test.py
+++ b/tensorflow/contrib/eager/python/tfe_test.py
@@ -27,7 +27,6 @@ from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import numerics
-from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.summary import summary
from tensorflow.python.summary.writer import writer
@@ -45,12 +44,6 @@ class TFETest(test_util.TensorFlowTestCase):
r'indices = 7 is not in \[0, 3\)'):
array_ops.gather([0, 1, 2], 7)
- def testVariableError(self):
- with self.assertRaisesRegexp(
- RuntimeError,
- r'Variable not supported when eager execution is enabled'):
- variables.Variable(initial_value=1.0)
-
def testGradients(self):
def square(x):
diff --git a/tensorflow/contrib/lite/delegates/eager/BUILD b/tensorflow/contrib/lite/delegates/eager/BUILD
index 9d8c20e96f..9f31ffdf67 100644
--- a/tensorflow/contrib/lite/delegates/eager/BUILD
+++ b/tensorflow/contrib/lite/delegates/eager/BUILD
@@ -39,6 +39,38 @@ cc_test(
)
cc_library(
+ name = "delegate_data",
+ srcs = ["delegate_data.cc"],
+ hdrs = ["delegate_data.h"],
+ tags = [
+ "no_oss",
+ "tflite_not_portable",
+ ],
+ deps = [
+ ":buffer_map",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:lib",
+ "//tensorflow/core/common_runtime/eager:context",
+ ],
+)
+
+cc_test(
+ name = "delegate_data_test",
+ size = "small",
+ srcs = ["delegate_data_test.cc"],
+ tags = [
+ "tflite_not_portable",
+ ],
+ deps = [
+ ":delegate_data",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite:util",
+ "//tensorflow/contrib/lite/testing:util",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+cc_library(
name = "util",
srcs = ["util.cc"],
hdrs = ["util.h"],
diff --git a/tensorflow/contrib/lite/delegates/eager/buffer_map.cc b/tensorflow/contrib/lite/delegates/eager/buffer_map.cc
index e4a780b735..1d6453f498 100644
--- a/tensorflow/contrib/lite/delegates/eager/buffer_map.cc
+++ b/tensorflow/contrib/lite/delegates/eager/buffer_map.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/core/framework/log_memory.h"
namespace tflite {
+namespace eager {
namespace {
// A tensor buffer that is allocated, deallocated and populated by TF Lite.
class TfLiteTensorBuffer : public tensorflow::TensorBuffer {
@@ -102,4 +103,5 @@ void BufferMap::SetFromTensorFlow(int tensor_index, tensorflow::Tensor tensor) {
id_to_tensor_[tensor_index] = std::move(tensor);
}
+} // namespace eager
} // namespace tflite
diff --git a/tensorflow/contrib/lite/delegates/eager/buffer_map.h b/tensorflow/contrib/lite/delegates/eager/buffer_map.h
index 922f67f574..a28329ae7d 100644
--- a/tensorflow/contrib/lite/delegates/eager/buffer_map.h
+++ b/tensorflow/contrib/lite/delegates/eager/buffer_map.h
@@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h"
namespace tflite {
+namespace eager {
// Maps a TF Lite tensor index into a TensorFlow tensor.
//
@@ -54,6 +55,7 @@ class BufferMap {
std::map<int, tensorflow::Tensor> id_to_tensor_;
};
+} // namespace eager
} // namespace tflite
#endif // TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_BUFFER_MAP_H_
diff --git a/tensorflow/contrib/lite/delegates/eager/buffer_map_test.cc b/tensorflow/contrib/lite/delegates/eager/buffer_map_test.cc
index c447eeaa05..dcb3f6c941 100644
--- a/tensorflow/contrib/lite/delegates/eager/buffer_map_test.cc
+++ b/tensorflow/contrib/lite/delegates/eager/buffer_map_test.cc
@@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/contrib/lite/util.h"
namespace tflite {
+namespace eager {
namespace {
using ::testing::ElementsAre;
@@ -163,6 +164,7 @@ TEST(BufferMapTest, TensorFlowOverwritesTfLite) {
}
} // namespace
+} // namespace eager
} // namespace tflite
int main(int argc, char** argv) {
diff --git a/tensorflow/contrib/lite/delegates/eager/delegate_data.cc b/tensorflow/contrib/lite/delegates/eager/delegate_data.cc
new file mode 100644
index 0000000000..29687694bd
--- /dev/null
+++ b/tensorflow/contrib/lite/delegates/eager/delegate_data.cc
@@ -0,0 +1,46 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/delegates/eager/delegate_data.h"
+
+#include "tensorflow/core/common_runtime/device_factory.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tflite {
+namespace eager {
+tensorflow::Status DelegateData::Create(std::unique_ptr<DelegateData>* data) {
+ std::vector<tensorflow::Device*> devices;
+
+ TF_RETURN_IF_ERROR(tensorflow::DeviceFactory::AddDevices(
+ tensorflow::SessionOptions(), "/device:cpu:*", &devices));
+
+ std::unique_ptr<tensorflow::DeviceMgr> device_mgr(
+ new tensorflow::DeviceMgr(devices));
+ // Note that Rendezvous is ref-counted so it will be automatically deleted.
+ tensorflow::Rendezvous* rendezvous =
+ new tensorflow::IntraProcessRendezvous(device_mgr.get());
+ data->reset(new DelegateData(new tensorflow::EagerContext(
+ tensorflow::SessionOptions(),
+ tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
+ /*async=*/false, std::move(device_mgr), rendezvous)));
+ return tensorflow::Status();
+}
+
+DelegateData::DelegateData(tensorflow::EagerContext* eager_context)
+ : eager_context_(eager_context) {}
+
+DelegateData::~DelegateData() {}
+
+} // namespace eager
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/delegates/eager/delegate_data.h b/tensorflow/contrib/lite/delegates/eager/delegate_data.h
new file mode 100644
index 0000000000..8a0e8ba8bf
--- /dev/null
+++ b/tensorflow/contrib/lite/delegates/eager/delegate_data.h
@@ -0,0 +1,48 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_DELEGATE_DATA_H_
+#define TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_DELEGATE_DATA_H_
+
+#include "tensorflow/contrib/lite/delegates/eager/buffer_map.h"
+#include "tensorflow/core/common_runtime/eager/context.h"
+
+namespace tflite {
+namespace eager {
+
+// Data kept by the Eager delegate for the lifetime of an Interpreter.
+class DelegateData {
+ public:
+ // Create a new DelegateData, initialized with a newly-created EagerContext.
+ static tensorflow::Status Create(std::unique_ptr<DelegateData>* data);
+
+ ~DelegateData();
+
+ // The EagerContext that is required for execution of Eager Ops.
+ tensorflow::EagerContext* GetEagerContext() { return eager_context_.get(); }
+
+ // Map from TF Lite tensor index to TensorFlow tensor.
+ BufferMap* GetBufferMap() { return &buffer_map_; }
+
+ private:
+ explicit DelegateData(tensorflow::EagerContext* eager_context);
+
+ std::unique_ptr<tensorflow::EagerContext> eager_context_;
+ BufferMap buffer_map_;
+};
+
+} // namespace eager
+} // namespace tflite
+
+#endif // TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_DELEGATE_DATA_H_
diff --git a/tensorflow/contrib/lite/delegates/eager/delegate_data_test.cc b/tensorflow/contrib/lite/delegates/eager/delegate_data_test.cc
new file mode 100644
index 0000000000..30251b8f82
--- /dev/null
+++ b/tensorflow/contrib/lite/delegates/eager/delegate_data_test.cc
@@ -0,0 +1,44 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/delegates/eager/delegate_data.h"
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/testing/util.h"
+
+namespace tflite {
+namespace eager {
+namespace {
+
+TEST(DelegateDataTest, Basic) {
+ std::unique_ptr<DelegateData> data;
+ // We only check for success because it is hard to make initialization fail.
+ // It only happens if we manage to not link the CPU device factory into the
+ // binary.
+ EXPECT_TRUE(DelegateData::Create(&data).ok());
+
+ EXPECT_NE(data->GetEagerContext(), nullptr);
+ EXPECT_NE(data->GetBufferMap(), nullptr);
+}
+
+} // namespace
+} // namespace eager
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ ::tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/delegates/eager/util.cc b/tensorflow/contrib/lite/delegates/eager/util.cc
index e1879bdaff..4426c653e6 100644
--- a/tensorflow/contrib/lite/delegates/eager/util.cc
+++ b/tensorflow/contrib/lite/delegates/eager/util.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/contrib/lite/delegates/eager/util.h"
namespace tflite {
+namespace eager {
TfLiteStatus ConvertStatus(TfLiteContext* context,
const tensorflow::Status& status) {
@@ -67,4 +68,5 @@ TF_DataType GetTensorFlowDataType(TfLiteType type) {
}
}
+} // namespace eager
} // namespace tflite
diff --git a/tensorflow/contrib/lite/delegates/eager/util.h b/tensorflow/contrib/lite/delegates/eager/util.h
index 12b33b9b49..a9407be071 100644
--- a/tensorflow/contrib/lite/delegates/eager/util.h
+++ b/tensorflow/contrib/lite/delegates/eager/util.h
@@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/status.h"
namespace tflite {
+namespace eager {
// Converts a tensorflow:Status into a TfLiteStatus. If the original status
// represented an error, reports it using the given 'context'.
@@ -35,6 +36,7 @@ TfLiteStatus CopyShape(TfLiteContext* context, const tensorflow::Tensor& src,
// Returns the TF C API Data type that corresponds to the given TfLiteType.
TF_DataType GetTensorFlowDataType(TfLiteType type);
+} // namespace eager
} // namespace tflite
#endif // TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_UTIL_H_
diff --git a/tensorflow/contrib/lite/delegates/eager/util_test.cc b/tensorflow/contrib/lite/delegates/eager/util_test.cc
index 53ed4db972..c4fbf54127 100644
--- a/tensorflow/contrib/lite/delegates/eager/util_test.cc
+++ b/tensorflow/contrib/lite/delegates/eager/util_test.cc
@@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/contrib/lite/testing/util.h"
namespace tflite {
+namespace eager {
namespace {
using tensorflow::DT_FLOAT;
@@ -102,6 +103,7 @@ TEST(UtilTest, TypeConversions) {
}
} // namespace
+} // namespace eager
} // namespace tflite
int main(int argc, char** argv) {
diff --git a/tensorflow/contrib/lite/kernels/pow_test.cc b/tensorflow/contrib/lite/kernels/pow_test.cc
index 474d323bc3..74b3aef5bd 100644
--- a/tensorflow/contrib/lite/kernels/pow_test.cc
+++ b/tensorflow/contrib/lite/kernels/pow_test.cc
@@ -50,22 +50,22 @@ class PowOpModel : public SingleOpModel {
};
TEST(PowOpModel, Simple) {
- PowOpModel<int32> model({TensorType_INT32, {1, 2, 2, 1}},
- {TensorType_INT32, {1, 2, 2, 1}},
- {TensorType_INT32, {}});
- model.PopulateTensor<int32>(model.input1(), {12, 2, 7, 8});
- model.PopulateTensor<int32>(model.input2(), {1, 2, 3, 1});
+ PowOpModel<int32_t> model({TensorType_INT32, {1, 2, 2, 1}},
+ {TensorType_INT32, {1, 2, 2, 1}},
+ {TensorType_INT32, {}});
+ model.PopulateTensor<int32_t>(model.input1(), {12, 2, 7, 8});
+ model.PopulateTensor<int32_t>(model.input2(), {1, 2, 3, 1});
model.Invoke();
EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 2, 2, 1));
EXPECT_THAT(model.GetOutput(), ElementsAre(12, 4, 343, 8));
}
TEST(PowOpModel, NegativeAndZeroValue) {
- PowOpModel<int32> model({TensorType_INT32, {1, 2, 2, 1}},
- {TensorType_INT32, {1, 2, 2, 1}},
- {TensorType_INT32, {}});
- model.PopulateTensor<int32>(model.input1(), {0, 2, -7, 8});
- model.PopulateTensor<int32>(model.input2(), {1, 2, 3, 0});
+ PowOpModel<int32_t> model({TensorType_INT32, {1, 2, 2, 1}},
+ {TensorType_INT32, {1, 2, 2, 1}},
+ {TensorType_INT32, {}});
+ model.PopulateTensor<int32_t>(model.input1(), {0, 2, -7, 8});
+ model.PopulateTensor<int32_t>(model.input2(), {1, 2, 3, 0});
model.Invoke();
EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 2, 2, 1));
EXPECT_THAT(model.GetOutput(), ElementsAre(0, 4, -343, 1));
@@ -98,10 +98,10 @@ TEST(PowOpModel, NegativeFloatTest) {
}
TEST(PowOpModel, BroadcastTest) {
- PowOpModel<int32> model({TensorType_INT32, {1, 2, 2, 1}},
- {TensorType_INT32, {1}}, {TensorType_INT32, {}});
- model.PopulateTensor<int32>(model.input1(), {12, 2, 7, 8});
- model.PopulateTensor<int32>(model.input2(), {4});
+ PowOpModel<int32_t> model({TensorType_INT32, {1, 2, 2, 1}},
+ {TensorType_INT32, {1}}, {TensorType_INT32, {}});
+ model.PopulateTensor<int32_t>(model.input1(), {12, 2, 7, 8});
+ model.PopulateTensor<int32_t>(model.input2(), {4});
model.Invoke();
EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 2, 2, 1));
EXPECT_THAT(model.GetOutput(), ElementsAre(20736, 16, 2401, 4096));
diff --git a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc
index c38b692dcd..f97919363b 100644
--- a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc
+++ b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc
@@ -340,6 +340,8 @@ PyObject* InterpreterWrapper::SetTensor(int i, PyObject* value) {
namespace {
+// Checks to see if a tensor access can succeed (returns nullptr on error).
+// Otherwise returns Py_None.
PyObject* CheckGetTensorArgs(Interpreter* interpreter_, int tensor_index,
TfLiteTensor** tensor, int* type_num) {
TFLITE_PY_ENSURE_VALID_INTERPRETER();
@@ -362,7 +364,7 @@ PyObject* CheckGetTensorArgs(Interpreter* interpreter_, int tensor_index,
return nullptr;
}
- return nullptr;
+ Py_RETURN_NONE;
}
} // namespace
@@ -371,10 +373,12 @@ PyObject* InterpreterWrapper::GetTensor(int i) const {
// Sanity check accessor
TfLiteTensor* tensor = nullptr;
int type_num = 0;
- if (PyObject* pynone_or_nullptr =
- CheckGetTensorArgs(interpreter_.get(), i, &tensor, &type_num)) {
- return pynone_or_nullptr;
- }
+
+ PyObject* check_result =
+ CheckGetTensorArgs(interpreter_.get(), i, &tensor, &type_num);
+ if (check_result == nullptr) return check_result;
+ Py_XDECREF(check_result);
+
std::vector<npy_intp> dims(tensor->dims->data,
tensor->dims->data + tensor->dims->size);
// Make a buffer copy but we must tell Numpy It owns that data or else
@@ -396,10 +400,11 @@ PyObject* InterpreterWrapper::tensor(PyObject* base_object, int i) {
// Sanity check accessor
TfLiteTensor* tensor = nullptr;
int type_num = 0;
- if (PyObject* pynone_or_nullptr =
- CheckGetTensorArgs(interpreter_.get(), i, &tensor, &type_num)) {
- return pynone_or_nullptr;
- }
+
+ PyObject* check_result =
+ CheckGetTensorArgs(interpreter_.get(), i, &tensor, &type_num);
+ if (check_result == nullptr) return check_result;
+ Py_XDECREF(check_result);
std::vector<npy_intp> dims(tensor->dims->data,
tensor->dims->data + tensor->dims->size);
diff --git a/tensorflow/contrib/lite/toco/BUILD b/tensorflow/contrib/lite/toco/BUILD
index 5e197e584c..c88079717d 100644
--- a/tensorflow/contrib/lite/toco/BUILD
+++ b/tensorflow/contrib/lite/toco/BUILD
@@ -93,6 +93,7 @@ cc_library(
":runtime",
":toco_port",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/types:optional",
],
)
@@ -246,6 +247,7 @@ cc_library(
"graph_transformations/resolve_constant_transpose.cc",
"graph_transformations/resolve_constant_unary.cc",
"graph_transformations/resolve_fake_quant_args_from_vars.cc",
+ "graph_transformations/resolve_gather_attributes.cc",
"graph_transformations/resolve_multiply_by_zero.cc",
"graph_transformations/resolve_pad_attributes.cc",
"graph_transformations/resolve_padv2_attributes.cc",
diff --git a/tensorflow/contrib/lite/toco/export_tensorflow.cc b/tensorflow/contrib/lite/toco/export_tensorflow.cc
index 4508aa6632..f9a6d31d60 100644
--- a/tensorflow/contrib/lite/toco/export_tensorflow.cc
+++ b/tensorflow/contrib/lite/toco/export_tensorflow.cc
@@ -215,6 +215,30 @@ void ConvertFloatTensorConst(const Model& model, const string& name,
LegacyScalarPolicy::kAvoidLegacyScalars);
}
+void ConvertBoolTensorConst(const Model& model, const string& name,
+ GraphDef* tensorflow_graph) {
+ if (HasAlreadyExportedConst(name, *tensorflow_graph)) {
+ return;
+ }
+ CHECK(model.HasArray(name));
+ const auto& array = model.GetArray(name);
+ tensorflow::NodeDef* const_op = tensorflow_graph->add_node();
+ const_op->set_op("Const");
+ const_op->set_name(name);
+ (*const_op->mutable_attr())["dtype"].set_type(DT_BOOL);
+ auto* tensor = (*const_op->mutable_attr())["value"].mutable_tensor();
+ tensor->set_dtype(DT_BOOL);
+ const auto& data = array.GetBuffer<ArrayDataType::kBool>().data;
+ for (auto index : data) {
+ tensor->add_bool_val(index);
+ }
+ const auto& array_shape = array.shape();
+ auto* shape = tensor->mutable_tensor_shape();
+ for (int i = 0; i < array_shape.dimensions_count(); i++) {
+ shape->add_dim()->set_size(array_shape.dims(i));
+ }
+}
+
void ConvertIntTensorConst(const Model& model, const string& name,
GraphDef* tensorflow_graph) {
if (HasAlreadyExportedConst(name, *tensorflow_graph)) {
@@ -621,7 +645,8 @@ void ConvertAddOperator(const Model& model, const AddOperator& src_op,
CHECK_EQ(src_op.inputs.size(), 2);
*add_op->add_input() = src_op.inputs[0];
*add_op->add_input() = src_op.inputs[1];
- (*add_op->mutable_attr())["T"].set_type(DT_FLOAT);
+ (*add_op->mutable_attr())["T"].set_type(
+ GetTensorFlowDataType(model, src_op.outputs[0]));
}
void ConvertAddNOperator(const Model& model, const AddNOperator& src_op,
@@ -633,7 +658,8 @@ void ConvertAddNOperator(const Model& model, const AddNOperator& src_op,
*add_op->add_input() = input;
}
(*add_op->mutable_attr())["N"].set_i(src_op.inputs.size());
- (*add_op->mutable_attr())["T"].set_type(DT_FLOAT);
+ (*add_op->mutable_attr())["T"].set_type(
+ GetTensorFlowDataType(model, src_op.outputs[0]));
}
void ConvertMulOperator(const Model& model, const MulOperator& src_op,
@@ -644,16 +670,18 @@ void ConvertMulOperator(const Model& model, const MulOperator& src_op,
CHECK_EQ(src_op.inputs.size(), 2);
*add_op->add_input() = src_op.inputs[0];
*add_op->add_input() = src_op.inputs[1];
- (*add_op->mutable_attr())["T"].set_type(DT_FLOAT);
+ (*add_op->mutable_attr())["T"].set_type(
+ GetTensorFlowDataType(model, src_op.outputs[0]));
}
-void ConvertReluOperator(const ReluOperator& src_op,
+void ConvertReluOperator(const Model& model, const ReluOperator& src_op,
GraphDef* tensorflow_graph) {
tensorflow::NodeDef* relu_op = tensorflow_graph->add_node();
relu_op->set_op("Relu");
relu_op->set_name(src_op.outputs[0]);
*relu_op->add_input() = src_op.inputs[0];
- (*relu_op->mutable_attr())["T"].set_type(DT_FLOAT);
+ (*relu_op->mutable_attr())["T"].set_type(
+ GetTensorFlowDataType(model, src_op.outputs[0]));
}
void ConvertRelu1Operator(const Relu1Operator& src_op,
@@ -1110,13 +1138,27 @@ void ConvertFloorOperator(const Model& model, const FloorOperator& src_op,
void ConvertGatherOperator(const Model& model, const GatherOperator& src_op,
GraphDef* tensorflow_graph) {
tensorflow::NodeDef* gather_op = tensorflow_graph->add_node();
- gather_op->set_op("Gather");
+ gather_op->set_op("GatherV2");
gather_op->set_name(src_op.outputs[0]);
- CHECK_EQ(src_op.inputs.size(), 2);
*gather_op->add_input() = src_op.inputs[0];
*gather_op->add_input() = src_op.inputs[1];
+ if (!src_op.axis) {
+ // Dynamic axis.
+ CHECK_EQ(src_op.inputs.size(), 3);
+ *gather_op->add_input() = src_op.inputs[2];
+ } else {
+ // Constant axis.
+ CHECK_EQ(src_op.inputs.size(), 2);
+ const string gather_axis =
+ AvailableArrayName(model, gather_op->name() + "/axis");
+ CreateIntTensorConst(gather_axis, {src_op.axis.value()}, {},
+ tensorflow_graph);
+ *gather_op->add_input() = gather_axis;
+ }
+
(*gather_op->mutable_attr())["Tindices"].set_type(DT_INT32);
+ (*gather_op->mutable_attr())["Taxis"].set_type(DT_INT32);
const tensorflow::DataType params_type =
GetTensorFlowDataType(model, src_op.inputs[0]);
(*gather_op->mutable_attr())["Tparams"].set_type(params_type);
@@ -1638,6 +1680,9 @@ void ConvertReduceOperator(const Model& model, const T& src_op,
const tensorflow::DataType params_type =
GetTensorFlowDataType(model, src_op.inputs[0]);
(*new_op->mutable_attr())["T"].set_type(params_type);
+ const tensorflow::DataType indices_type =
+ GetTensorFlowDataType(model, src_op.inputs[1]);
+ (*new_op->mutable_attr())["Tidx"].set_type(indices_type);
if (src_op.keep_dims) {
(*new_op->mutable_attr())["keep_dims"].set_b(true);
@@ -1694,43 +1739,43 @@ void ConvertSubOperator(const Model& model, const SubOperator& src_op,
void ConvertTensorFlowMinimumOperator(const Model& model,
const TensorFlowMinimumOperator& src_op,
GraphDef* tensorflow_graph) {
- tensorflow::NodeDef* sub_op = tensorflow_graph->add_node();
- sub_op->set_op("Minimum");
- sub_op->set_name(src_op.outputs[0]);
+ tensorflow::NodeDef* min_op = tensorflow_graph->add_node();
+ min_op->set_op("Minimum");
+ min_op->set_name(src_op.outputs[0]);
CHECK_EQ(src_op.inputs.size(), 2);
- *sub_op->add_input() = src_op.inputs[0];
- *sub_op->add_input() = src_op.inputs[1];
+ *min_op->add_input() = src_op.inputs[0];
+ *min_op->add_input() = src_op.inputs[1];
const tensorflow::DataType data_type =
GetTensorFlowDataType(model, src_op.inputs[0]);
- (*sub_op->mutable_attr())["T"].set_type(data_type);
+ (*min_op->mutable_attr())["T"].set_type(data_type);
}
void ConvertTensorFlowMaximumOperator(const Model& model,
const TensorFlowMaximumOperator& src_op,
GraphDef* tensorflow_graph) {
- tensorflow::NodeDef* sub_op = tensorflow_graph->add_node();
- sub_op->set_op("Maximum");
- sub_op->set_name(src_op.outputs[0]);
+ tensorflow::NodeDef* max_op = tensorflow_graph->add_node();
+ max_op->set_op("Maximum");
+ max_op->set_name(src_op.outputs[0]);
CHECK_EQ(src_op.inputs.size(), 2);
- *sub_op->add_input() = src_op.inputs[0];
- *sub_op->add_input() = src_op.inputs[1];
+ *max_op->add_input() = src_op.inputs[0];
+ *max_op->add_input() = src_op.inputs[1];
const tensorflow::DataType data_type =
GetTensorFlowDataType(model, src_op.inputs[0]);
- (*sub_op->mutable_attr())["T"].set_type(data_type);
+ (*max_op->mutable_attr())["T"].set_type(data_type);
}
void ConvertSelectOperator(const Model& model, const SelectOperator& src_op,
GraphDef* tensorflow_graph) {
- tensorflow::NodeDef* sub_op = tensorflow_graph->add_node();
- sub_op->set_op("Select");
- sub_op->set_name(src_op.outputs[0]);
+ tensorflow::NodeDef* select_op = tensorflow_graph->add_node();
+ select_op->set_op("Select");
+ select_op->set_name(src_op.outputs[0]);
CHECK_EQ(src_op.inputs.size(), 3);
- *sub_op->add_input() = src_op.inputs[0];
- *sub_op->add_input() = src_op.inputs[1];
- *sub_op->add_input() = src_op.inputs[2];
+ *select_op->add_input() = src_op.inputs[0];
+ *select_op->add_input() = src_op.inputs[1];
+ *select_op->add_input() = src_op.inputs[2];
const tensorflow::DataType data_type =
GetTensorFlowDataType(model, src_op.inputs[1]);
- (*sub_op->mutable_attr())["T"].set_type(data_type);
+ (*select_op->mutable_attr())["T"].set_type(data_type);
}
void ConvertTileOperator(const Model& model,
@@ -1753,11 +1798,14 @@ void ConvertTileOperator(const Model& model,
void ConvertTopKV2Operator(const Model& model, const TopKV2Operator& src_op,
GraphDef* tensorflow_graph) {
tensorflow::NodeDef* topk_op = tensorflow_graph->add_node();
- topk_op->set_op("TOPKV2");
+ topk_op->set_op("TopKV2");
topk_op->set_name(src_op.outputs[0]);
CHECK_EQ(src_op.inputs.size(), 2);
*topk_op->add_input() = src_op.inputs[0];
*topk_op->add_input() = src_op.inputs[1];
+ const tensorflow::DataType data_type =
+ GetTensorFlowDataType(model, src_op.inputs[0]);
+ (*topk_op->mutable_attr())["T"].set_type(data_type);
(*topk_op->mutable_attr())["sorted"].set_b(true);
}
@@ -1828,6 +1876,43 @@ void ConvertPowOperator(const Model& model, const PowOperator& src_op,
(*pow_op->mutable_attr())["T"].set_type(data_type);
}
+void ConvertAnyOperator(const Model& model, const AnyOperator& src_op,
+ GraphDef* tensorflow_graph) {
+ tensorflow::NodeDef* any_op = tensorflow_graph->add_node();
+ any_op->set_op("Any");
+ any_op->set_name(src_op.outputs[0]);
+ CHECK_EQ(src_op.inputs.size(), 2);
+ for (int i = 0; i < 2; ++i) {
+ *any_op->add_input() = src_op.inputs[i];
+ }
+ const tensorflow::DataType data_type =
+ GetTensorFlowDataType(model, src_op.inputs[1]);
+ (*any_op->mutable_attr())["Tidx"].set_type(data_type);
+ (*any_op->mutable_attr())["keep_dims"].set_b(src_op.keep_dims);
+}
+
+void ConvertLogicalAndOperator(const Model& model,
+ const LogicalAndOperator& src_op,
+ GraphDef* tensorflow_graph) {
+ tensorflow::NodeDef* logical_op = tensorflow_graph->add_node();
+ logical_op->set_op("LogicalAnd");
+ logical_op->set_name(src_op.outputs[0]);
+ CHECK_EQ(src_op.inputs.size(), 2);
+ for (int i = 0; i < 2; ++i) {
+ *logical_op->add_input() = src_op.inputs[i];
+ }
+}
+
+void ConvertLogicalNotOperator(const Model& model,
+ const LogicalNotOperator& src_op,
+ GraphDef* tensorflow_graph) {
+ tensorflow::NodeDef* logical_op = tensorflow_graph->add_node();
+ logical_op->set_op("LogicalNot");
+ logical_op->set_name(src_op.outputs[0]);
+ CHECK_EQ(src_op.inputs.size(), 1);
+ *logical_op->add_input() = src_op.inputs[0];
+}
+
void ConvertOperator(const Model& model, const Operator& src_op,
GraphDef* tensorflow_graph) {
if (src_op.fused_activation_function != FusedActivationFunctionType::kNone) {
@@ -1864,7 +1949,7 @@ void ConvertOperator(const Model& model, const Operator& src_op,
ConvertMulOperator(model, static_cast<const MulOperator&>(src_op),
tensorflow_graph);
} else if (src_op.type == OperatorType::kRelu) {
- ConvertReluOperator(static_cast<const ReluOperator&>(src_op),
+ ConvertReluOperator(model, static_cast<const ReluOperator&>(src_op),
tensorflow_graph);
} else if (src_op.type == OperatorType::kRelu1) {
ConvertRelu1Operator(static_cast<const Relu1Operator&>(src_op),
@@ -1974,6 +2059,10 @@ void ConvertOperator(const Model& model, const Operator& src_op,
ConvertReduceOperator(model,
static_cast<const TensorFlowProdOperator&>(src_op),
tensorflow_graph, "Prod");
+ } else if (src_op.type == OperatorType::kReduceMin) {
+ ConvertReduceOperator(model,
+ static_cast<const TensorFlowMaxOperator&>(src_op),
+ tensorflow_graph, "Min");
} else if (src_op.type == OperatorType::kReduceMax) {
ConvertReduceOperator(model,
static_cast<const TensorFlowMaxOperator&>(src_op),
@@ -2060,6 +2149,17 @@ void ConvertOperator(const Model& model, const Operator& src_op,
} else if (src_op.type == OperatorType::kPow) {
ConvertPowOperator(model, static_cast<const PowOperator&>(src_op), "Pow",
tensorflow_graph);
+ } else if (src_op.type == OperatorType::kAny) {
+ ConvertAnyOperator(model, static_cast<const AnyOperator&>(src_op),
+ tensorflow_graph);
+ } else if (src_op.type == OperatorType::kLogicalAnd) {
+ ConvertLogicalAndOperator(model,
+ static_cast<const LogicalAndOperator&>(src_op),
+ tensorflow_graph);
+ } else if (src_op.type == OperatorType::kLogicalNot) {
+ ConvertLogicalNotOperator(model,
+ static_cast<const LogicalNotOperator&>(src_op),
+ tensorflow_graph);
} else {
LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(src_op.type);
}
@@ -2138,6 +2238,9 @@ void ExportTensorFlowGraphDefImplementation(const Model& model,
const auto& array = *array_pair.second;
if (array.buffer) {
switch (array.data_type) {
+ case ArrayDataType::kBool:
+ ConvertBoolTensorConst(model, array_name, tensorflow_graph);
+ break;
case ArrayDataType::kFloat:
ConvertFloatTensorConst(model, array_name, tensorflow_graph);
break;
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/convert_expanddims_to_reshape.cc b/tensorflow/contrib/lite/toco/graph_transformations/convert_expanddims_to_reshape.cc
index 56f48d47de..310a88484c 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/convert_expanddims_to_reshape.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/convert_expanddims_to_reshape.cc
@@ -40,11 +40,6 @@ bool ConvertExpandDimsToReshape::Run(Model* model, std::size_t op_index) {
// Yield until input dims have been resolved.
return false;
}
- if (input_array.shape().dimensions_count() == 0) {
- // Input array cannot be 0-D.
- // (Unsure if this is TF behavior, but was required to get a test to pass.)
- return false;
- }
const auto& axis_array = model->GetArray(expand_op->inputs[1]);
if (!axis_array.has_shape()) {
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h
index 5cee08fd4c..b7634e28c6 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h
+++ b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h
@@ -195,6 +195,7 @@ DECLARE_GRAPH_TRANSFORMATION(Dequantize)
DECLARE_GRAPH_TRANSFORMATION(UnpartitionEmbeddingLookup)
DECLARE_GRAPH_TRANSFORMATION(ShuffleFCWeights)
DECLARE_GRAPH_TRANSFORMATION(ResolveFakeQuantArgsFromVars)
+DECLARE_GRAPH_TRANSFORMATION(ResolveGatherAttributes)
class PropagateDefaultMinMax : public GraphTransformation {
public:
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
index 670bcf64e7..3dda536ef7 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
@@ -62,6 +62,9 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) {
case OperatorType::kGreaterEqual:
case OperatorType::kEqual:
case OperatorType::kNotEqual:
+ case OperatorType::kAny:
+ case OperatorType::kLogicalAnd:
+ case OperatorType::kLogicalNot:
// These operators unconditionally produce bool outputs
SetDataTypeForAllOutputs(model, op, ArrayDataType::kBool);
break;
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
index 5e2ba0eca7..62ed5c46e9 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
@@ -437,6 +437,7 @@ void ProcessTensorFlowReshapeOperator(Model* model,
product_non_wildcard_dims *= shape_data[i];
}
}
+
const int input_flat_size = RequiredBufferSizeForShape(input_shape);
if (has_wildcard) {
CHECK_GE(input_flat_size, product_non_wildcard_dims)
@@ -445,6 +446,12 @@ void ProcessTensorFlowReshapeOperator(Model* model,
<< op->outputs[0] << "\". Are your input shapes correct?";
shape_data[wildcard_index] = input_flat_size / product_non_wildcard_dims;
}
+
+ if (shape_data.size() == 1 && shape_data[0] == 0) {
+ // We have reshaped a scalar, so preserve as a scalar.
+ shape_data.clear();
+ }
+
auto& output_shape = *output_array.mutable_shape();
*output_shape.mutable_dims() = shape_data;
CHECK_EQ(input_flat_size, RequiredBufferSizeForShape(output_shape))
@@ -522,7 +529,7 @@ void ProcessAddNOperator(Model* model, Operator* op) {
bool KeepDims(const Operator& op) {
switch (op.type) {
- case OperatorType::kMin: // Reduction Min
+ case OperatorType::kReduceMin: // Reduction Min
return static_cast<const TensorFlowMinOperator&>(op).keep_dims;
case OperatorType::kReduceMax: // Reduction Max
return static_cast<const TensorFlowMaxOperator&>(op).keep_dims;
@@ -1036,17 +1043,28 @@ void ProcessGatherOperator(Model* model, GatherOperator* op) {
return;
}
+ // Yield until the axis has been resolved.
+ if (!op->axis) {
+ return;
+ }
+ int axis = op->axis.value();
+
const auto& input_shape = input_array.shape();
const auto& indices_shape = indices_array.shape();
QCHECK_GE(input_shape.dimensions_count(), 1);
op->input_rank = input_shape.dimensions_count();
+ QCHECK_LT(axis, op->input_rank);
- // Copy the input dimensions to the output except for dimension 0,
+ // Copy the input dimensions to the output except for the axis dimensions
// where the dimension of indices_shape is used.
- // TODO(mgubin): if axis != 0 this is not true, change when it's supported.
auto output_dims = output_array.mutable_shape()->mutable_dims();
- output_dims->push_back(indices_shape.dims(0));
- for (int dim = 1; dim < input_shape.dimensions_count(); dim++) {
+ for (int dim = 0; dim < axis; ++dim) {
+ output_dims->push_back(input_shape.dims(dim));
+ }
+ for (int dim = 0; dim < indices_shape.dimensions_count(); ++dim) {
+ output_dims->push_back(indices_shape.dims(dim));
+ }
+ for (int dim = axis + 1; dim < input_shape.dimensions_count(); ++dim) {
output_dims->push_back(input_shape.dims(dim));
}
}
@@ -1501,6 +1519,65 @@ void ProcessTileOperator(Model* model, TensorFlowTileOperator* op) {
}
}
+void ProcessAnyOperator(Model* model, AnyOperator* op) {
+ CHECK_EQ(op->inputs.size(), 2);
+ CHECK_EQ(op->outputs.size(), 1);
+
+ auto& output_array = model->GetArray(op->outputs[0]);
+ if (output_array.has_shape()) {
+ // We have already run.
+ return;
+ }
+
+ const auto& input_array = model->GetArray(op->inputs[0]);
+ if (!input_array.has_shape()) {
+ // Yield until input dims have been resolved.
+ return;
+ }
+ const auto& input_shape = input_array.shape();
+
+ auto& reduction_indices_array = model->GetArray(op->inputs[1]);
+ if (!reduction_indices_array.has_shape()) {
+ // Yield until reduction indices shape been resolved.
+ return;
+ }
+ if (!reduction_indices_array.buffer) {
+ // Yield until the reduction indices are constant.
+ return;
+ }
+ CHECK(reduction_indices_array.data_type == ArrayDataType::kInt32)
+ << "Any reduction input must be int32";
+
+ int input_rank = input_shape.dimensions_count();
+ std::set<int32> true_indices;
+ const auto& reduction_indices =
+ reduction_indices_array.GetBuffer<ArrayDataType::kInt32>().data;
+ for (int i = 0; i < reduction_indices.size(); ++i) {
+ const int32 reduction_index = reduction_indices[i];
+ if (reduction_index < -input_rank || reduction_index >= input_rank) {
+ CHECK(false) << "Invalid reduction dimension " << reduction_index
+ << " for input with " << input_rank << " dimensions";
+ }
+ int32 wrapped_index = reduction_index;
+ if (wrapped_index < 0) {
+ wrapped_index += input_rank;
+ }
+ true_indices.insert(wrapped_index);
+ }
+
+ auto* mutable_dims = output_array.mutable_shape()->mutable_dims();
+ mutable_dims->clear();
+ for (int i = 0; i < input_rank; ++i) {
+ if (true_indices.count(i) > 0) {
+ if (op->keep_dims) {
+ mutable_dims->emplace_back(1);
+ }
+ } else {
+ mutable_dims->emplace_back(input_shape.dims(i));
+ }
+ }
+}
+
} // namespace
bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) {
@@ -1539,6 +1616,8 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) {
case OperatorType::kFloor:
case OperatorType::kExp:
case OperatorType::kSin:
+ case OperatorType::kLogicalAnd:
+ case OperatorType::kLogicalNot:
ProcessSimpleOperator(model, op, 0);
break;
case OperatorType::kGather:
@@ -1607,7 +1686,7 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) {
case OperatorType::kL2Pool:
ProcessL2PoolOperator(model, static_cast<L2PoolOperator*>(op));
break;
- case OperatorType::kMin: // Reduction Min
+ case OperatorType::kReduceMin: // Reduction Min
case OperatorType::kReduceMax: // Reduction Max
case OperatorType::kSum:
case OperatorType::kReduceProd:
@@ -1732,6 +1811,9 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) {
case OperatorType::kTile:
ProcessTileOperator(model, static_cast<TensorFlowTileOperator*>(op));
break;
+ case OperatorType::kAny:
+ ProcessAnyOperator(model, static_cast<AnyOperator*>(op));
+ break;
default:
// Unimplemented, another graph transformation should drop it.
LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(op->type);
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_reshape.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_reshape.cc
index 404f27e067..5295eeccec 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_reshape.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_reshape.cc
@@ -59,6 +59,15 @@ bool IsReshapeTrivial(const Model& model, const Operator& op,
if (CountOpsWithInput(model, op.outputs[0]) == 1) {
const auto* next_op = GetOpWithInput(model, op.outputs[0]);
if (next_op->type == OperatorType::kReshape) {
+ if (!IsDiscardableArray(model, next_op->outputs[0])) {
+ // If the |next_op| output is used as a model output we need to preserve
+ // its shape.
+ transformation->AddMessageF(
+ "%s cannot be merged into following reshape %s as it is "
+ "non-discardable and must keep the specified shape",
+ LogName(op), LogName(*next_op));
+ return false;
+ }
transformation->AddMessageF(
"%s is trivial because its output is only consumed by another "
"Reshape op %s",
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_gather.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_gather.cc
index debe298a5a..36d7dad0ce 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_gather.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_gather.cc
@@ -69,7 +69,7 @@ bool ResolveConstantGather::Run(Model* model, std::size_t op_index) {
}
const auto* op = static_cast<const GatherOperator*>(base_op);
- CHECK_EQ(op->inputs.size(), 2);
+ CHECK_GE(op->inputs.size(), 2);
CHECK_EQ(op->outputs.size(), 1);
auto& output_array = model->GetArray(op->outputs[0]);
if (output_array.data_type == ArrayDataType::kNone) {
@@ -81,10 +81,14 @@ bool ResolveConstantGather::Run(Model* model, std::size_t op_index) {
return false;
}
- // Only handling axis=0 for now.
- if (op->axis != 0) {
+ if (!op->axis) {
+ // Yield until axis has been set by ResolveGatherAttributes.
+ return false;
+ }
+ if (op->axis.value() != 0) {
+ // Only handling axis=0 for now.
AddMessageF("%s has axis %d; only axis=0 is supported", LogName(*op),
- op->axis);
+ op->axis.value());
return false;
}
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc
index 51099cf74a..fe3882c28d 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc
@@ -57,7 +57,7 @@ bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) {
case OperatorType::kSqrt:
case OperatorType::kSquare:
case OperatorType::kSum:
- case OperatorType::kMin: // Reduction Min
+ case OperatorType::kReduceMin: // Reduction Min
case OperatorType::kReduceMax: // Reduction Max
case OperatorType::kReshape:
case OperatorType::kRelu6:
@@ -196,7 +196,7 @@ bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) {
}
output_float_data[i] = sum;
}
- } else if (unary_op->type == OperatorType::kMin) {
+ } else if (unary_op->type == OperatorType::kReduceMin) {
// At the moment only full reduction across all dimensions is supported.
// TODO(starka): Output should not be padded.
for (int i = 0; i < output_dims_count; i++) {
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_gather_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_gather_attributes.cc
new file mode 100644
index 0000000000..ce825c91af
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_gather_attributes.cc
@@ -0,0 +1,53 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+bool ResolveGatherAttributes::Run(Model* model, std::size_t op_index) {
+ auto* gather_op = model->operators[op_index].get();
+ if (gather_op->type != OperatorType::kGather) return false;
+ auto* op = static_cast<GatherOperator*>(gather_op);
+
+ if (op->axis) {
+ // Attributes already resolved
+ return false;
+ }
+ if (op->inputs.size() != 3) return false;
+ if (!IsConstantParameterArray(*model, op->inputs[2])) return false;
+
+ const auto& indices_array = model->GetArray(op->inputs[2]);
+ if (!indices_array.has_shape()) return false;
+ const auto& axis_data = indices_array.GetBuffer<ArrayDataType::kInt32>().data;
+ CHECK_EQ(axis_data.size(), 1)
+ << "Multidimensional gather not supported on " << LogName(*op);
+ op->axis = {axis_data[0]};
+
+ // Drop the axis array as we no longer need it.
+ DeleteArrayIfUsedOnce(op->inputs[2], model);
+ op->inputs.resize(2);
+
+ return true;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_reduce_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_reduce_attributes.cc
index 5f8a06ba92..7d456af2fb 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_reduce_attributes.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_reduce_attributes.cc
@@ -48,6 +48,8 @@ bool ResolveReduceAttributes::Run(Model* model, std::size_t op_index) {
return ResolveAttributes(model, static_cast<TensorFlowSumOperator*>(op));
case OperatorType::kReduceProd:
return ResolveAttributes(model, static_cast<TensorFlowProdOperator*>(op));
+ case OperatorType::kReduceMin:
+ return ResolveAttributes(model, static_cast<TensorFlowMinOperator*>(op));
case OperatorType::kReduceMax:
return ResolveAttributes(model, static_cast<TensorFlowMaxOperator*>(op));
default:
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/unfuse_activation_functions.cc b/tensorflow/contrib/lite/toco/graph_transformations/unfuse_activation_functions.cc
index 2c7046c8c7..69bad2fa89 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/unfuse_activation_functions.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/unfuse_activation_functions.cc
@@ -64,7 +64,14 @@ bool UnfuseActivationFunctions::Run(Model* model, std::size_t op_index) {
const string& tmp_array_name =
AvailableArrayName(*model, op->outputs[0] + "_unfused");
CHECK(!model->HasArray(tmp_array_name));
- model->GetOrCreateArray(tmp_array_name);
+
+ const auto& output_array = model->GetArray(op->outputs[0]);
+ auto& tmp_array = model->GetOrCreateArray(tmp_array_name);
+ if (output_array.quantization_params) {
+ tmp_array.GetOrCreateQuantizationParams() =
+ output_array.GetQuantizationParams();
+ }
+
ac_op->inputs = {tmp_array_name};
op->outputs = {tmp_array_name};
return true;
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/unpartition_embedding_lookup.cc b/tensorflow/contrib/lite/toco/graph_transformations/unpartition_embedding_lookup.cc
index cbea39bcc0..dd9e26e68b 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/unpartition_embedding_lookup.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/unpartition_embedding_lookup.cc
@@ -187,6 +187,7 @@ bool UnpartitionEmbeddingLookup::Run(Model* model, std::size_t op_index) {
AvailableArrayName(*model, gather_ops[0]->inputs[0] + "_permuted/perm"));
gather_params_permute_op->outputs.push_back(
AvailableArrayName(*model, gather_ops[0]->inputs[0] + "_permuted"));
+ gather_params_permute_op->axis = {0};
op_it = model->operators.emplace(op_it, gather_params_permute_op) + 1;
model->GetOrCreateArray(gather_params_permute_op->outputs[0]);
const auto& partition_array = model->GetArray(gather_ops[0]->inputs[0]);
@@ -212,6 +213,7 @@ bool UnpartitionEmbeddingLookup::Run(Model* model, std::size_t op_index) {
mod_op->inputs[0]};
merged_gather_op->outputs = {stitch_op->outputs[0]};
merged_gather_op->input_rank = partition_array.shape().dimensions_count();
+ merged_gather_op->axis = {0};
model->operators.emplace(op_it, merged_gather_op);
AddMessageF(
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc
index 576eb71534..8bb797fe0f 100644
--- a/tensorflow/contrib/lite/toco/import_tensorflow.cc
+++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc
@@ -1042,22 +1042,6 @@ tensorflow::Status ConvertSimpleOperator(
return ConvertSimpleOperator<Op>(node, tf_import_flags, model);
}
-tensorflow::Status ConvertMinOperator(
- const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
- CHECK_EQ(node.op(), "Min");
- TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
- auto* op = new TensorFlowMinOperator;
- op->inputs.push_back(node.input(0));
- op->inputs.push_back(node.input(1));
- op->outputs.push_back(node.name());
- model->operators.emplace_back(op);
- if (HasAttr(node, "keep_dims")) {
- op->keep_dims = GetBoolAttr(node, "keep_dims");
- }
- return tensorflow::Status::OK();
-}
-
tensorflow::Status ConvertUnsupportedOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) {
@@ -1197,8 +1181,17 @@ tensorflow::Status ConvertGatherOperator(
auto* op = new GatherOperator;
op->inputs.push_back(node.input(0));
op->inputs.push_back(node.input(1));
- // TODO(ahentz): we currently ignore the third tensor in GatherV2 but we
- // should read it an pass it on to the TF Lite Interpreter.
+ if (node.input_size() >= 3) {
+ // GatherV2 form where we are provided an axis. It may be either a constant
+ // or runtime defined value, so we just wire up the array and let
+ // ResolveGatherAttributes take care of it later on.
+ const auto axis_data_type = GetDataTypeAttr(node, "Taxis");
+ CHECK(axis_data_type == DT_INT32 || axis_data_type == DT_INT64);
+ op->inputs.push_back(node.input(2));
+ } else {
+ // Gather form that assumes axis=0.
+ op->axis = {0};
+ }
op->outputs.push_back(node.name());
model->operators.emplace_back(op);
return tensorflow::Status::OK();
@@ -1585,6 +1578,24 @@ tensorflow::Status ConvertShapeOperator(
return tensorflow::Status::OK();
}
+tensorflow::Status ConvertAnyOperator(
+ const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
+ Model* model) {
+ CHECK_EQ(node.op(), "Any");
+ TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
+ const auto idx_type =
+ HasAttr(node, "Tidx") ? GetDataTypeAttr(node, "Tidx") : DT_INT32;
+ CHECK(idx_type == DT_INT32);
+ auto op = absl::make_unique<AnyOperator>();
+ op->inputs.push_back(node.input(0));
+ op->inputs.push_back(node.input(1));
+ op->outputs.push_back(node.name());
+ op->keep_dims =
+ HasAttr(node, "keep_dims") ? GetBoolAttr(node, "keep_dims") : false;
+ model->operators.push_back(std::move(op));
+ return tensorflow::Status::OK();
+}
+
void StripCaretFromArrayNames(Model* model) {
for (auto& op : model->operators) {
for (auto& input : op->inputs) {
@@ -1820,6 +1831,7 @@ ConverterMapType GetTensorFlowNodeConverterMap() {
{"Add", ConvertSimpleOperator<AddOperator, 2>},
{"AddN", ConvertSimpleOperator<AddNOperator>},
{"All", ConvertSimpleOperator<TensorFlowAllOperator>},
+ {"Any", ConvertAnyOperator},
{"ArgMax", ConvertArgMinMaxOperator<ArgMaxOperator, kArgMax>},
{"ArgMin", ConvertArgMinMaxOperator<ArgMinOperator, kArgMin>},
{"Assert", ConvertSimpleOperator<TensorFlowAssertOperator>},
@@ -1862,15 +1874,16 @@ ConverterMapType GetTensorFlowNodeConverterMap() {
{"Less", ConvertSimpleOperator<TensorFlowLessOperator, 2>},
{"LessEqual", ConvertSimpleOperator<TensorFlowLessEqualOperator, 2>},
{"Log", ConvertSimpleOperator<LogOperator, 1>},
- {"Log", ConvertSimpleOperator<LogOperator, 1>},
{"LogSoftmax", ConvertSimpleOperator<LogSoftmaxOperator, 1>},
+ {"LogicalAnd", ConvertSimpleOperator<LogicalAndOperator, 2>},
+ {"LogicalNot", ConvertSimpleOperator<LogicalNotOperator, 1>},
{"MatMul", ConvertMatMulOperator},
{"Max", ConvertReduceOperator<TensorFlowMaxOperator>},
{"MaxPool", ConvertMaxPoolOperator},
{"Maximum", ConvertSimpleOperator<TensorFlowMaximumOperator, 2>},
{"Mean", ConvertReduceOperator<MeanOperator>},
{"Merge", ConvertSimpleOperator<TensorFlowMergeOperator, 2>},
- {"Min", ConvertMinOperator},
+ {"Min", ConvertReduceOperator<TensorFlowMinOperator>},
{"Minimum", ConvertSimpleOperator<TensorFlowMinimumOperator, 2>},
{"Mul", ConvertSimpleOperator<MulOperator, 2>},
{"Neg", ConvertSimpleOperator<NegOperator, 1>},
diff --git a/tensorflow/contrib/lite/toco/model.h b/tensorflow/contrib/lite/toco/model.h
index 8fff68cf47..6fe194516d 100644
--- a/tensorflow/contrib/lite/toco/model.h
+++ b/tensorflow/contrib/lite/toco/model.h
@@ -23,6 +23,7 @@ limitations under the License.
#include <unordered_map>
#include <vector>
+#include "absl/types/optional.h"
#include "tensorflow/contrib/lite/toco/model_flags.pb.h"
#include "tensorflow/contrib/lite/toco/runtime/types.h"
#include "tensorflow/contrib/lite/toco/toco_port.h"
@@ -109,7 +110,7 @@ enum class OperatorType : uint8 {
kLessEqual,
kReduceMax, // Reduction Max
kMaximum, // Element-wise Maximum
- kMin, // Reduction Min
+ kReduceMin, // Reduction Min
kMinimum, // Element-wise Minimum
kMatMul,
kMerge,
@@ -142,6 +143,9 @@ enum class OperatorType : uint8 {
kNotEqual,
kPow,
kArgMin,
+ kAny,
+ kLogicalAnd,
+ kLogicalNot,
};
// Helper to deal with TensorFlow arrays using a different ordering of
@@ -1415,16 +1419,15 @@ struct TensorFlowMaxOperator : Operator {
bool keep_dims = false;
};
-// Global min reduction: computes the min of all of entries in the input array.
-// Thus the output is "0-dimensional": it consists of a single scalar value.
+// Min reduction: computes the min of all of entries across the axes.
//
// Inputs:
// inputs[0]: required: the input array
//
-// TensorFlow equivalent: Min --- except that we only support the special case
-// of global reduction across all dimensions.
+// TensorFlow equivalent: Min
struct TensorFlowMinOperator : Operator {
- TensorFlowMinOperator() : Operator(OperatorType::kMin) {}
+ TensorFlowMinOperator() : Operator(OperatorType::kReduceMin) {}
+ std::vector<int> axis;
bool keep_dims = false;
};
@@ -1525,11 +1528,15 @@ struct FloorOperator : Operator {
// Inputs:
// inputs[0]: required: the params array
// inputs[1]: required: the indices to gather
+// inputs[2]: optional: axis
//
// TensorFlow equivalent: Gather
struct GatherOperator : Operator {
GatherOperator() : Operator(OperatorType::kGather) {}
- int axis = 0;
+ // Axis is populated explicitly or implicitly from the axis input by
+ // ResolveGatherAttributes. An empty axis indicates that the axis has not yet
+ // be resolved.
+ absl::optional<int> axis;
int input_rank = 0;
};
@@ -1685,6 +1692,39 @@ struct PowOperator : Operator {
PowOperator() : Operator(OperatorType::kPow) {}
};
+// Any operator:
+//
+// Inputs:
+// Inputs[0]: required: A boolean input tensor.
+// Inputs[1]: required: reduction_indices.
+//
+// TensorFlow equivalent: tf.reduce_any.
+struct AnyOperator : Operator {
+ AnyOperator() : Operator(OperatorType::kAny) {}
+ bool keep_dims = false;
+};
+
+// LogicalAnd operator:
+//
+// Inputs:
+// Inputs[0]: required: A boolean tensor.
+// Inputs[1]: required: A boolean tensor.
+//
+// TensorFlow equivalent: tf.logical_and.
+struct LogicalAndOperator : Operator {
+ LogicalAndOperator() : Operator(OperatorType::kLogicalAnd) {}
+};
+
+// LogicalNot operator:
+//
+// Inputs:
+// Inputs[0]: required: A boolean tensor.
+//
+// TensorFlow equivalent: tf.logical_not.
+struct LogicalNotOperator : Operator {
+ LogicalNotOperator() : Operator(OperatorType::kLogicalNot) {}
+};
+
// Alloc's are used for transient arrays only. An Alloc specifies which interval
// of the "transient_data" workspace buffer passed to inference functions, is to
// be used for the transient array at hand. The 'start' and 'end' values are
diff --git a/tensorflow/contrib/lite/toco/tflite/operator.cc b/tensorflow/contrib/lite/toco/tflite/operator.cc
index 68d13586f1..1a1c4b8944 100644
--- a/tensorflow/contrib/lite/toco/tflite/operator.cc
+++ b/tensorflow/contrib/lite/toco/tflite/operator.cc
@@ -370,12 +370,13 @@ class Gather : public BuiltinOperator<GatherOperator, ::tflite::GatherOptions,
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
- return ::tflite::CreateGatherOptions(*builder, op.axis);
+ int axis = op.axis ? op.axis.value() : 0;
+ return ::tflite::CreateGatherOptions(*builder, axis);
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {
- op->axis = options.axis();
+ op->axis = {options.axis()};
}
int GetVersion(const Operator& op) const override { return 1; }
diff --git a/tensorflow/contrib/lite/toco/toco_tooling.cc b/tensorflow/contrib/lite/toco/toco_tooling.cc
index d8964ebc13..aa7f6996eb 100644
--- a/tensorflow/contrib/lite/toco/toco_tooling.cc
+++ b/tensorflow/contrib/lite/toco/toco_tooling.cc
@@ -117,6 +117,7 @@ void MakeGeneralGraphTransformationsSet(
transformations->Add(new ResolveConstantShapeOrRank);
transformations->Add(new MakeInitialDequantizeOperator);
transformations->Add(new UnpartitionEmbeddingLookup);
+ transformations->Add(new ResolveGatherAttributes);
}
bool SupportsQuantization(FileFormat format) {
diff --git a/tensorflow/contrib/lite/toco/tooling_util.cc b/tensorflow/contrib/lite/toco/tooling_util.cc
index 4305727c8c..52f8df45a2 100644
--- a/tensorflow/contrib/lite/toco/tooling_util.cc
+++ b/tensorflow/contrib/lite/toco/tooling_util.cc
@@ -351,10 +351,10 @@ const char* OperatorTypeName(OperatorType type) {
HANDLE_OPERATORTYPENAME_CASE(LessEqual)
HANDLE_OPERATORTYPENAME_CASE(MatMul)
HANDLE_OPERATORTYPENAME_CASE(ReduceMax) // Reduction Max
- HANDLE_OPERATORTYPENAME_CASE(Maximum) // Element-wise Maximum
+ HANDLE_OPERATORTYPENAME_CASE(Maximum) // Element-wise Maximum
HANDLE_OPERATORTYPENAME_CASE(Merge)
- HANDLE_OPERATORTYPENAME_CASE(Min) // Reduction Min
- HANDLE_OPERATORTYPENAME_CASE(Minimum) // Element-wise Minimum
+ HANDLE_OPERATORTYPENAME_CASE(ReduceMin) // Reduction Min
+ HANDLE_OPERATORTYPENAME_CASE(Minimum) // Element-wise Minimum
HANDLE_OPERATORTYPENAME_CASE(Neg)
HANDLE_OPERATORTYPENAME_CASE(Pack)
HANDLE_OPERATORTYPENAME_CASE(Pad)
@@ -399,6 +399,9 @@ const char* OperatorTypeName(OperatorType type) {
HANDLE_OPERATORTYPENAME_CASE(Equal)
HANDLE_OPERATORTYPENAME_CASE(NotEqual)
HANDLE_OPERATORTYPENAME_CASE(Pow)
+ HANDLE_OPERATORTYPENAME_CASE(Any)
+ HANDLE_OPERATORTYPENAME_CASE(LogicalAnd)
+ HANDLE_OPERATORTYPENAME_CASE(LogicalNot)
default:
LOG(FATAL) << "Unhandled op type";
#undef HANDLE_OPERATORTYPENAME_CASE
@@ -940,8 +943,12 @@ void CheckEachArray(const Model& model) {
// shape.
CHECK(array->has_shape());
// Constant buffer should has a valid shape.
- for (int d : array->shape().dims()) {
- CHECK_GE(d, 1);
+ bool is_scalar =
+ array->shape().dimensions_count() == 1 && array->shape().dims(0) == 0;
+ if (!is_scalar) {
+ for (int d : array->shape().dims()) {
+ CHECK_GE(d, 1);
+ }
}
// The shape flat-size should agree with the buffer length.
CHECK_EQ(array->buffer->Length(),
diff --git a/tensorflow/contrib/makefile/tf_op_files.txt b/tensorflow/contrib/makefile/tf_op_files.txt
index 6e7423f85e..ecf2e120df 100644
--- a/tensorflow/contrib/makefile/tf_op_files.txt
+++ b/tensorflow/contrib/makefile/tf_op_files.txt
@@ -229,6 +229,8 @@ tensorflow/core/kernels/cast_op_impl_int32.cc
tensorflow/core/kernels/cast_op_impl_int64.cc
tensorflow/core/kernels/cast_op_impl_int8.cc
tensorflow/core/kernels/cast_op_impl_uint16.cc
+tensorflow/core/kernels/cast_op_impl_uint32.cc
+tensorflow/core/kernels/cast_op_impl_uint64.cc
tensorflow/core/kernels/cast_op_impl_uint8.cc
tensorflow/core/kernels/boosted_trees/prediction_ops.cc
tensorflow/core/kernels/boosted_trees/resource_ops.cc
diff --git a/tensorflow/contrib/nccl/kernels/nccl_manager.cc b/tensorflow/contrib/nccl/kernels/nccl_manager.cc
index b1cb89391c..99fecf9651 100644
--- a/tensorflow/contrib/nccl/kernels/nccl_manager.cc
+++ b/tensorflow/contrib/nccl/kernels/nccl_manager.cc
@@ -445,7 +445,7 @@ void NcclManager::LoopKernelLaunches(NcclStream* nccl_stream) {
se::Stream* comm_stream = nccl_stream->stream.get();
ScopedActivateExecutorContext scoped_context(nccl_stream->executor);
const cudaStream_t* cu_stream = reinterpret_cast<const cudaStream_t*>(
- comm_stream->implementation()->CudaStreamMemberHack());
+ comm_stream->implementation()->GpuStreamMemberHack());
while (true) {
// Find collective to run.
diff --git a/tensorflow/contrib/tensor_forest/BUILD b/tensorflow/contrib/tensor_forest/BUILD
index 136856c015..164f3e58e6 100644
--- a/tensorflow/contrib/tensor_forest/BUILD
+++ b/tensorflow/contrib/tensor_forest/BUILD
@@ -223,7 +223,6 @@ tf_kernel_library(
":model_ops_lib",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
- "//tensorflow/core:lib_internal",
],
alwayslink = 1,
)
@@ -319,7 +318,6 @@ tf_kernel_library(
":stats_ops_lib",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
- "//tensorflow/core:lib_internal",
],
alwayslink = 1,
)
diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph.cc b/tensorflow/contrib/tensorrt/convert/convert_graph.cc
index 089b03dcb5..68c78e8301 100644
--- a/tensorflow/contrib/tensorrt/convert/convert_graph.cc
+++ b/tensorflow/contrib/tensorrt/convert/convert_graph.cc
@@ -831,9 +831,7 @@ tensorflow::Status ConvertAfterShapes(ConversionParams& params) {
// The allocator is used to build the engine. The build and the built engine
// will be destroyed after we get the serialized engine string, so it's fine
// to use unique_ptr here.
- // TODO(aaroey): nvinfer1::IGpuAllocator doesn't have a virtual destructor
- // and destructing the unique_ptr will result in segfault, fix it.
- std::unique_ptr<TRTDeviceAllocator> alloc;
+ std::unique_ptr<TRTBaseAllocator> alloc;
auto device_alloc = GetDeviceAndAllocator(params, engine);
int cuda_device_id = 0;
if (device_alloc.first >= 0) {
diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.cu.cc b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.cu.cc
index 988b35f74f..2de7973750 100644
--- a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.cu.cc
+++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.cu.cc
@@ -65,7 +65,7 @@ class IncPluginTRT : public OpKernel {
reinterpret_cast<const cudaStream_t*>(context->op_device_context()
->stream()
->implementation()
- ->CudaStreamMemberHack()));
+ ->GpuStreamMemberHack()));
IncrementKernel(input_tensor.flat<float>().data(), inc_,
output_tensor->flat<float>().data(),
input_shape.num_elements(), *stream);
diff --git a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc
index 04d072f5d9..54009179a8 100644
--- a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc
+++ b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc
@@ -230,7 +230,7 @@ void TRTEngineOp::ExecuteCalibration(tensorflow::OpKernelContext* ctx,
reinterpret_cast<const cudaStream_t*>(ctx->op_device_context()
->stream()
->implementation()
- ->CudaStreamMemberHack()));
+ ->GpuStreamMemberHack()));
calib_res->calibrator_->setBatch(input_data, *stream);
VLOG(2) << "Passed calibration data";
ExecuteNativeSegment(ctx, helper);
@@ -391,7 +391,7 @@ void TRTEngineOp::ComputeAsync(tensorflow::OpKernelContext* ctx,
reinterpret_cast<const cudaStream_t*>(ctx->op_device_context()
->stream()
->implementation()
- ->CudaStreamMemberHack()));
+ ->GpuStreamMemberHack()));
// TODO(jie): trt enqueue does not return error
auto& trt_execution_context_ptr = engine_ctx_pair.second;
diff --git a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.h b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.h
index 6fe318be6a..9265250605 100644
--- a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.h
+++ b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.h
@@ -81,7 +81,7 @@ class TRTEngineOp : public AsyncOpKernel {
std::vector<string> output_nodes_;
// keep device allocator for TRT.
- std::unique_ptr<TRTDeviceAllocator> allocator_;
+ std::unique_ptr<TRTBaseAllocator> allocator_;
// serialized protobuf segment or trt engine depending on static_engine_ flag.
string serialized_segment_;
diff --git a/tensorflow/contrib/tensorrt/resources/trt_allocator.cc b/tensorflow/contrib/tensorrt/resources/trt_allocator.cc
index 9f115990c3..81d7330b49 100644
--- a/tensorflow/contrib/tensorrt/resources/trt_allocator.cc
+++ b/tensorflow/contrib/tensorrt/resources/trt_allocator.cc
@@ -37,8 +37,22 @@ void TRTCudaAllocator::free(void* memory) { cudaFree(memory); }
void* TRTDeviceAllocator::allocate(uint64_t size, uint64_t alignment,
uint32_t flags) {
+ // WAR for allocator alignment requirement. Certain cuda API calls require GPU
+ // memory with alignemtn to cudaDeviceProp::textureAlignment.
+ // See issue #20856
+ alignment = 512;
assert((alignment & (alignment - 1)) == 0); // zero or a power of 2.
- void* mem = allocator_->AllocateRaw(alignment, size);
+ size_t total_size = size + alignment;
+ void* mem = allocator_->AllocateRaw(alignment, total_size);
+ if (!mem) {
+ return nullptr;
+ }
+
+ void* alloc_mem = mem;
+ CHECK(std::align(alignment, size, mem, total_size));
+ if (mem != alloc_mem) {
+ CHECK(mem_map_.insert({mem, alloc_mem}).second);
+ }
VLOG(2) << "Allocated " << size << " bytes with alignment " << alignment
<< " @ " << mem;
return mem;
@@ -51,7 +65,15 @@ TRTDeviceAllocator::TRTDeviceAllocator(tensorflow::Allocator* allocator)
void TRTDeviceAllocator::free(void* memory) {
VLOG(2) << "Deallocating @ " << memory;
- allocator_->DeallocateRaw(memory);
+ // allocated memory adjusted for alignment, restore the original pointer
+ if (memory) {
+ auto alloc_mem = mem_map_.find(memory);
+ if (alloc_mem != mem_map_.end()) {
+ memory = alloc_mem->second;
+ mem_map_.erase(alloc_mem->first);
+ }
+ allocator_->DeallocateRaw(memory);
+ }
}
} // namespace tensorrt
diff --git a/tensorflow/contrib/tensorrt/resources/trt_allocator.h b/tensorflow/contrib/tensorrt/resources/trt_allocator.h
index 97ac82ca5d..b8825b108d 100644
--- a/tensorflow/contrib/tensorrt/resources/trt_allocator.h
+++ b/tensorflow/contrib/tensorrt/resources/trt_allocator.h
@@ -37,7 +37,14 @@ class IGpuAllocator {
namespace tensorflow {
namespace tensorrt {
-class TRTCudaAllocator : public nvinfer1::IGpuAllocator {
+class TRTBaseAllocator : public nvinfer1::IGpuAllocator {
+ // Base allocator class so we can have a virtual destructor;
+ public:
+ // python wrapper seems to be not happy with an pure virtual destructor;
+ virtual ~TRTBaseAllocator() = default;
+};
+
+class TRTCudaAllocator : public TRTBaseAllocator {
// Allocator implementation that is using cuda allocator instead of device
// allocator in case we can't get device allocator from TF.
public:
@@ -47,7 +54,7 @@ class TRTCudaAllocator : public nvinfer1::IGpuAllocator {
void free(void* memory) override;
};
-class TRTDeviceAllocator : public nvinfer1::IGpuAllocator {
+class TRTDeviceAllocator : public TRTBaseAllocator {
// Allocator implementation wrapping TF device allocators.
public:
TRTDeviceAllocator(tensorflow::Allocator* allocator);
@@ -62,6 +69,9 @@ class TRTDeviceAllocator : public nvinfer1::IGpuAllocator {
private:
tensorflow::Allocator* allocator_;
+
+ // supporting alignment from allocation request requires a map to free;
+ std::unordered_map<void*, void*> mem_map_;
};
} // namespace tensorrt
diff --git a/tensorflow/contrib/tensorrt/resources/trt_resources.h b/tensorflow/contrib/tensorrt/resources/trt_resources.h
index b7d5ffd674..d7d56cb95e 100644
--- a/tensorflow/contrib/tensorrt/resources/trt_resources.h
+++ b/tensorflow/contrib/tensorrt/resources/trt_resources.h
@@ -64,7 +64,7 @@ class TRTCalibrationResource : public tensorflow::ResourceBase {
std::unique_ptr<TRTInt8Calibrator> calibrator_;
TrtUniquePtrType<nvinfer1::IBuilder> builder_;
TrtUniquePtrType<nvinfer1::ICudaEngine> engine_;
- std::unique_ptr<nvinfer1::IGpuAllocator> allocator_;
+ std::unique_ptr<TRTBaseAllocator> allocator_;
tensorflow::tensorrt::Logger logger_;
// TODO(sami): Use threadpool threads!
std::unique_ptr<std::thread> thr_;
diff --git a/tensorflow/contrib/tensorrt/test/tf_trt_integration_test.py b/tensorflow/contrib/tensorrt/test/tf_trt_integration_test.py
index 7c3ef498c9..035b112254 100644
--- a/tensorflow/contrib/tensorrt/test/tf_trt_integration_test.py
+++ b/tensorflow/contrib/tensorrt/test/tf_trt_integration_test.py
@@ -186,8 +186,8 @@ class TfTrtIntegrationTest(test_util.TensorFlowTestCase):
# Defaults to 2 runs to verify result across multiple runs is same.
for _ in range(num_runs):
new_val = sess.run(out, {inp: input_data})
- self.assertEquals(TEST_GRAPHS[graph_key].expected_output_dims,
- new_val.shape)
+ self.assertEqual(TEST_GRAPHS[graph_key].expected_output_dims,
+ new_val.shape)
if val is not None:
self.assertAllEqual(new_val, val)
val = new_val
@@ -220,19 +220,19 @@ class TfTrtIntegrationTest(test_util.TensorFlowTestCase):
for n in gdef.node:
if n.op == "TRTEngineOp":
num_engines += 1
- self.assertNotEqual("", n.attr["serialized_segment"].s)
- self.assertNotEqual("", n.attr["segment_funcdef_name"].s)
- self.assertEquals(n.attr["precision_mode"].s, precision_mode)
- self.assertEquals(n.attr["static_engine"].b, not dynamic_engine)
+ self.assertNotEqual(to_bytes(""), n.attr["serialized_segment"].s)
+ self.assertNotEqual(to_bytes(""), n.attr["segment_funcdef_name"].s)
+ self.assertEqual(n.attr["precision_mode"].s, to_bytes(precision_mode))
+ self.assertEqual(n.attr["static_engine"].b, not dynamic_engine)
if precision_mode == MODE_INT8 and is_calibrated:
- self.assertNotEqual("", n.attr["calibration_data"].s)
+ self.assertNotEqual(to_bytes(""), n.attr["calibration_data"].s)
else:
- self.assertEquals("", n.attr["calibration_data"].s)
+ self.assertEqual(to_bytes(""), n.attr["calibration_data"].s)
if precision_mode is None:
- self.assertEquals(num_engines, 0)
+ self.assertEqual(num_engines, 0)
else:
- self.assertEquals(num_engines,
- TEST_GRAPHS[graph_key].num_expected_engines)
+ self.assertEqual(num_engines,
+ TEST_GRAPHS[graph_key].num_expected_engines)
def _RunTest(self, graph_key, use_optimizer, precision_mode,
dynamic_infer_engine, dynamic_calib_engine):
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
index 718ea630a8..78b79b111e 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
@@ -701,8 +701,6 @@ def generate_per_core_enqueue_ops_fn_for_host(
infeed_queue = tpu_feed.InfeedQueue(
number_of_tuple_elements=len(per_host_sharded_inputs[0]))
captured_infeed_queue.capture(infeed_queue)
- infeed_queue.set_configuration_from_sharded_input_tensors(
- per_host_sharded_inputs)
per_host_enqueue_ops = infeed_queue.generate_enqueue_ops(
per_host_sharded_inputs, tpu_ordinal_function=tpu_ordinal_function_impl)
@@ -837,8 +835,6 @@ def generate_per_host_v2_enqueue_ops_fn_for_host(
infeed_queue = tpu_feed.InfeedQueue(
number_of_tuple_elements=len(per_host_sharded_inputs[0]))
captured_infeed_queue.capture(infeed_queue)
- infeed_queue.set_configuration_from_sharded_input_tensors(
- per_host_sharded_inputs)
per_host_enqueue_ops = infeed_queue.generate_enqueue_ops(
per_host_sharded_inputs, tpu_ordinal_function=tpu_ordinal_function_impl)
@@ -867,7 +863,7 @@ def generate_broadcast_enqueue_ops_fn(ctx, input_fn, inputs_structure_recorder,
def tpu_ordinal_function_impl(replica_id):
if ctx.device_assignment:
- return ctx.device_assignment.tpu_ordinal(replica_id=replica_id)
+ return ctx.device_assignment.tpu_ordinal(replica=replica_id)
else:
return replica_id % num_replicas_per_host
diff --git a/tensorflow/core/api_def/base_api/api_def_ResourceScatterNdUpdate.pbtxt b/tensorflow/core/api_def/base_api/api_def_ResourceScatterNdUpdate.pbtxt
index b07ee9fda9..17b79ee30c 100644
--- a/tensorflow/core/api_def/base_api/api_def_ResourceScatterNdUpdate.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_ResourceScatterNdUpdate.pbtxt
@@ -51,7 +51,7 @@ For example, say we want to update 4 scattered elements to a rank-1 tensor to
8 elements. In Python, that update would look like this:
```python
- ref = tfe.Variable([1, 2, 3, 4, 5, 6, 7, 8])
+ ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8])
indices = tf.constant([[4], [3], [1] ,[7]])
updates = tf.constant([9, 10, 11, 12])
update = tf.scatter_nd_update(ref, indices, updates)
diff --git a/tensorflow/core/common_runtime/gpu/gpu_device.cc b/tensorflow/core/common_runtime/gpu/gpu_device.cc
index 3cb51b0dbc..7110ffd40c 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_device.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_device.cc
@@ -41,6 +41,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/gpu/gpu_util.h"
#include "tensorflow/core/common_runtime/gpu_device_context.h"
#include "tensorflow/core/common_runtime/local_device.h"
+#include "tensorflow/core/common_runtime/visitable_allocator.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/device_base.h"
#include "tensorflow/core/framework/op_kernel.h"
@@ -856,7 +857,7 @@ void BaseGPUDevice::ReinitializeDevice(OpKernelContext* context,
static_cast<ConcretePerOpGpuDevice*>(device);
DCHECK(concrete_device);
const cudaStream_t* cuda_stream = reinterpret_cast<const cudaStream_t*>(
- streams_[stream_id]->compute->implementation()->CudaStreamMemberHack());
+ streams_[stream_id]->compute->implementation()->GpuStreamMemberHack());
concrete_device->Reinitialize(context, cuda_stream, tf_gpu_id_, allocator,
scratch_[stream_id]);
}
diff --git a/tensorflow/core/common_runtime/process_state.cc b/tensorflow/core/common_runtime/process_state.cc
index 4d83b25ce6..447338e7bd 100644
--- a/tensorflow/core/common_runtime/process_state.cc
+++ b/tensorflow/core/common_runtime/process_state.cc
@@ -71,7 +71,7 @@ ProcessState::MemDesc ProcessState::PtrType(const void* ptr) {
return MemDesc();
}
-Allocator* ProcessState::GetCPUAllocator(int numa_node) {
+VisitableAllocator* ProcessState::GetCPUAllocator(int numa_node) {
CHECK_GE(numa_node, 0);
if (!numa_enabled_) numa_node = 0;
mutex_lock lock(mu_);
diff --git a/tensorflow/core/common_runtime/process_state.h b/tensorflow/core/common_runtime/process_state.h
index 0f4ae230bb..2892677333 100644
--- a/tensorflow/core/common_runtime/process_state.h
+++ b/tensorflow/core/common_runtime/process_state.h
@@ -65,7 +65,7 @@ class ProcessState {
// Returns the one CPUAllocator used for the given numa_node.
// TEMPORARY: ignores numa_node.
- Allocator* GetCPUAllocator(int numa_node);
+ VisitableAllocator* GetCPUAllocator(int numa_node);
typedef std::unordered_map<const void*, MemDesc> MDMap;
@@ -87,7 +87,7 @@ class ProcessState {
mutex mu_;
- std::vector<Allocator*> cpu_allocators_ GUARDED_BY(mu_);
+ std::vector<VisitableAllocator*> cpu_allocators_ GUARDED_BY(mu_);
virtual ~ProcessState();
diff --git a/tensorflow/core/graph/algorithm.cc b/tensorflow/core/graph/algorithm.cc
index 4652fbe406..9b4200e0b4 100644
--- a/tensorflow/core/graph/algorithm.cc
+++ b/tensorflow/core/graph/algorithm.cc
@@ -25,7 +25,8 @@ namespace tensorflow {
void DFS(const Graph& g, const std::function<void(Node*)>& enter,
const std::function<void(Node*)>& leave,
- const NodeComparator& stable_comparator) {
+ const NodeComparator& stable_comparator,
+ const EdgeFilter& edge_filter) {
// Stack of work to do.
struct Work {
Node* node;
@@ -52,7 +53,6 @@ void DFS(const Graph& g, const std::function<void(Node*)>& enter,
// Arrange to call leave(n) when all done with descendants.
if (leave) stack.push_back(Work{n, true});
- gtl::iterator_range<NeighborIter> nodes = n->out_nodes();
auto add_work = [&visited, &stack](Node* out) {
if (!visited[out->id()]) {
// Note; we must not mark as visited until we actually process it.
@@ -62,16 +62,20 @@ void DFS(const Graph& g, const std::function<void(Node*)>& enter,
if (stable_comparator) {
std::vector<Node*> nodes_sorted;
- for (Node* out : nodes) {
- nodes_sorted.emplace_back(out);
+ for (const Edge* out_edge : n->out_edges()) {
+ if (!edge_filter || edge_filter(*out_edge)) {
+ nodes_sorted.emplace_back(out_edge->dst());
+ }
}
std::sort(nodes_sorted.begin(), nodes_sorted.end(), stable_comparator);
for (Node* out : nodes_sorted) {
add_work(out);
}
} else {
- for (Node* out : nodes) {
- add_work(out);
+ for (const Edge* out_edge : n->out_edges()) {
+ if (!edge_filter || edge_filter(*out_edge)) {
+ add_work(out_edge->dst());
+ }
}
}
}
@@ -118,8 +122,6 @@ void ReverseDFSFromHelper(const Graph& g, gtl::ArraySlice<T> start,
// Arrange to call leave(n) when all done with descendants.
if (leave) stack.push_back(Work{n, true});
- gtl::iterator_range<NeighborIter> nodes = n->in_nodes();
-
auto add_work = [&visited, &stack](T out) {
if (!visited[out->id()]) {
// Note; we must not mark as visited until we actually process it.
@@ -129,16 +131,16 @@ void ReverseDFSFromHelper(const Graph& g, gtl::ArraySlice<T> start,
if (stable_comparator) {
std::vector<T> nodes_sorted;
- for (T in : nodes) {
- nodes_sorted.emplace_back(in);
+ for (const Edge* in_edge : n->in_edges()) {
+ nodes_sorted.emplace_back(in_edge->src());
}
std::sort(nodes_sorted.begin(), nodes_sorted.end(), stable_comparator);
for (T in : nodes_sorted) {
add_work(in);
}
} else {
- for (T in : nodes) {
- add_work(in);
+ for (const Edge* in_edge : n->in_edges()) {
+ add_work(in_edge->src());
}
}
}
@@ -161,14 +163,17 @@ void ReverseDFSFrom(const Graph& g, gtl::ArraySlice<Node*> start,
}
void GetPostOrder(const Graph& g, std::vector<Node*>* order,
- const NodeComparator& stable_comparator) {
+ const NodeComparator& stable_comparator,
+ const EdgeFilter& edge_filter) {
order->clear();
- DFS(g, nullptr, [order](Node* n) { order->push_back(n); }, stable_comparator);
+ DFS(g, nullptr, [order](Node* n) { order->push_back(n); }, stable_comparator,
+ edge_filter);
}
void GetReversePostOrder(const Graph& g, std::vector<Node*>* order,
- const NodeComparator& stable_comparator) {
- GetPostOrder(g, order, stable_comparator);
+ const NodeComparator& stable_comparator,
+ const EdgeFilter& edge_filter) {
+ GetPostOrder(g, order, stable_comparator, edge_filter);
std::reverse(order->begin(), order->end());
}
diff --git a/tensorflow/core/graph/algorithm.h b/tensorflow/core/graph/algorithm.h
index ac4a099013..5bbbc6f6dc 100644
--- a/tensorflow/core/graph/algorithm.h
+++ b/tensorflow/core/graph/algorithm.h
@@ -28,6 +28,8 @@ namespace tensorflow {
// Comparator for two nodes. This is used in order to get a stable ording.
using NodeComparator = std::function<bool(const Node*, const Node*)>;
+using EdgeFilter = std::function<bool(const Edge&)>;
+
// Compares two node based on their ids.
struct NodeComparatorID {
bool operator()(const Node* n1, const Node* n2) const {
@@ -47,9 +49,11 @@ struct NodeComparatorName {
// If leave is not empty, calls leave(n) after visiting all children of n.
// If stable_comparator is set, a stable ordering of visit is achieved by
// sorting a node's neighbors first before visiting them.
+// If edge_filter is set then ignores edges for which edge_filter returns false.
extern void DFS(const Graph& g, const std::function<void(Node*)>& enter,
const std::function<void(Node*)>& leave,
- const NodeComparator& stable_comparator = {});
+ const NodeComparator& stable_comparator = {},
+ const EdgeFilter& edge_filter = {});
// Perform a reverse depth-first-search on g starting at the sink node.
// If enter is not empty, calls enter(n) before visiting any parents of n.
@@ -83,15 +87,21 @@ extern void ReverseDFSFrom(const Graph& g, gtl::ArraySlice<const Node*> start,
// If stable_comparator is set, a stable ordering of visit is achieved by
// sorting a node's neighbors first before visiting them.
//
+// If edge_filter is set then ignores edges for which edge_filter returns false.
+//
// REQUIRES: order is not NULL.
void GetPostOrder(const Graph& g, std::vector<Node*>* order,
- const NodeComparator& stable_comparator = {});
+ const NodeComparator& stable_comparator = {},
+ const EdgeFilter& edge_filter = {});
// Stores in *order the reverse post-order numbering of all nodes
// If stable_comparator is set, a stable ordering of visit is achieved by
// sorting a node's neighbors first before visiting them.
+//
+// If edge_filter is set then ignores edges for which edge_filter returns false.
void GetReversePostOrder(const Graph& g, std::vector<Node*>* order,
- const NodeComparator& stable_comparator = {});
+ const NodeComparator& stable_comparator = {},
+ const EdgeFilter& edge_filter = {});
// Prune nodes in "g" that are not in some path from the source node
// to any node in 'nodes'. Returns true if changes were made to the graph.
diff --git a/tensorflow/core/graph/algorithm_test.cc b/tensorflow/core/graph/algorithm_test.cc
index f67d5a2fd2..60a3e66aa1 100644
--- a/tensorflow/core/graph/algorithm_test.cc
+++ b/tensorflow/core/graph/algorithm_test.cc
@@ -36,6 +36,11 @@ namespace {
REGISTER_OP("TestParams").Output("o: float");
REGISTER_OP("TestInput").Output("a: float").Output("b: float");
REGISTER_OP("TestMul").Input("a: float").Input("b: float").Output("o: float");
+REGISTER_OP("TestUnary").Input("a: float").Output("o: float");
+REGISTER_OP("TestBinary")
+ .Input("a: float")
+ .Input("b: float")
+ .Output("o: float");
// Compares that the order of nodes in 'inputs' respects the
// pair orders described in 'ordered_pairs'.
@@ -148,5 +153,52 @@ TEST(AlgorithmTest, ReversePostOrderStable) {
EXPECT_TRUE(ExpectBefore({{"t2", "t3"}}, order, &error));
}
}
+
+TEST(AlgorithmTest, PostOrderWithEdgeFilter) {
+ GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
+ string error;
+ Node* n0 = ops::SourceOp("TestParams", b.opts().WithName("n0"));
+ Node* n1 = ops::UnaryOp("TestUnary", n0, b.opts().WithName("n1"));
+ Node* n2 = ops::UnaryOp("TestUnary", n1, b.opts().WithName("n2"));
+ Node* n3 = ops::BinaryOp("TestBinary", n2, n0, b.opts().WithName("n3"));
+
+ Graph g(OpRegistry::Global());
+ TF_ASSERT_OK(GraphDefBuilderToGraph(b, &g));
+
+ g.AddEdge(g.FindNodeId(n3->id()), 0, g.FindNodeId(n1->id()), 1);
+
+ std::vector<Node*> post_order;
+ auto edge_filter = [&](const Edge& e) {
+ return !(e.src()->id() == n3->id() && e.dst()->id() == n1->id());
+ };
+
+ std::vector<Node*> expected_post_order = {
+ g.sink_node(), g.FindNodeId(n3->id()), g.FindNodeId(n2->id()),
+ g.FindNodeId(n1->id()), g.FindNodeId(n0->id()), g.source_node()};
+
+ std::vector<Node*> expected_reverse_post_order = expected_post_order;
+ std::reverse(expected_reverse_post_order.begin(),
+ expected_reverse_post_order.end());
+
+ GetPostOrder(g, &post_order, /*stable_comparator=*/{},
+ /*edge_filter=*/edge_filter);
+
+ ASSERT_EQ(expected_post_order.size(), post_order.size());
+ for (int i = 0; i < post_order.size(); i++) {
+ CHECK_EQ(post_order[i], expected_post_order[i])
+ << post_order[i]->name() << " vs. " << expected_post_order[i]->name();
+ }
+
+ std::vector<Node*> reverse_post_order;
+ GetReversePostOrder(g, &reverse_post_order, /*stable_comparator=*/{},
+ /*edge_filter=*/edge_filter);
+
+ ASSERT_EQ(expected_reverse_post_order.size(), reverse_post_order.size());
+ for (int i = 0; i < reverse_post_order.size(); i++) {
+ CHECK_EQ(reverse_post_order[i], expected_reverse_post_order[i])
+ << reverse_post_order[i]->name() << " vs. "
+ << expected_reverse_post_order[i]->name();
+ }
+}
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/core/graph/graph_constructor.cc b/tensorflow/core/graph/graph_constructor.cc
index add26f3b71..8c73f8f712 100644
--- a/tensorflow/core/graph/graph_constructor.cc
+++ b/tensorflow/core/graph/graph_constructor.cc
@@ -1042,6 +1042,14 @@ Status GraphConstructor::Convert() {
}
if (processed < node_defs_.size()) {
+ LOG(WARNING) << "IN " << __func__ << (node_defs_.size() - processed)
+ << " NODES IN A CYCLE";
+ for (int64 i = 0; i < node_defs_.size(); i++) {
+ if (pending_count_[i] != 0) {
+ LOG(WARNING) << "PENDING: " << SummarizeNodeDef(*node_defs_[i])
+ << "WITH PENDING COUNT = " << pending_count_[i];
+ }
+ }
return errors::InvalidArgument(node_defs_.size() - processed,
" nodes in a cycle");
}
diff --git a/tensorflow/core/grappler/clusters/cluster.cc b/tensorflow/core/grappler/clusters/cluster.cc
index 8d8c6084ec..6d84283e68 100644
--- a/tensorflow/core/grappler/clusters/cluster.cc
+++ b/tensorflow/core/grappler/clusters/cluster.cc
@@ -29,6 +29,14 @@ void Cluster::AllowSoftPlacement(bool soft_placement_state) {
options_.config.set_allow_soft_placement(soft_placement_state);
}
+void Cluster::SetNumInterOpThreads(int num_threads) {
+ for (int i = 0; i < options_.config.session_inter_op_thread_pool_size();
+ ++i) {
+ options_.config.mutable_session_inter_op_thread_pool(i)->set_num_threads(
+ num_threads);
+ }
+}
+
void Cluster::SetNumWarmupSteps(int num_steps) {
options_.config.mutable_graph_options()->set_build_cost_model_after(
num_steps);
diff --git a/tensorflow/core/grappler/clusters/cluster.h b/tensorflow/core/grappler/clusters/cluster.h
index 06db36b3aa..e94fb900c0 100644
--- a/tensorflow/core/grappler/clusters/cluster.h
+++ b/tensorflow/core/grappler/clusters/cluster.h
@@ -65,6 +65,9 @@ class Cluster {
// with reftype input(s) which are from CPU.
void AllowSoftPlacement(bool soft_placement_state);
+ // Update the number of inter-op threads for each per-session threadpool
+ void SetNumInterOpThreads(int num_threads);
+
// Set the number of steps required to warmup TensorFlow. Must be called
// before Provision().
void SetNumWarmupSteps(int num_steps);
diff --git a/tensorflow/core/grappler/optimizers/data/BUILD b/tensorflow/core/grappler/optimizers/data/BUILD
index 3cb9d4d61c..c8946c499c 100644
--- a/tensorflow/core/grappler/optimizers/data/BUILD
+++ b/tensorflow/core/grappler/optimizers/data/BUILD
@@ -48,10 +48,7 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core/grappler:graph_view",
"//tensorflow/core/grappler:grappler_item",
- "//tensorflow/core/grappler:grappler_item_builder",
"//tensorflow/core/grappler:utils",
- "//tensorflow/core/grappler/clusters:virtual_cluster",
- "//tensorflow/core/grappler/optimizers:meta_optimizer",
] + tf_protos_all(),
)
diff --git a/tensorflow/core/grappler/optimizers/data/graph_utils.cc b/tensorflow/core/grappler/optimizers/data/graph_utils.cc
index b5b46ccafe..ea5f450009 100644
--- a/tensorflow/core/grappler/optimizers/data/graph_utils.cc
+++ b/tensorflow/core/grappler/optimizers/data/graph_utils.cc
@@ -16,11 +16,7 @@ limitations under the License.
#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
#include "tensorflow/core/framework/device_base.h"
-#include "tensorflow/core/grappler/clusters/virtual_cluster.h"
#include "tensorflow/core/grappler/graph_view.h"
-#include "tensorflow/core/grappler/grappler_item.h"
-#include "tensorflow/core/grappler/grappler_item_builder.h"
-#include "tensorflow/core/grappler/optimizers/meta_optimizer.h"
#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow {
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index 1c842150fd..99e5e3cfca 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -4855,6 +4855,8 @@ filegroup(
"cast_op_impl_int64.cc",
"cast_op_impl_int8.cc",
"cast_op_impl_uint16.cc",
+ "cast_op_impl_uint32.cc",
+ "cast_op_impl_uint64.cc",
"cast_op_impl_uint8.cc",
"concat_lib.h",
"concat_lib_cpu.cc",
diff --git a/tensorflow/core/kernels/cast_op.cc b/tensorflow/core/kernels/cast_op.cc
index 626db9131a..e6e388b3d1 100644
--- a/tensorflow/core/kernels/cast_op.cc
+++ b/tensorflow/core/kernels/cast_op.cc
@@ -41,8 +41,10 @@ typedef Eigen::SyclDevice SYCLDevice;
#define CURRY_TYPES2(FN, arg0) \
FN(arg0, bool); \
FN(arg0, uint8); \
- FN(arg0, int8); \
FN(arg0, uint16); \
+ FN(arg0, uint32); \
+ FN(arg0, uint64); \
+ FN(arg0, int8); \
FN(arg0, int16); \
FN(arg0, int32); \
FN(arg0, int64); \
@@ -86,10 +88,14 @@ Status CpuCastOp::Prepare() {
work_ = GetCpuCastFromBool(dst_dtype_);
} else if (src_dtype_ == DT_UINT8) {
work_ = GetCpuCastFromUint8(dst_dtype_);
- } else if (src_dtype_ == DT_INT8) {
- work_ = GetCpuCastFromInt8(dst_dtype_);
} else if (src_dtype_ == DT_UINT16) {
work_ = GetCpuCastFromUint16(dst_dtype_);
+ } else if (src_dtype_ == DT_UINT32) {
+ work_ = GetCpuCastFromUint32(dst_dtype_);
+ } else if (src_dtype_ == DT_UINT64) {
+ work_ = GetCpuCastFromUint64(dst_dtype_);
+ } else if (src_dtype_ == DT_INT8) {
+ work_ = GetCpuCastFromInt8(dst_dtype_);
} else if (src_dtype_ == DT_INT16) {
work_ = GetCpuCastFromInt16(dst_dtype_);
} else if (src_dtype_ == DT_INT32) {
@@ -135,10 +141,14 @@ class GpuCastOp : public CastOpBase {
work_ = GetGpuCastFromBool(dst_dtype_);
} else if (src_dtype_ == DT_UINT8) {
work_ = GetGpuCastFromUint8(dst_dtype_);
- } else if (src_dtype_ == DT_INT8) {
- work_ = GetGpuCastFromInt8(dst_dtype_);
} else if (src_dtype_ == DT_UINT16) {
work_ = GetGpuCastFromUint16(dst_dtype_);
+ } else if (src_dtype_ == DT_UINT32) {
+ work_ = GetGpuCastFromUint32(dst_dtype_);
+ } else if (src_dtype_ == DT_UINT64) {
+ work_ = GetGpuCastFromUint64(dst_dtype_);
+ } else if (src_dtype_ == DT_INT8) {
+ work_ = GetGpuCastFromInt8(dst_dtype_);
} else if (src_dtype_ == DT_INT16) {
work_ = GetGpuCastFromInt16(dst_dtype_);
} else if (src_dtype_ == DT_INT32) {
@@ -178,8 +188,10 @@ REGISTER_KERNEL_BUILDER(Name("Cast").Device(DEVICE_CPU), CpuCastOp);
CURRY_TYPES2(REGISTER_CAST_GPU, bool);
CURRY_TYPES2(REGISTER_CAST_GPU, uint8);
-CURRY_TYPES2(REGISTER_CAST_GPU, int8);
CURRY_TYPES2(REGISTER_CAST_GPU, uint16);
+CURRY_TYPES2(REGISTER_CAST_GPU, uint32);
+CURRY_TYPES2(REGISTER_CAST_GPU, uint64);
+CURRY_TYPES2(REGISTER_CAST_GPU, int8);
CURRY_TYPES2(REGISTER_CAST_GPU, int16);
CURRY_TYPES2(REGISTER_CAST_GPU, int32);
CURRY_TYPES2(REGISTER_CAST_GPU, int64);
diff --git a/tensorflow/core/kernels/cast_op_gpu.cu.cc b/tensorflow/core/kernels/cast_op_gpu.cu.cc
index 9c9e9e7658..607e7f5efd 100644
--- a/tensorflow/core/kernels/cast_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/cast_op_gpu.cu.cc
@@ -37,8 +37,10 @@ struct CastFunctor<GPUDevice, O, I> {
#define DEFINE_ALL_FROM(in_type) \
DEFINE(in_type, bool); \
DEFINE(in_type, uint8); \
- DEFINE(in_type, int8); \
DEFINE(in_type, uint16); \
+ DEFINE(in_type, uint32); \
+ DEFINE(in_type, uint64); \
+ DEFINE(in_type, int8); \
DEFINE(in_type, int16); \
DEFINE(in_type, int32); \
DEFINE(in_type, int64); \
@@ -50,8 +52,10 @@ struct CastFunctor<GPUDevice, O, I> {
DEFINE_ALL_FROM(bool);
DEFINE_ALL_FROM(uint8);
-DEFINE_ALL_FROM(int8);
DEFINE_ALL_FROM(uint16);
+DEFINE_ALL_FROM(uint32);
+DEFINE_ALL_FROM(uint64);
+DEFINE_ALL_FROM(int8);
DEFINE_ALL_FROM(int16);
DEFINE_ALL_FROM(int32);
DEFINE_ALL_FROM(int64);
diff --git a/tensorflow/core/kernels/cast_op_impl.h b/tensorflow/core/kernels/cast_op_impl.h
index 382e5440e1..fe821b25df 100644
--- a/tensorflow/core/kernels/cast_op_impl.h
+++ b/tensorflow/core/kernels/cast_op_impl.h
@@ -48,8 +48,10 @@ struct CastFunctor<Eigen::SyclDevice, O, I> {
#define CURRY_TYPES3_NO_HALF(FN, arg0, arg1) \
FN(arg0, arg1, bool); \
FN(arg0, arg1, uint8); \
- FN(arg0, arg1, int8); \
FN(arg0, arg1, uint16); \
+ FN(arg0, arg1, uint32); \
+ FN(arg0, arg1, uint64); \
+ FN(arg0, arg1, int8); \
FN(arg0, arg1, int16); \
FN(arg0, arg1, int32); \
FN(arg0, arg1, int64); \
@@ -82,10 +84,16 @@ std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
GetCpuCastFromUint8(DataType dst_dtype);
std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetCpuCastFromInt8(DataType dst_dtype);
+GetCpuCastFromUint16(DataType dst_dtype);
std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetCpuCastFromUint16(DataType dst_dtype);
+GetCpuCastFromUint32(DataType dst_dtype);
+
+std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
+GetCpuCastFromUint64(DataType dst_dtype);
+
+std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
+GetCpuCastFromInt8(DataType dst_dtype);
std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
GetCpuCastFromInt16(DataType dst_dtype);
@@ -123,10 +131,16 @@ std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
GetGpuCastFromUint8(DataType dst_dtype);
std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetGpuCastFromInt8(DataType dst_dtype);
+GetGpuCastFromUint16(DataType dst_dtype);
std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetGpuCastFromUint16(DataType dst_dtype);
+GetGpuCastFromUint32(DataType dst_dtype);
+
+std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
+GetGpuCastFromUint64(DataType dst_dtype);
+
+std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
+GetGpuCastFromInt8(DataType dst_dtype);
std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
GetGpuCastFromInt16(DataType dst_dtype);
@@ -168,6 +182,12 @@ std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
GetSyclCastFromUint16(DataType dst_dtype);
std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
+GetSyclCastFromUint32(DataType dst_dtype);
+
+std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
+GetSyclCastFromUint64(DataType dst_dtype);
+
+std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
GetSyclCastFromInt16(DataType dst_dtype);
std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
diff --git a/tensorflow/core/kernels/cast_op_impl_uint32.cc b/tensorflow/core/kernels/cast_op_impl_uint32.cc
new file mode 100644
index 0000000000..d1a854d98b
--- /dev/null
+++ b/tensorflow/core/kernels/cast_op_impl_uint32.cc
@@ -0,0 +1,46 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/kernels/cast_op_impl.h"
+
+namespace tensorflow {
+
+typedef Eigen::ThreadPoolDevice CPUDevice;
+typedef Eigen::GpuDevice GPUDevice;
+
+std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
+GetCpuCastFromUint32(DataType dst_dtype) {
+ CURRY_TYPES3(CAST_CASE, CPUDevice, uint32);
+ return nullptr;
+}
+
+#if GOOGLE_CUDA
+std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
+GetGpuCastFromUint32(DataType dst_dtype) {
+ CURRY_TYPES3_NO_BF16(CAST_CASE, GPUDevice, uint32);
+ return nullptr;
+}
+#endif // GOOGLE_CUDA
+
+#ifdef TENSORFLOW_USE_SYCL
+typedef Eigen::SyclDevice SYCLDevice;
+std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
+GetSyclCastFromUint32(DataType dst_dtype) {
+ CURRY_TYPES3_NO_HALF(CAST_CASE, SYCLDevice, uint32);
+ return nullptr;
+}
+#endif // TENSORFLOW_USE_SYCL
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cast_op_impl_uint64.cc b/tensorflow/core/kernels/cast_op_impl_uint64.cc
new file mode 100644
index 0000000000..604e0424fc
--- /dev/null
+++ b/tensorflow/core/kernels/cast_op_impl_uint64.cc
@@ -0,0 +1,46 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/kernels/cast_op_impl.h"
+
+namespace tensorflow {
+
+typedef Eigen::ThreadPoolDevice CPUDevice;
+typedef Eigen::GpuDevice GPUDevice;
+
+std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
+GetCpuCastFromUint64(DataType dst_dtype) {
+ CURRY_TYPES3(CAST_CASE, CPUDevice, uint64);
+ return nullptr;
+}
+
+#if GOOGLE_CUDA
+std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
+GetGpuCastFromUint64(DataType dst_dtype) {
+ CURRY_TYPES3_NO_BF16(CAST_CASE, GPUDevice, uint64);
+ return nullptr;
+}
+#endif // GOOGLE_CUDA
+
+#ifdef TENSORFLOW_USE_SYCL
+typedef Eigen::SyclDevice SYCLDevice;
+std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
+GetSyclCastFromUint64(DataType dst_dtype) {
+ CURRY_TYPES3_NO_HALF(CAST_CASE, SYCLDevice, uint64);
+ return nullptr;
+}
+#endif // TENSORFLOW_USE_SYCL
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cast_op_test.cc b/tensorflow/core/kernels/cast_op_test.cc
index 7da9d28a3d..75e21802c0 100644
--- a/tensorflow/core/kernels/cast_op_test.cc
+++ b/tensorflow/core/kernels/cast_op_test.cc
@@ -70,6 +70,8 @@ class CastOpTest : public OpsTestBase {
#define TEST_ALL_CASTS_FROM(in) \
TEST_CAST(in, uint8); \
TEST_CAST(in, uint16); \
+ TEST_CAST(in, uint32); \
+ TEST_CAST(in, uint64); \
TEST_CAST(in, int16); \
TEST_CAST(in, int32); \
TEST_CAST(in, int64); \
@@ -80,6 +82,8 @@ class CastOpTest : public OpsTestBase {
TEST_ALL_CASTS_FROM(uint8)
TEST_ALL_CASTS_FROM(uint16)
+TEST_ALL_CASTS_FROM(uint32)
+TEST_ALL_CASTS_FROM(uint64)
TEST_ALL_CASTS_FROM(int16)
TEST_ALL_CASTS_FROM(int32)
TEST_ALL_CASTS_FROM(int64)
diff --git a/tensorflow/core/kernels/cuda_solvers.cc b/tensorflow/core/kernels/cuda_solvers.cc
index a857bd3ce4..a59baaa96f 100644
--- a/tensorflow/core/kernels/cuda_solvers.cc
+++ b/tensorflow/core/kernels/cuda_solvers.cc
@@ -151,7 +151,7 @@ CudaSolver::CudaSolver(OpKernelContext* context) : context_(context) {
reinterpret_cast<const cudaStream_t*>(context->op_device_context()
->stream()
->implementation()
- ->CudaStreamMemberHack()));
+ ->GpuStreamMemberHack()));
cuda_stream_ = *cu_stream_ptr;
HandleMap* handle_map = CHECK_NOTNULL(GetHandleMapSingleton());
auto it = handle_map->find(cuda_stream_);
diff --git a/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc b/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc
index 5472a192d9..2a25459194 100644
--- a/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc
@@ -641,20 +641,6 @@ Status LaunchDepthwiseConv2dGPUSmall(OpKernelContext* ctx,
return Status::OK();
}
-namespace {
-// Returns whether the context's GPU supports efficient fp16 math.
-bool HasFastHalfMath(OpKernelContext* ctx) {
- int major, minor;
- ctx->op_device_context()
- ->stream()
- ->parent()
- ->GetDeviceDescription()
- .cuda_compute_capability(&major, &minor);
- auto cuda_arch = major * 100 + minor * 10;
- // GPUs before sm_53 don't support fp16 math, and sm_61's fp16 math is slow.
- return cuda_arch >= 530 && cuda_arch != 610;
-}
-
namespace detail {
template <typename T>
struct PseudoHalfType {
@@ -666,9 +652,23 @@ struct PseudoHalfType<Eigen::half> {
};
} // namespace detail
+namespace {
// Maps to float if T is __half, and to T otherwise.
template <typename T>
using PseudoHalfType = typename detail::PseudoHalfType<T>::Type;
+
+// Returns whether the context's GPU supports efficient fp16 math.
+bool HasFastHalfMath(OpKernelContext* ctx) {
+ int major, minor;
+ ctx->op_device_context()
+ ->stream()
+ ->parent()
+ ->GetDeviceDescription()
+ .cuda_compute_capability(&major, &minor);
+ auto cuda_arch = major * 100 + minor * 10;
+ // GPUs before sm_53 don't support fp16 math, and sm_61's fp16 math is slow.
+ return cuda_arch >= 530 && cuda_arch != 610;
+}
} // namespace
template <typename T, DepthwiseConv2dDirection kDirection,
diff --git a/tensorflow/core/util/cuda_launch_config.h b/tensorflow/core/util/cuda_launch_config.h
index 81df7a51d7..d0d95736d3 100644
--- a/tensorflow/core/util/cuda_launch_config.h
+++ b/tensorflow/core/util/cuda_launch_config.h
@@ -295,7 +295,7 @@ inline const cudaStream_t& GetCudaStream(OpKernelContext* context) {
reinterpret_cast<const cudaStream_t*>(context->op_device_context()
->stream()
->implementation()
- ->CudaStreamMemberHack()));
+ ->GpuStreamMemberHack()));
return *ptr;
}
diff --git a/tensorflow/docs_src/guide/eager.md b/tensorflow/docs_src/guide/eager.md
index e98206eef9..42ad9652f8 100644
--- a/tensorflow/docs_src/guide/eager.md
+++ b/tensorflow/docs_src/guide/eager.md
@@ -225,7 +225,7 @@ the tape backwards and then discard. A particular `tf.GradientTape` can only
compute one gradient; subsequent calls throw a runtime error.
```py
-w = tfe.Variable([[1.0]])
+w = tf.Variable([[1.0]])
with tf.GradientTape() as tape:
loss = w * w
@@ -260,8 +260,8 @@ def grad(weights, biases):
train_steps = 200
learning_rate = 0.01
# Start with arbitrary values for W and B on the same batch of data
-W = tfe.Variable(5.)
-B = tfe.Variable(10.)
+W = tf.Variable(5.)
+B = tf.Variable(10.)
print("Initial loss: {:.3f}".format(loss(W, B)))
@@ -407,11 +407,11 @@ with tf.device("/gpu:0"):
### Variables and optimizers
-`tfe.Variable` objects store mutable `tf.Tensor` values accessed during
+`tf.Variable` objects store mutable `tf.Tensor` values accessed during
training to make automatic differentiation easier. The parameters of a model can
be encapsulated in classes as variables.
-Better encapsulate model parameters by using `tfe.Variable` with
+Better encapsulate model parameters by using `tf.Variable` with
`tf.GradientTape`. For example, the automatic differentiation example above
can be rewritten:
@@ -419,8 +419,8 @@ can be rewritten:
class Model(tf.keras.Model):
def __init__(self):
super(Model, self).__init__()
- self.W = tfe.Variable(5., name='weight')
- self.B = tfe.Variable(10., name='bias')
+ self.W = tf.Variable(5., name='weight')
+ self.B = tf.Variable(10., name='bias')
def call(self, inputs):
return inputs * self.W + self.B
@@ -498,17 +498,17 @@ is removed, and is then deleted.
```py
with tf.device("gpu:0"):
- v = tfe.Variable(tf.random_normal([1000, 1000]))
+ v = tf.Variable(tf.random_normal([1000, 1000]))
v = None # v no longer takes up GPU memory
```
### Object-based saving
-`tfe.Checkpoint` can save and restore `tfe.Variable`s to and from
+`tfe.Checkpoint` can save and restore `tf.Variable`s to and from
checkpoints:
```py
-x = tfe.Variable(10.)
+x = tf.Variable(10.)
checkpoint = tfe.Checkpoint(x=x) # save as "x"
@@ -612,7 +612,7 @@ def line_search_step(fn, init_x, rate=1.0):
`tf.GradientTape` is a powerful interface for computing gradients, but there
is another [Autograd](https://github.com/HIPS/autograd)-style API available for
automatic differentiation. These functions are useful if writing math code with
-only tensors and gradient functions, and without `tfe.Variables`:
+only tensors and gradient functions, and without `tf.Variables`:
* `tfe.gradients_function` —Returns a function that computes the derivatives
of its input function parameter with respect to its arguments. The input
diff --git a/tensorflow/docs_src/mobile/index.md b/tensorflow/docs_src/mobile/index.md
index 419ae7094a..6032fcad02 100644
--- a/tensorflow/docs_src/mobile/index.md
+++ b/tensorflow/docs_src/mobile/index.md
@@ -13,9 +13,6 @@ Here are a few of the differences between the two:
developed with TensorFlow Lite will have a smaller binary size, fewer
dependencies, and better performance.
-- TensorFlow Lite is in developer preview, so not all use cases are covered yet.
- We expect you to use TensorFlow Mobile to cover production cases.
-
- TensorFlow Lite supports only a limited set of operators, so not all models
will work on it by default. TensorFlow for Mobile has a fuller set of
supported functionality.
diff --git a/tensorflow/docs_src/mobile/tflite/index.md b/tensorflow/docs_src/mobile/tflite/index.md
index 3d1733024e..cc4af2a875 100644
--- a/tensorflow/docs_src/mobile/tflite/index.md
+++ b/tensorflow/docs_src/mobile/tflite/index.md
@@ -70,10 +70,9 @@ There are several factors which are fueling interest in this domain:
We believe the next wave of machine learning applications will have significant
processing on mobile and embedded devices.
-## TensorFlow Lite developer preview highlights
+## TensorFlow Lite highlights
-TensorFlow Lite is available as a developer preview and includes the
-following:
+TensorFlow Lite provides:
- A set of core operators, both quantized and float, many of which have been
tuned for mobile platforms. These can be used to create and run custom
@@ -129,9 +128,6 @@ following:
- Java and C++ API support
-Note: This is a developer release, and it’s likely that there will be changes in
-the API in upcoming versions. We do not guarantee backward or forward
-compatibility with this release.
## Getting Started
@@ -201,9 +197,5 @@ possible performance for a particular model on a particular device.
## Next Steps
-For the developer preview, most of our documentation is on GitHub. Please take a
-look at the [TensorFlow Lite
-repository](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite)
-on GitHub for more information and for code samples, demo applications, and
-more.
-
+The TensorFlow Lite [GitHub repository](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite).
+contains additional docs, code samples, and demo applications.
diff --git a/tensorflow/examples/speech_commands/freeze.py b/tensorflow/examples/speech_commands/freeze.py
index 7657b23c60..89e790d4e4 100644
--- a/tensorflow/examples/speech_commands/freeze.py
+++ b/tensorflow/examples/speech_commands/freeze.py
@@ -130,7 +130,7 @@ def main(_):
FLAGS.clip_stride_ms, FLAGS.window_size_ms, FLAGS.window_stride_ms,
FLAGS.feature_bin_count, FLAGS.model_architecture, FLAGS.preprocess)
if FLAGS.quantize:
- tf.contrib.quantize.create_training_graph(quant_delay=0)
+ tf.contrib.quantize.create_eval_graph()
models.load_variables_from_checkpoint(sess, FLAGS.start_checkpoint)
# Turn all the variables into inline constants inside the graph and save it.
diff --git a/tensorflow/examples/speech_commands/models.py b/tensorflow/examples/speech_commands/models.py
index 65ae3b1511..4d1454be0d 100644
--- a/tensorflow/examples/speech_commands/models.py
+++ b/tensorflow/examples/speech_commands/models.py
@@ -302,7 +302,7 @@ def create_conv_model(fingerprint_input, model_settings, is_training):
label_count = model_settings['label_count']
final_fc_weights = tf.get_variable(
name='final_fc_weights',
- initializer=tf.truncated_normal_initializer,
+ initializer=tf.truncated_normal_initializer(stddev=0.01),
shape=[second_conv_element_count, label_count])
final_fc_bias = tf.get_variable(
name='final_fc_bias',
diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py
index 573422e533..fbc2a11eda 100644
--- a/tensorflow/python/keras/engine/training.py
+++ b/tensorflow/python/keras/engine/training.py
@@ -217,10 +217,9 @@ class Model(Network):
for name in self.output_names:
if name not in loss:
logging.warning(
- 'Output "' + name + '" missing from loss dictionary. '
- 'We assume this was done on purpose, '
- 'and we will not be expecting '
- 'any data to be passed to "' + name + '" during training.')
+ 'Output "' + name + '" missing from loss dictionary. We assume '
+ 'this was done on purpose. The fit and evaluate APIs will not be '
+ 'expecting any data to be passed to "' + name + '".')
loss_functions.append(losses.get(loss.get(name)))
elif isinstance(loss, list):
if len(loss) != len(self.outputs):
diff --git a/tensorflow/python/keras/engine/training_test.py b/tensorflow/python/keras/engine/training_test.py
index d9e548f01f..c621a88fb3 100644
--- a/tensorflow/python/keras/engine/training_test.py
+++ b/tensorflow/python/keras/engine/training_test.py
@@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import logging
import os
import unittest
@@ -415,6 +416,28 @@ class TrainingTest(test.TestCase):
x2 = model.predict(val_a)
self.assertAllClose(x1, x2, atol=1e-7)
+ def test_compile_warning_for_loss_missing_output(self):
+ with self.test_session():
+ inp = keras.layers.Input(shape=(16,), name='input_a')
+ out_1 = keras.layers.Dense(8, name='dense_1')(inp)
+ out_2 = keras.layers.Dense(3, activation='softmax', name='dense_2')(out_1)
+ model = keras.models.Model(inputs=[inp], outputs=[out_1, out_2])
+
+ with test.mock.patch.object(logging, 'warning') as mock_log:
+ model.compile(
+ loss={
+ 'dense_2': 'categorical_crossentropy',
+ },
+ optimizer='rmsprop',
+ metrics={
+ 'dense_2': 'categorical_accuracy',
+ 'dense_1': 'categorical_accuracy',
+ })
+ msg = ('Output "dense_1" missing from loss dictionary. We assume this '
+ 'was done on purpose. The fit and evaluate APIs will not be '
+ 'expecting any data to be passed to "dense_1".')
+ self.assertRegexpMatches(str(mock_log.call_args), msg)
+
class LossWeightingTest(test.TestCase):
diff --git a/tensorflow/python/keras/metrics.py b/tensorflow/python/keras/metrics.py
index e03d7dfe93..72e15763cb 100644
--- a/tensorflow/python/keras/metrics.py
+++ b/tensorflow/python/keras/metrics.py
@@ -19,9 +19,16 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from abc import ABCMeta
+from abc import abstractmethod
import six
+from tensorflow.python.eager import context
+from tensorflow.python.eager import function
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
from tensorflow.python.keras import backend as K
+from tensorflow.python.keras.engine.base_layer import Layer
from tensorflow.python.keras.losses import binary_crossentropy
from tensorflow.python.keras.losses import categorical_crossentropy
from tensorflow.python.keras.losses import cosine_proximity
@@ -37,11 +44,385 @@ from tensorflow.python.keras.losses import sparse_categorical_crossentropy
from tensorflow.python.keras.losses import squared_hinge
from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object
from tensorflow.python.keras.utils.generic_utils import serialize_keras_object
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import confusion_matrix
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
+from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import variable_scope as vs
+from tensorflow.python.ops import weights_broadcast_ops
+from tensorflow.python.training import distribute as distribute_lib
+from tensorflow.python.util import tf_decorator
from tensorflow.python.util.tf_export import tf_export
+def update_state(update_state_fn):
+ """Decorator to wrap metric `update_state()` with `defun()`, `add_update()`.
+
+ Args:
+ update_state_fn: function that accumulates metric statistics.
+
+ Returns:
+ If eager execution is enabled, returns None.
+ If graph execution is enabled, returns an update op. This op should be
+ executed to update the metric state with the given inputs.
+ """
+
+ def decorated(*args, **kwargs):
+ """Decorated function with `defun()` and `add_update()`."""
+
+ # Converting update_state_fn() into a graph function, so that
+ # we can return a single op that performs all of the variable updates.
+ # Assigning to a different method name to avoid reference cycle.
+ defuned_update_state_fn = function.defun(update_state_fn)
+ update_op = defuned_update_state_fn(*args, **kwargs)
+ if update_op is not None: # update_op will be None in eager execution.
+ metric_obj = args[0]
+ metric_obj.add_update(update_op, inputs=True)
+ return update_op
+
+ return tf_decorator.make_decorator(update_state_fn, decorated)
+
+
+def result(result_fn):
+ """Decorator to wrap metric `result()` function in `merge_call()`.
+
+ Result computation is an idempotent operation that simply calculates the
+ metric value using the state variables.
+
+ If metric state variables are distributed across towers/devices and
+ `result()` is requested from the context of one device - This function wraps
+ `result()` in a distribution strategy `merge_call()`. With this,
+ the metric state variables will be aggregated across devices.
+
+ Args:
+ result_fn: function that computes the metric result.
+
+ Returns:
+ The metric result tensor.
+ """
+
+ def decorated(*args):
+ """Decorated function with merge_call."""
+ tower_context = distribute_lib.get_tower_context()
+ if tower_context is None: # if in cross tower context already
+ return result_fn()
+
+ # TODO(psv): Test distribution of metrics using different distribution
+ # strategies.
+
+ # Creating a wrapper for merge_fn. merge_call invokes the given merge_fn
+ # with distribution object as the first parameter. We create a wrapper here
+ # so that the result function need not have that parameter.
+ def merge_fn_wrapper(distribution, merge_fn, *args):
+ # We will get `PerDevice` merge function. Taking the first one as all are
+ # identical copies of the function that we had passed below.
+ return distribution.unwrap(merge_fn)[0](*args)
+
+ # Wrapping result in merge_call. merge_call is used when we want to leave
+ # tower mode and compute a value in cross tower mode.
+ return tower_context.merge_call(merge_fn_wrapper, result_fn, *args)
+
+ return tf_decorator.make_decorator(result_fn, decorated)
+
+
+def _safe_div(numerator, denominator):
+ """Divides two tensors element-wise, returning 0 if the denominator is <= 0.
+
+ Args:
+ numerator: A `Tensor`.
+ denominator: A `Tensor`, with dtype matching `numerator`.
+
+ Returns:
+ 0 if `denominator` <= 0, else `numerator` / `denominator`
+ """
+ t = math_ops.truediv(numerator, denominator)
+ zero = array_ops.zeros_like(t, dtype=denominator.dtype)
+ condition = math_ops.greater(denominator, zero)
+ zero = math_ops.cast(zero, t.dtype)
+ return array_ops.where(condition, t, zero)
+
+
+def _squeeze_or_expand_dimensions(y_pred, y_true, sample_weight):
+ """Squeeze or expand last dimension if needed.
+
+ 1. Squeezes last dim of `y_pred` or `y_true` if their rank differs by 1
+ (using `confusion_matrix.remove_squeezable_dimensions`).
+ 2. Squeezes or expands last dim of `sample_weight` if its rank differs by 1
+ from the new rank of `y_pred`.
+ If `sample_weight` is scalar, it is kept scalar.
+
+ This will use static shape if available. Otherwise, it will add graph
+ operations, which could result in a performance hit.
+
+ Args:
+ y_pred: Predicted values, a `Tensor` of arbitrary dimensions.
+ y_true: Optional label `Tensor` whose dimensions match `y_pred`.
+ sample_weight: Optional weight scalar or `Tensor` whose dimensions match
+ `y_pred`.
+
+ Returns:
+ Tuple of `y_pred`, `y_true` and `sample_weight`. Each of them possibly has
+ the last dimension squeezed,
+ `sample_weight` could be extended by one dimension.
+ """
+ if y_true is not None:
+ # squeeze last dim of `y_pred` or `y_true` if their rank differs by 1
+ y_true, y_pred = confusion_matrix.remove_squeezable_dimensions(
+ y_true, y_pred)
+ y_pred.get_shape().assert_is_compatible_with(y_true.get_shape())
+
+ if sample_weight is None:
+ return y_pred, y_true, None
+
+ sample_weight = ops.convert_to_tensor(sample_weight)
+ weights_shape = sample_weight.get_shape()
+ weights_rank = weights_shape.ndims
+ if weights_rank == 0: # If weights is scalar, do nothing.
+ return y_pred, y_true, sample_weight
+
+ y_pred_shape = y_pred.get_shape()
+ y_pred_rank = y_pred_shape.ndims
+ if (y_pred_rank is not None) and (weights_rank is not None):
+ # Use static rank.
+ if weights_rank - y_pred_rank == 1:
+ sample_weight = array_ops.squeeze(sample_weight, [-1])
+ elif y_pred_rank - weights_rank == 1:
+ sample_weight = array_ops.expand_dims(sample_weight, [-1])
+ return y_pred, y_true, sample_weight
+
+ # Use dynamic rank.
+ weights_rank_tensor = array_ops.rank(sample_weight)
+ rank_diff = weights_rank_tensor - array_ops.rank(y_pred)
+ maybe_squeeze_weights = lambda: array_ops.squeeze(sample_weight, [-1])
+
+ def _maybe_expand_weights():
+ return control_flow_ops.cond(
+ math_ops.equal(rank_diff,
+ -1), lambda: array_ops.expand_dims(sample_weight, [-1]),
+ lambda: sample_weight)
+
+ def _maybe_adjust_weights():
+ return control_flow_ops.cond(
+ math_ops.equal(rank_diff, 1), maybe_squeeze_weights,
+ _maybe_expand_weights)
+
+ # squeeze or expand last dim of `sample_weight` if its rank differs by 1
+ # from the new rank of `y_pred`.
+ sample_weight = control_flow_ops.cond(
+ math_ops.equal(weights_rank_tensor, 0), lambda: sample_weight,
+ _maybe_adjust_weights)
+ return y_pred, y_true, sample_weight
+
+
+class Metric(Layer):
+ """Encapsulates metric logic and state.
+
+ Usage with eager execution:
+
+ ```python
+ m = SomeMetric(...)
+ for input in ...:
+ m.update_state(input)
+ print('Final result: ', m.result().numpy())
+ ```
+
+ Usage with graph execution:
+
+ ```python
+ m = SomeMetric(...)
+ init_op = tf.global_variables_initializer() # Initialize variables
+ with tf.Session() as sess:
+ sess.run(init_op)
+ for input in ...:
+ update_op = m.update_state(input)
+ sess.run(update_op)
+ print('Final result: ', sess.run(m.result()))
+ ```
+
+ To be implemented by subclasses:
+ * `__init__()`: All state variables should be created in this method by
+ calling `self.add_weight()` like: `self.var = self.add_weight(...)`
+ * `update_state()`: Has all updates to the state variables like:
+ self.var.assign_add(...). Please decorate the function with:
+ @update_state: Converts `update_state()` into a graph function, so that
+ we can return a single op that performs all of the variable updates and
+ adds the update op to the metric layer.
+ * `result()`: Computes and returns a value for the metric
+ from the state variables. Please decorate the function with:
+ @result: Wraps `result()` in a distribution strategy merge_call().
+
+ Example subclass implementation:
+
+ ```
+ class BinaryTruePositives(Metric):
+ def __init__(self, name='binary-true-positives', dtype=dtypes.float64):
+ super(BinaryTruePositives, self).__init__(name=name, dtype=dtype)
+ self.true_positives = self.add_weight(
+ 'true_positives', initializer=init_ops.zeros_initializer)
+
+ @update_state
+ def update_state(self, y_true, y_pred, sample_weight=None):
+ y_true = math_ops.cast(y_true, dtypes.bool)
+ y_pred = math_ops.cast(y_pred, dtypes.bool)
+ y_pred, y_true, sample_weight = _squeeze_or_expand_dimensions(
+ y_pred, y_true, sample_weight)
+
+ values = math_ops.logical_and(
+ math_ops.equal(y_true, True), math_ops.equal(y_pred, True))
+ values = math_ops.cast(values, self._dtype)
+ if sample_weight is not None:
+ sample_weight = math_ops.cast(sample_weight, self._dtype)
+ values = math_ops.multiply(values, sample_weight)
+ state_ops.assign_add(self.true_positives, math_ops.reduce_sum(values))
+
+ @result
+ def result(self):
+ return array_ops.identity(self.true_positives)
+ ```
+ """
+ __metaclass__ = ABCMeta
+
+ def __init__(self, name=None, dtype=dtypes.float64):
+ super(Metric, self).__init__(name=name, dtype=dtype)
+ self.stateful = True # All metric layers are stateful.
+ self.built = True
+
+ def __call__(self, *args, **kwargs):
+ """Accumulates statistics and then computes metric result value.
+
+ Args:
+ *args:
+ **kwargs: A mini-batch of inputs to the Metric,
+ passed on to `update_state()`.
+
+ Returns:
+ The metric value tensor.
+ """
+ update_op = self.update_state(*args, **kwargs)
+ with ops.control_dependencies([update_op]):
+ return self.result()
+
+ def reset_states(self):
+ """Resets all of the metric state variables.
+
+ This function is called between epochs/steps,
+ when a metric is evaluated during training.
+ """
+ for v in self.variables:
+ K.set_value(v, 0)
+
+ @abstractmethod
+ def update_state(self, *args, **kwargs):
+ """Accumulates statistics for the metric.
+
+ Please decorate the function with:
+ @update_state: Converts `update_state()` into a graph function, so that
+ we can return a single op that performs all of the variable updates
+ This means:
+ a) Operations on the same resource are executed in textual order.
+ This should make it easier to do things like add the updated
+ value of a variable to another, for example.
+ b) You don't need to worry about collecting the update ops to execute.
+ All update ops added to the graph by this function will be executed.
+ As a result, code should generally work the same way with graph or
+ eager execution.
+ and adds the update op to the metric layer.
+
+ Args:
+ *args:
+ **kwargs: A mini-batch of inputs to the Metric.
+ """
+ NotImplementedError('Must be implemented in subclasses.')
+
+ @abstractmethod
+ def result(self):
+ """Computes and returns the metric value tensor.
+
+ Result computation is an idempotent operation that simply calculates the
+ metric value using the state variables.
+
+ Please decorate the function with:
+ @result: Wraps `result()` in a distribution strategy merge_call().
+ """
+ NotImplementedError('Must be implemented in subclasses.')
+
+ ### For use by subclasses ###
+ def add_weight(self,
+ name,
+ shape=(),
+ aggregation=vs.VariableAggregation.SUM,
+ synchronization=vs.VariableSynchronization.ON_READ,
+ initializer=None):
+ """Adds state variable. Only for use by subclasses."""
+ return super(Metric, self).add_weight(
+ name=name,
+ shape=shape,
+ dtype=self._dtype,
+ trainable=False,
+ initializer=initializer,
+ synchronization=synchronization,
+ aggregation=aggregation)
+
+ ### End: For use by subclasses ###
+
+
+class Mean(Metric):
+ """Computes the (weighted) mean of the given values.
+
+ This metric creates two variables, `total` and `count` that are used to
+ compute the average of `values`. This average is ultimately returned as `mean`
+ which is an idempotent operation that simply divides `total` by `count`.
+
+ If `sample_weight` is `None`, weights default to 1.
+ Use `sample_weight` of 0 to mask values.
+ """
+
+ def __init__(self, name='mean', dtype=dtypes.float64):
+ super(Mean, self).__init__(name=name, dtype=dtype)
+ # Create new state variables
+ self.total = self.add_weight(
+ 'total', initializer=init_ops.zeros_initializer)
+ self.count = self.add_weight(
+ 'count', initializer=init_ops.zeros_initializer)
+
+ @update_state
+ def update_state(self, values, sample_weight=None):
+ """Accumulates statistics for computing the mean.
+
+ For example, if `values` is [1, 3, 5, 7] then the mean is 4. If
+ the `sample_weight` is specified as [1, 1, 0, 0] then the mean would be 2.
+
+ Args:
+ values: Per-example value.
+ sample_weight: Optional weighting of each example. Defaults to 1.
+ """
+ values = math_ops.cast(values, self._dtype)
+ if sample_weight is None:
+ num_values = math_ops.cast(array_ops.size(values), self._dtype)
+ else:
+ sample_weight = math_ops.cast(sample_weight, self._dtype)
+
+ # Update dimensions of weights to match with values.
+ values, _, sample_weight = _squeeze_or_expand_dimensions(
+ values, None, sample_weight)
+ sample_weight = weights_broadcast_ops.broadcast_weights(
+ sample_weight, values)
+ num_values = math_ops.reduce_sum(sample_weight)
+ values = math_ops.multiply(values, sample_weight)
+ values = math_ops.reduce_sum(values)
+
+ # Update state variables
+ state_ops.assign_add(self.total, values)
+ state_ops.assign_add(self.count, num_values)
+
+ @result
+ def result(self):
+ return _safe_div(self.total, self.count)
+
+
@tf_export('keras.metrics.binary_accuracy')
def binary_accuracy(y_true, y_pred):
return K.mean(math_ops.equal(y_true, math_ops.round(y_pred)), axis=-1)
diff --git a/tensorflow/python/keras/metrics_test.py b/tensorflow/python/keras/metrics_test.py
index 15e793f5fc..6d8269f34d 100644
--- a/tensorflow/python/keras/metrics_test.py
+++ b/tensorflow/python/keras/metrics_test.py
@@ -18,67 +18,72 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import os
import numpy as np
-from tensorflow.python import keras
+from tensorflow.python.eager import context
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
+from tensorflow.python.keras import backend as K
+from tensorflow.python.keras import layers
+from tensorflow.python.keras import metrics
+from tensorflow.python.keras.engine.training import Model
+from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import variables
from tensorflow.python.platform import test
+from tensorflow.python.training.checkpointable import util as checkpointable_utils
class KerasMetricsTest(test.TestCase):
def test_metrics(self):
with self.test_session():
- y_a = keras.backend.variable(np.random.random((6, 7)))
- y_b = keras.backend.variable(np.random.random((6, 7)))
- for metric in [keras.metrics.binary_accuracy,
- keras.metrics.categorical_accuracy]:
+ y_a = K.variable(np.random.random((6, 7)))
+ y_b = K.variable(np.random.random((6, 7)))
+ for metric in [metrics.binary_accuracy, metrics.categorical_accuracy]:
output = metric(y_a, y_b)
- self.assertEqual(keras.backend.eval(output).shape, (6,))
+ self.assertEqual(K.eval(output).shape, (6,))
def test_sparse_categorical_accuracy(self):
with self.test_session():
- metric = keras.metrics.sparse_categorical_accuracy
- y_a = keras.backend.variable(np.random.randint(0, 7, (6,)))
- y_b = keras.backend.variable(np.random.random((6, 7)))
- self.assertEqual(keras.backend.eval(metric(y_a, y_b)).shape, (6,))
+ metric = metrics.sparse_categorical_accuracy
+ y_a = K.variable(np.random.randint(0, 7, (6,)))
+ y_b = K.variable(np.random.random((6, 7)))
+ self.assertEqual(K.eval(metric(y_a, y_b)).shape, (6,))
def test_sparse_top_k_categorical_accuracy(self):
with self.test_session():
- y_pred = keras.backend.variable(np.array([[0.3, 0.2, 0.1],
- [0.1, 0.2, 0.7]]))
- y_true = keras.backend.variable(np.array([[1], [0]]))
- result = keras.backend.eval(
- keras.metrics.sparse_top_k_categorical_accuracy(y_true, y_pred, k=3))
+ y_pred = K.variable(np.array([[0.3, 0.2, 0.1], [0.1, 0.2, 0.7]]))
+ y_true = K.variable(np.array([[1], [0]]))
+ result = K.eval(
+ metrics.sparse_top_k_categorical_accuracy(y_true, y_pred, k=3))
self.assertEqual(result, 1)
- result = keras.backend.eval(
- keras.metrics.sparse_top_k_categorical_accuracy(y_true, y_pred, k=2))
+ result = K.eval(
+ metrics.sparse_top_k_categorical_accuracy(y_true, y_pred, k=2))
self.assertEqual(result, 0.5)
- result = keras.backend.eval(
- keras.metrics.sparse_top_k_categorical_accuracy(y_true, y_pred, k=1))
+ result = K.eval(
+ metrics.sparse_top_k_categorical_accuracy(y_true, y_pred, k=1))
self.assertEqual(result, 0.)
def test_top_k_categorical_accuracy(self):
with self.test_session():
- y_pred = keras.backend.variable(np.array([[0.3, 0.2, 0.1],
- [0.1, 0.2, 0.7]]))
- y_true = keras.backend.variable(np.array([[0, 1, 0], [1, 0, 0]]))
- result = keras.backend.eval(
- keras.metrics.top_k_categorical_accuracy(y_true, y_pred, k=3))
+ y_pred = K.variable(np.array([[0.3, 0.2, 0.1], [0.1, 0.2, 0.7]]))
+ y_true = K.variable(np.array([[0, 1, 0], [1, 0, 0]]))
+ result = K.eval(metrics.top_k_categorical_accuracy(y_true, y_pred, k=3))
self.assertEqual(result, 1)
- result = keras.backend.eval(
- keras.metrics.top_k_categorical_accuracy(y_true, y_pred, k=2))
+ result = K.eval(metrics.top_k_categorical_accuracy(y_true, y_pred, k=2))
self.assertEqual(result, 0.5)
- result = keras.backend.eval(
- keras.metrics.top_k_categorical_accuracy(y_true, y_pred, k=1))
+ result = K.eval(metrics.top_k_categorical_accuracy(y_true, y_pred, k=1))
self.assertEqual(result, 0.)
def test_stateful_metrics(self):
with self.test_session():
np.random.seed(1334)
- class BinaryTruePositives(keras.layers.Layer):
+ class BinaryTruePositives(layers.Layer):
"""Stateful Metric to count the total true positives over all batches.
Assumes predictions and targets of shape `(samples, 1)`.
@@ -91,11 +96,11 @@ class KerasMetricsTest(test.TestCase):
def __init__(self, name='true_positives', **kwargs):
super(BinaryTruePositives, self).__init__(name=name, **kwargs)
- self.true_positives = keras.backend.variable(value=0, dtype='int32')
+ self.true_positives = K.variable(value=0, dtype='int32')
self.stateful = True
def reset_states(self):
- keras.backend.set_value(self.true_positives, 0)
+ K.set_value(self.true_positives, 0)
def __call__(self, y_true, y_pred):
"""Computes the number of true positives in a batch.
@@ -120,14 +125,14 @@ class KerasMetricsTest(test.TestCase):
return current_true_pos + true_pos
metric_fn = BinaryTruePositives()
- config = keras.metrics.serialize(metric_fn)
- metric_fn = keras.metrics.deserialize(
+ config = metrics.serialize(metric_fn)
+ metric_fn = metrics.deserialize(
config, custom_objects={'BinaryTruePositives': BinaryTruePositives})
# Test on simple model
- inputs = keras.Input(shape=(2,))
- outputs = keras.layers.Dense(1, activation='sigmoid')(inputs)
- model = keras.Model(inputs, outputs)
+ inputs = layers.Input(shape=(2,))
+ outputs = layers.Dense(1, activation='sigmoid')(inputs)
+ model = Model(inputs, outputs)
model.compile(optimizer='sgd',
loss='binary_crossentropy',
metrics=['acc', metric_fn])
@@ -184,6 +189,125 @@ class KerasMetricsTest(test.TestCase):
self.assertAllClose(
val_outs[2], history.history['val_true_positives'][-1], atol=1e-5)
+ @test_util.run_in_graph_and_eager_modes
+ def test_mean(self):
+ m = metrics.Mean(name='my_mean')
+
+ # check config
+ self.assertEqual(m.name, 'my_mean')
+ self.assertTrue(m.stateful)
+ self.assertEqual(m.dtype, dtypes.float64)
+ self.assertEqual(len(m.variables), 2)
+ self.evaluate(variables.global_variables_initializer())
+
+ # check initial state
+ self.assertEqual(self.evaluate(m.total), 0)
+ self.assertEqual(self.evaluate(m.count), 0)
+
+ # check __call__()
+ self.assertEqual(self.evaluate(m(100)), 100)
+ self.assertEqual(self.evaluate(m.total), 100)
+ self.assertEqual(self.evaluate(m.count), 1)
+
+ # check update_state() and result() + state accumulation + tensor input
+ update_op = m.update_state(ops.convert_n_to_tensor([1, 5]))
+ self.evaluate(update_op)
+ self.assertEqual(self.evaluate(m.result()), 106 / 3)
+ self.assertEqual(self.evaluate(m.total), 106) # 100 + 1 + 5
+ self.assertEqual(self.evaluate(m.count), 3)
+
+ # check reset_states()
+ m.reset_states()
+ self.assertEqual(self.evaluate(m.total), 0)
+ self.assertEqual(self.evaluate(m.count), 0)
+
+ @test_util.run_in_graph_and_eager_modes
+ def test_mean_with_sample_weight(self):
+ m = metrics.Mean()
+ self.evaluate(variables.global_variables_initializer())
+
+ # check scalar weight
+ result_t = m(100, sample_weight=0.5)
+ self.assertEqual(self.evaluate(result_t), 50 / 0.5)
+ self.assertEqual(self.evaluate(m.total), 50)
+ self.assertEqual(self.evaluate(m.count), 0.5)
+
+ # check weights not scalar and weights rank matches values rank
+ result_t = m([1, 5], sample_weight=[1, 0.2])
+ result = self.evaluate(result_t)
+ self.assertAlmostEqual(result, 52 / 1.7, 2)
+ self.assertAlmostEqual(self.evaluate(m.total), 52, 2) # 50 + 1 + 5 * 0.2
+ self.assertAlmostEqual(self.evaluate(m.count), 1.7, 2) # 0.5 + 1.2
+
+ # check weights broadcast
+ result_t = m([1, 2], sample_weight=0.5)
+ self.assertAlmostEqual(self.evaluate(result_t), 53.5 / 2.7, 2)
+ self.assertAlmostEqual(self.evaluate(m.total), 53.5, 2) # 52 + 0.5 + 1
+ self.assertAlmostEqual(self.evaluate(m.count), 2.7, 2) # 1.7 + 0.5 + 0.5
+
+ # check weights squeeze
+ result_t = m([1, 5], sample_weight=[[1], [0.2]])
+ self.assertAlmostEqual(self.evaluate(result_t), 55.5 / 3.9, 2)
+ self.assertAlmostEqual(self.evaluate(m.total), 55.5, 2) # 53.5 + 1 + 1
+ self.assertAlmostEqual(self.evaluate(m.count), 3.9, 2) # 2.7 + 1.2
+
+ # check weights expand
+ result_t = m([[1], [5]], sample_weight=[1, 0.2])
+ self.assertAlmostEqual(self.evaluate(result_t), 57.5 / 5.1, 2)
+ self.assertAlmostEqual(self.evaluate(m.total), 57.5, 2) # 55.5 + 1 + 1
+ self.assertAlmostEqual(self.evaluate(m.count), 5.1, 2) # 3.9 + 1.2
+
+ def test_mean_graph_with_placeholder(self):
+ with context.graph_mode(), self.test_session() as sess:
+ m = metrics.Mean()
+ v = array_ops.placeholder(dtypes.float32)
+ w = array_ops.placeholder(dtypes.float32)
+ sess.run(variables.global_variables_initializer())
+
+ # check __call__()
+ result_t = m(v, sample_weight=w)
+ result = sess.run(result_t, feed_dict=({v: 100, w: 0.5}))
+ self.assertEqual(sess.run(m.total), 50)
+ self.assertEqual(sess.run(m.count), 0.5)
+ self.assertEqual(result, 50 / 0.5)
+
+ # check update_state() and result()
+ result = sess.run(result_t, feed_dict=({v: [1, 5], w: [1, 0.2]}))
+ self.assertAlmostEqual(sess.run(m.total), 52, 2) # 50 + 1 + 5 * 0.2
+ self.assertAlmostEqual(sess.run(m.count), 1.7, 2) # 0.5 + 1.2
+ self.assertAlmostEqual(result, 52 / 1.7, 2)
+
+ @test_util.run_in_graph_and_eager_modes
+ def test_save_restore(self):
+ checkpoint_directory = self.get_temp_dir()
+ checkpoint_prefix = os.path.join(checkpoint_directory, 'ckpt')
+ m = metrics.Mean()
+ checkpoint = checkpointable_utils.Checkpoint(mean=m)
+ self.evaluate(variables.global_variables_initializer())
+
+ # update state
+ self.evaluate(m(100.))
+ self.evaluate(m(200.))
+
+ # save checkpoint and then add an update
+ save_path = checkpoint.save(checkpoint_prefix)
+ self.evaluate(m(1000.))
+
+ # restore to the same checkpoint mean object
+ checkpoint.restore(save_path).assert_consumed().run_restore_ops()
+ self.evaluate(m(300.))
+ self.assertEqual(200., self.evaluate(m.result()))
+
+ # restore to a different checkpoint mean object
+ restore_mean = metrics.Mean()
+ restore_checkpoint = checkpointable_utils.Checkpoint(mean=restore_mean)
+ status = restore_checkpoint.restore(save_path)
+ restore_update = restore_mean(300.)
+ status.assert_consumed().run_restore_ops()
+ self.evaluate(restore_update)
+ self.assertEqual(200., self.evaluate(restore_mean.result()))
+ self.assertEqual(3, self.evaluate(restore_mean.count))
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/kernel_tests/resource_variable_ops_test.py b/tensorflow/python/kernel_tests/resource_variable_ops_test.py
index e358293a90..c739cd2c0d 100644
--- a/tensorflow/python/kernel_tests/resource_variable_ops_test.py
+++ b/tensorflow/python/kernel_tests/resource_variable_ops_test.py
@@ -246,6 +246,15 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
self.assertEqual(self.evaluate(read), [[2]])
+ def testUseResource(self):
+ v = variables.Variable(1.0, use_resource=True)
+ self.assertTrue(isinstance(v, resource_variable_ops.ResourceVariable))
+
+ def testEagerNoUseResource(self):
+ with context.eager_mode():
+ v = variables.Variable(1.0)
+ self.assertTrue(isinstance(v, resource_variable_ops.ResourceVariable))
+
@test_util.run_in_graph_and_eager_modes
def testScatterMin(self):
with ops.device("cpu:0"):
diff --git a/tensorflow/python/ops/resource_variable_ops.py b/tensorflow/python/ops/resource_variable_ops.py
index 1f56ad25bf..5979b76ff2 100644
--- a/tensorflow/python/ops/resource_variable_ops.py
+++ b/tensorflow/python/ops/resource_variable_ops.py
@@ -1294,3 +1294,16 @@ def is_resource_variable(var):
""""Returns True if `var` is to be considered a ResourceVariable."""
return isinstance(var, ResourceVariable) or hasattr(
var, "_should_act_as_resource_variable")
+
+
+_DEFAULT_USE_RESOURCE = False
+
+
+def _default_variable_creator(_, *args, **kwds):
+ use_resource = kwds.pop("use_resource", _DEFAULT_USE_RESOURCE)
+ use_resource = use_resource or context.executing_eagerly()
+ if use_resource:
+ return ResourceVariable(*args, **kwds)
+ return variables.RefVariable(*args, **kwds)
+
+variables.default_variable_creator = _default_variable_creator
diff --git a/tensorflow/python/ops/variable_scope.py b/tensorflow/python/ops/variable_scope.py
index 77f67c18ee..0f37dcc027 100644
--- a/tensorflow/python/ops/variable_scope.py
+++ b/tensorflow/python/ops/variable_scope.py
@@ -191,36 +191,9 @@ class _ReuseMode(enum.Enum):
# REUSE_TRUE = 3
-@tf_export("VariableSynchronization")
-class VariableSynchronization(enum.Enum):
- """Indicates when a distributed variable will be synced."""
-
- # Indicates that the synchronization will be determined by the current
- # `DistributionStrategy` (eg. With `MirroredStrategy` this would be
- # `ON_WRITE`).
- AUTO = 0
-
- # Indicates that there will only be one copy of the variable, so there is no
- # need to sync.
- NONE = 1
-
- # Indicates that the variable will be aggregated across devices
- # every time it is updated.
- ON_WRITE = 2
-
- # Indicates that the variable will be aggregated across devices
- # when it is read (eg. when checkpointing or when evaluating an op that uses
- # the variable).
- ON_READ = 3
-
-
-@tf_export("VariableAggregation")
-class VariableAggregation(enum.Enum):
- """Indicates how a distributed variable will be aggregated."""
- NONE = 0
- SUM = 1
- MEAN = 2
-
+# TODO(apassos) remove these forwarding symbols.
+VariableSynchronization = variables.VariableSynchronization # pylint: disable=invalid-name
+VariableAggregation = variables.VariableAggregation # pylint: disable=invalid-name
AUTO_REUSE = _ReuseMode.AUTO_REUSE
tf_export("AUTO_REUSE").export_constant(__name__, "AUTO_REUSE")
diff --git a/tensorflow/python/ops/variables.py b/tensorflow/python/ops/variables.py
index 87e0de197c..6bb2d6f669 100644
--- a/tensorflow/python/ops/variables.py
+++ b/tensorflow/python/ops/variables.py
@@ -17,6 +17,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import enum # pylint: disable=g-bad-import-order
+
import six
from tensorflow.core.framework import attr_value_pb2
@@ -38,8 +40,9 @@ from tensorflow.python.util.deprecation import deprecated
from tensorflow.python.util.tf_export import tf_export
-def _default_variable_creator(_, *args, **kwds):
- return RefVariable(*args, **kwds)
+def default_variable_creator(_, *args, **kwds):
+ del args, kwds
+ raise NotImplementedError("resource_variable_ops needs to be imported")
def _make_getter(captured_getter, captured_previous):
@@ -49,12 +52,43 @@ def _make_getter(captured_getter, captured_previous):
return getter
+@tf_export("VariableSynchronization")
+class VariableSynchronization(enum.Enum):
+ """Indicates when a distributed variable will be synced."""
+
+ # Indicates that the synchronization will be determined by the current
+ # `DistributionStrategy` (eg. With `MirroredStrategy` this would be
+ # `ON_WRITE`).
+ AUTO = 0
+
+ # Indicates that there will only be one copy of the variable, so there is no
+ # need to sync.
+ NONE = 1
+
+ # Indicates that the variable will be aggregated across devices
+ # every time it is updated.
+ ON_WRITE = 2
+
+ # Indicates that the variable will be aggregated across devices
+ # when it is read (eg. when checkpointing or when evaluating an op that uses
+ # the variable).
+ ON_READ = 3
+
+
+@tf_export("VariableAggregation")
+class VariableAggregation(enum.Enum):
+ """Indicates how a distributed variable will be aggregated."""
+ NONE = 0
+ SUM = 1
+ MEAN = 2
+
+
class VariableMetaclass(type):
"""Metaclass to allow construction of tf.Variable to be overridden."""
def __call__(cls, *args, **kwargs):
if cls is Variable:
- previous_getter = lambda *a, **k: _default_variable_creator(None, *a, **k)
+ previous_getter = lambda *a, **k: default_variable_creator(None, *a, **k)
# TODO(apassos) use a stack of getters here
return previous_getter(*args, **kwargs)
else:
@@ -172,14 +206,6 @@ class Variable(six.with_metaclass(VariableMetaclass,
* Replace `tf.Variable` with `tf.contrib.eager.Variable`;
* Call `tf.get_variable_scope().set_use_resource(True)` inside a
`tf.variable_scope` before the `tf.get_variable()` call.
-
- @compatibility(eager)
- `tf.Variable` is not compatible with eager execution. Use
- `tf.contrib.eager.Variable` instead which is compatible with both eager
- execution and graph construction. See [the TensorFlow Eager Execution
- guide](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/eager/python/g3doc/guide.md#variables-and-optimizers)
- for details on how variables work in eager execution.
- @end_compatibility
"""
def __init__(self,
@@ -193,7 +219,10 @@ class Variable(six.with_metaclass(VariableMetaclass,
dtype=None,
expected_shape=None,
import_scope=None,
- constraint=None):
+ constraint=None,
+ use_resource=None,
+ synchronization=VariableSynchronization.AUTO,
+ aggregation=VariableAggregation.NONE):
"""Creates a new variable with value `initial_value`.
The new variable is added to the graph collections listed in `collections`,
@@ -245,20 +274,24 @@ class Variable(six.with_metaclass(VariableMetaclass,
variable and return the Tensor for the projected value
(which must have the same shape). Constraints are not safe to
use when doing asynchronous distributed training.
+ use_resource: if True, a ResourceVariable is created; otherwise an
+ old-style ref-based variable is created. When eager execution is enabled
+ a resource variable is always created.
+ synchronization: Indicates when a distributed a variable will be
+ aggregated. Accepted values are constants defined in the class
+ @{tf.VariableSynchronization}. By default the synchronization is set to
+ `AUTO` and the current `DistributionStrategy` chooses
+ when to synchronize. If `synchronization` is set to `ON_READ`,
+ `trainable` must not be set to `True`.
+ aggregation: Indicates how a distributed variable will be aggregated.
+ Accepted values are constants defined in the class
+ @{tf.VariableAggregation}.
Raises:
ValueError: If both `variable_def` and initial_value are specified.
ValueError: If the initial value is not specified, or does not have a
shape and `validate_shape` is `True`.
RuntimeError: If eager execution is enabled.
-
- @compatibility(eager)
- `tf.Variable` is not compatible with eager execution. Use
- `tfe.Variable` instead which is compatible with both eager execution
- and graph construction. See [the TensorFlow Eager Execution
- guide](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/eager/python/g3doc/guide.md#variables-and-optimizers)
- for details on how variables work in eager execution.
- @end_compatibility
"""
raise NotImplementedError
@@ -1714,7 +1747,7 @@ class PartitionedVariable(object):
"""A container for partitioned `Variable` objects.
@compatibility(eager) `tf.PartitionedVariable` is not compatible with
- eager execution. Use `tfe.Variable` instead which is compatible
+ eager execution. Use `tf.Variable` instead which is compatible
with both eager execution and graph construction. See [the
TensorFlow Eager Execution
guide](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/eager/python/g3doc/guide.md#variables-and-optimizers)
diff --git a/tensorflow/python/platform/gfile.py b/tensorflow/python/platform/gfile.py
index fd697d70bf..45de047894 100644
--- a/tensorflow/python/platform/gfile.py
+++ b/tensorflow/python/platform/gfile.py
@@ -38,7 +38,14 @@ from tensorflow.python.util.tf_export import tf_export
@tf_export('gfile.GFile', 'gfile.Open')
class GFile(_FileIO):
- """File I/O wrappers without thread locking."""
+ """File I/O wrappers without thread locking.
+
+ Note, that this is somewhat like builtin Python file I/O, but
+ there are semantic differences to make it more efficient for
+ some backing filesystems. For example, a write mode file will
+ not be opened until the first write call (to minimize RPC
+ invocations in network filesystems).
+ """
def __init__(self, name, mode='r'):
super(GFile, self).__init__(name=name, mode=mode)
@@ -46,7 +53,14 @@ class GFile(_FileIO):
@tf_export('gfile.FastGFile')
class FastGFile(_FileIO):
- """File I/O wrappers without thread locking."""
+ """File I/O wrappers without thread locking.
+
+ Note, that this is somewhat like builtin Python file I/O, but
+ there are semantic differences to make it more efficient for
+ some backing filesystems. For example, a write mode file will
+ not be opened until the first write call (to minimize RPC
+ invocations in network filesystems).
+ """
def __init__(self, name, mode='r'):
super(FastGFile, self).__init__(name=name, mode=mode)
diff --git a/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc b/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc
index f11022ef1d..259c813c57 100644
--- a/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc
+++ b/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc
@@ -844,7 +844,7 @@ CUDAExecutor::GetTimerImplementation() {
return std::unique_ptr<internal::TimerInterface>(new CUDATimer(this));
}
-void *CUDAExecutor::CudaContextHack() { return context_; }
+void *CUDAExecutor::GpuContextHack() { return context_; }
CudaContext* CUDAExecutor::cuda_context() { return context_; }
diff --git a/tensorflow/stream_executor/cuda/cuda_gpu_executor.h b/tensorflow/stream_executor/cuda/cuda_gpu_executor.h
index 773cbfb8a1..f7c341c857 100644
--- a/tensorflow/stream_executor/cuda/cuda_gpu_executor.h
+++ b/tensorflow/stream_executor/cuda/cuda_gpu_executor.h
@@ -210,7 +210,7 @@ class CUDAExecutor : public internal::StreamExecutorInterface {
std::unique_ptr<internal::TimerInterface> GetTimerImplementation() override;
- void *CudaContextHack() override;
+ void *GpuContextHack() override;
CudaContext* cuda_context();
diff --git a/tensorflow/stream_executor/cuda/cuda_stream.h b/tensorflow/stream_executor/cuda/cuda_stream.h
index 02edff6431..bb8bda4755 100644
--- a/tensorflow/stream_executor/cuda/cuda_stream.h
+++ b/tensorflow/stream_executor/cuda/cuda_stream.h
@@ -40,8 +40,8 @@ class CUDAStream : public internal::StreamInterface {
// Note: teardown is handled by a parent's call to DeallocateStream.
~CUDAStream() override {}
- void *CudaStreamHack() override { return cuda_stream_; }
- void **CudaStreamMemberHack() override {
+ void *GpuStreamHack() override { return cuda_stream_; }
+ void **GpuStreamMemberHack() override {
return reinterpret_cast<void **>(&cuda_stream_);
}
diff --git a/tensorflow/stream_executor/host/host_gpu_executor.h b/tensorflow/stream_executor/host/host_gpu_executor.h
index e82f57569f..858396ef96 100644
--- a/tensorflow/stream_executor/host/host_gpu_executor.h
+++ b/tensorflow/stream_executor/host/host_gpu_executor.h
@@ -202,7 +202,7 @@ class HostExecutor : public internal::StreamExecutorInterface {
return std::unique_ptr<internal::TimerInterface>(new HostTimer());
}
- void *CudaContextHack() override { return nullptr; }
+ void *GpuContextHack() override { return nullptr; }
private:
const PluginConfig plugin_config_;
diff --git a/tensorflow/stream_executor/host/host_stream.h b/tensorflow/stream_executor/host/host_stream.h
index 5d7b8a3782..be88f074cf 100644
--- a/tensorflow/stream_executor/host/host_stream.h
+++ b/tensorflow/stream_executor/host/host_stream.h
@@ -34,8 +34,8 @@ class HostStream : public internal::StreamInterface {
bool EnqueueTask(std::function<void()> task);
- void *CudaStreamHack() override { return nullptr; }
- void **CudaStreamMemberHack() override { return nullptr; }
+ void *GpuStreamHack() override { return nullptr; }
+ void **GpuStreamMemberHack() override { return nullptr; }
void BlockUntilDone();
diff --git a/tensorflow/stream_executor/stream_executor_internal.h b/tensorflow/stream_executor/stream_executor_internal.h
index 9c989b971d..fb1b92cb84 100644
--- a/tensorflow/stream_executor/stream_executor_internal.h
+++ b/tensorflow/stream_executor/stream_executor_internal.h
@@ -100,19 +100,20 @@ class StreamInterface {
// Default destructor for the abstract interface.
virtual ~StreamInterface() {}
- // Returns the CUDA stream associated with this platform's stream
+ // Returns the GPU stream associated with this platform's stream
// implementation.
//
- // WARNING: checks that the underlying platform is, in fact, CUDA, causing a
- // fatal error if it is not. This hack is made available solely for use from
- // distbelief code, which temporarily has strong ties to CUDA as a platform.
- virtual void *CudaStreamHack() { return nullptr; }
-
- // See the above comment on CudaStreamHack -- this further breaks abstraction
- // for Eigen within distbelief, which has strong ties to CUDA as a platform,
- // and a historical attachment to a programming model which takes a
+ // WARNING: checks that the underlying platform is, in fact, CUDA or ROCm,
+ // causing a fatal error if it is not. This hack is made available solely for
+ // use from distbelief code, which temporarily has strong ties to CUDA or
+ // ROCm as a platform.
+ virtual void *GpuStreamHack() { return nullptr; }
+
+ // See the above comment on GpuStreamHack -- this further breaks abstraction
+ // for Eigen within distbelief, which has strong ties to CUDA or ROCm as a
+ // platform, and a historical attachment to a programming model which takes a
// stream-slot rather than a stream-value.
- virtual void **CudaStreamMemberHack() { return nullptr; }
+ virtual void **GpuStreamMemberHack() { return nullptr; }
private:
SE_DISALLOW_COPY_AND_ASSIGN(StreamInterface);
@@ -324,13 +325,14 @@ class StreamExecutorInterface {
virtual std::unique_ptr<StreamInterface> GetStreamImplementation() = 0;
virtual std::unique_ptr<TimerInterface> GetTimerImplementation() = 0;
- // Returns the CUDA context associated with this StreamExecutor platform
- // implementation.
+ // Returns the CUDA or ROCm context associated with this StreamExecutor
+ // platform implementation.
//
- // WARNING: checks that the underlying platform is, in fact, CUDA, causing a
- // fatal error if it is not. This hack is made available solely for use from
- // distbelief code, which temporarily has strong ties to CUDA as a platform.
- virtual void *CudaContextHack() { return nullptr; }
+ // WARNING: checks that the underlying platform is, in fact, CUDA or ROCm,
+ // causing a fatal error if it is not. This hack is made available solely for
+ // use from distbelief code, which temporarily has strong ties to CUDA or ROCm
+ // as a platform.
+ virtual void *GpuContextHack() { return nullptr; }
private:
SE_DISALLOW_COPY_AND_ASSIGN(StreamExecutorInterface);
diff --git a/tensorflow/tools/api/golden/tensorflow.-variable.pbtxt b/tensorflow/tools/api/golden/tensorflow.-variable.pbtxt
index 23b552cc38..e841c4ad89 100644
--- a/tensorflow/tools/api/golden/tensorflow.-variable.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.-variable.pbtxt
@@ -49,7 +49,7 @@ tf_class {
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'initial_value\', \'trainable\', \'collections\', \'validate_shape\', \'caching_device\', \'name\', \'variable_def\', \'dtype\', \'expected_shape\', \'import_scope\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'True\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'initial_value\', \'trainable\', \'collections\', \'validate_shape\', \'caching_device\', \'name\', \'variable_def\', \'dtype\', \'expected_shape\', \'import_scope\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=None, defaults=[\'None\', \'True\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "assign"
diff --git a/tensorflow/tools/ci_build/ci_sanity.sh b/tensorflow/tools/ci_build/ci_sanity.sh
index db37edf809..866fe95d2b 100755
--- a/tensorflow/tools/ci_build/ci_sanity.sh
+++ b/tensorflow/tools/ci_build/ci_sanity.sh
@@ -354,7 +354,7 @@ do_external_licenses_check(){
# Whitelist
echo ${EXTRA_LICENSE_FILE}
- grep -e "@bazel_tools//src" -e "@bazel_tools//tools/" -e "@com_google_absl//" -e "//external" -e "@local" -e "@com_github_googlecloudplatform_google_cloud_cpp//" -v ${EXTRA_LICENSES_FILE} > temp.txt
+ grep -e "@bazel_tools//src" -e "@bazel_tools//tools/" -e "@com_google_absl//" -e "//external" -e "@local" -e "@com_github_googlecloudplatform_google_cloud_cpp//" -e "@embedded_jdk//" -v ${EXTRA_LICENSES_FILE} > temp.txt
mv temp.txt ${EXTRA_LICENSES_FILE}
diff --git a/tensorflow/tools/ci_build/install/install_bazel.sh b/tensorflow/tools/ci_build/install/install_bazel.sh
index adbff8f6ef..e284401b8a 100755
--- a/tensorflow/tools/ci_build/install/install_bazel.sh
+++ b/tensorflow/tools/ci_build/install/install_bazel.sh
@@ -15,7 +15,7 @@
# ==============================================================================
# Select bazel version.
-BAZEL_VERSION="0.14.1"
+BAZEL_VERSION="0.15.0"
set +e
local_bazel_ver=$(bazel version 2>&1 | grep -i label | awk '{print $3}')
diff --git a/tensorflow/tools/ci_build/install/install_bazel_from_source.sh b/tensorflow/tools/ci_build/install/install_bazel_from_source.sh
index 9d24b3e421..87be81577d 100755
--- a/tensorflow/tools/ci_build/install/install_bazel_from_source.sh
+++ b/tensorflow/tools/ci_build/install/install_bazel_from_source.sh
@@ -18,7 +18,7 @@
# It will compile bazel from source and install it in /usr/local/bin
# Select bazel version.
-BAZEL_VERSION="0.14.1"
+BAZEL_VERSION="0.15.0"
set +e
local_bazel_ver=$(bazel version 2>&1 | grep -i label | awk '{print $3}')
diff --git a/tensorflow/tools/ci_build/windows/bazel/bazel_test_lib.sh b/tensorflow/tools/ci_build/windows/bazel/bazel_test_lib.sh
index c03cbd9c66..0482cf619a 100644
--- a/tensorflow/tools/ci_build/windows/bazel/bazel_test_lib.sh
+++ b/tensorflow/tools/ci_build/windows/bazel/bazel_test_lib.sh
@@ -33,10 +33,10 @@ function set_remote_cache_options {
echo "build --tls_enabled=true" >> "${TMP_BAZELRC}"
echo "build --remote_timeout=3600" >> "${TMP_BAZELRC}"
echo "build --auth_enabled=true" >> "${TMP_BAZELRC}"
- echo "build --spawn_strategy=remote" >> "${TMP_BAZELRC}"
- echo "build --strategy=Javac=remote" >> "${TMP_BAZELRC}"
- echo "build --strategy=Closure=remote" >> "${TMP_BAZELRC}"
- echo "build --genrule_strategy=remote" >> "${TMP_BAZELRC}"
+ echo "build --spawn_strategy=standalone" >> "${TMP_BAZELRC}"
+ echo "build --strategy=Javac=standalone" >> "${TMP_BAZELRC}"
+ echo "build --strategy=Closure=standalone" >> "${TMP_BAZELRC}"
+ echo "build --genrule_strategy=standalone" >> "${TMP_BAZELRC}"
echo "build --google_credentials=$GOOGLE_CLOUD_CREDENTIAL" >> "${TMP_BAZELRC}"
}
diff --git a/tensorflow/tools/docker/Dockerfile.devel b/tensorflow/tools/docker/Dockerfile.devel
index fd94d64268..f7fe4119da 100644
--- a/tensorflow/tools/docker/Dockerfile.devel
+++ b/tensorflow/tools/docker/Dockerfile.devel
@@ -63,7 +63,7 @@ RUN echo "startup --batch" >>/etc/bazel.bazelrc
RUN echo "build --spawn_strategy=standalone --genrule_strategy=standalone" \
>>/etc/bazel.bazelrc
# Install the most recent bazel release.
-ENV BAZEL_VERSION 0.14.1
+ENV BAZEL_VERSION 0.15.0
WORKDIR /
RUN mkdir /bazel && \
cd /bazel && \
diff --git a/tensorflow/tools/docker/Dockerfile.devel-gpu b/tensorflow/tools/docker/Dockerfile.devel-gpu
index 44120bf274..957a7ed799 100644
--- a/tensorflow/tools/docker/Dockerfile.devel-gpu
+++ b/tensorflow/tools/docker/Dockerfile.devel-gpu
@@ -83,7 +83,7 @@ RUN echo "startup --batch" >>/etc/bazel.bazelrc
RUN echo "build --spawn_strategy=standalone --genrule_strategy=standalone" \
>>/etc/bazel.bazelrc
# Install the most recent bazel release.
-ENV BAZEL_VERSION 0.14.1
+ENV BAZEL_VERSION 0.15.0
WORKDIR /
RUN mkdir /bazel && \
cd /bazel && \
diff --git a/tensorflow/tools/docker/Dockerfile.devel-gpu-cuda9-cudnn7 b/tensorflow/tools/docker/Dockerfile.devel-gpu-cuda9-cudnn7
index 3bedc8cf34..30bc2d2806 100644
--- a/tensorflow/tools/docker/Dockerfile.devel-gpu-cuda9-cudnn7
+++ b/tensorflow/tools/docker/Dockerfile.devel-gpu-cuda9-cudnn7
@@ -4,7 +4,7 @@ LABEL maintainer="Gunhan Gulsoy <gunan@google.com>"
# It is possible to override these for releases.
ARG TF_BRANCH=master
-ARG BAZEL_VERSION=0.5.4
+ARG BAZEL_VERSION=0.15.0
ARG TF_AVAILABLE_CPUS=32
RUN apt-get update && apt-get install -y --no-install-recommends \
diff --git a/tensorflow/tools/graph_transforms/transform_utils.cc b/tensorflow/tools/graph_transforms/transform_utils.cc
index af17fd75bc..cb084e49b7 100644
--- a/tensorflow/tools/graph_transforms/transform_utils.cc
+++ b/tensorflow/tools/graph_transforms/transform_utils.cc
@@ -247,9 +247,16 @@ Status SortByExecutionOrder(const GraphDef& input_graph_def,
}
}
- if (processed < input_graph_def.node_size()) {
- return errors::InvalidArgument(input_graph_def.node_size() - processed,
- " nodes in a cycle");
+ if (processed < num_nodes) {
+ LOG(WARNING) << "IN " << __func__ << (num_nodes - processed)
+ << " NODES IN A CYCLE";
+ for (int64 i = 0; i < num_nodes; i++) {
+ if (pending_count[i] != 0) {
+ LOG(WARNING) << "PENDING: " << SummarizeNodeDef(input_graph_def.node(i))
+ << "WITH PENDING COUNT = " << pending_count[i];
+ }
+ }
+ return errors::InvalidArgument(num_nodes - processed, " nodes in a cycle");
}
return Status::OK();
}
diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl
index 378de4261c..4b4f31813c 100644
--- a/tensorflow/workspace.bzl
+++ b/tensorflow/workspace.bzl
@@ -487,11 +487,11 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
tf_http_archive(
name = "llvm",
urls = [
- "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/bd8c8d759852871609ba2e4e79868420f751949d.tar.gz",
- "https://github.com/llvm-mirror/llvm/archive/bd8c8d759852871609ba2e4e79868420f751949d.tar.gz",
+ "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/10c3b3d15ed6a788ac12221b784caf81fb8248b5.tar.gz",
+ "https://github.com/llvm-mirror/llvm/archive/10c3b3d15ed6a788ac12221b784caf81fb8248b5.tar.gz",
],
- sha256 = "0c63e8583b213543309e8577ffe87a0cf34cc22269630d2c5c2f0a2345fda4a8",
- strip_prefix = "llvm-bd8c8d759852871609ba2e4e79868420f751949d",
+ sha256 = "a9feb6b47267c30fd7c19ebfdf4dbde6757054f716fa77c09bcb1106799c3253",
+ strip_prefix = "llvm-10c3b3d15ed6a788ac12221b784caf81fb8248b5",
build_file = clean_dep("//third_party/llvm:llvm.autogenerated.BUILD"),
)
diff --git a/third_party/examples/eager/spinn/spinn.py b/third_party/examples/eager/spinn/spinn.py
index 67456a5bdf..c242ef3fdd 100644
--- a/third_party/examples/eager/spinn/spinn.py
+++ b/third_party/examples/eager/spinn/spinn.py
@@ -419,7 +419,7 @@ class SNLIClassifierTrainer(tfe.Checkpointable):
# Create a custom learning rate Variable for the RMSProp optimizer, because
# the learning rate needs to be manually decayed later (see
# decay_learning_rate()).
- self._learning_rate = tfe.Variable(lr, name="learning_rate")
+ self._learning_rate = tf.Variable(lr, name="learning_rate")
self._optimizer = tf.train.RMSPropOptimizer(self._learning_rate,
epsilon=1e-6)