aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar Yifei Feng <yifeif@google.com>2018-02-01 17:58:54 -0800
committerGravatar Yifei Feng <yifeif@google.com>2018-02-01 17:58:54 -0800
commit7ef914f41f1b376eacf41ba99a78491190c3a949 (patch)
tree186d6b07e8827e682a278e97694a4d7100509b0e /tensorflow
parent73019bc43d81c781b591407f97f409b8570c6115 (diff)
parentff81ca3d1303ec3ad178113a3398f8f1cac0304d (diff)
Merge commit for internal changes
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/BUILD1
-rw-r--r--tensorflow/compiler/aot/tfcompile.bzl83
-rw-r--r--tensorflow/compiler/tests/BUILD13
-rw-r--r--tensorflow/compiler/tests/binary_ops_test.py44
-rw-r--r--tensorflow/compiler/tests/matrix_band_part_test.py64
-rw-r--r--tensorflow/compiler/tf2xla/kernels/BUILD2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/matrix_band_part_op.cc98
-rw-r--r--tensorflow/compiler/tf2xla/kernels/matrix_set_diag_op.cc93
-rw-r--r--tensorflow/compiler/xla/python/local_computation_builder.cc56
-rw-r--r--tensorflow/compiler/xla/python/local_computation_builder.h17
-rw-r--r--tensorflow/compiler/xla/python/local_computation_builder.i197
-rw-r--r--tensorflow/compiler/xla/python/numpy_bridge.cc138
-rw-r--r--tensorflow/compiler/xla/python/numpy_bridge.h10
-rw-r--r--tensorflow/compiler/xla/python/xla_client.py156
-rw-r--r--tensorflow/compiler/xla/python/xla_client_test.py24
-rw-r--r--tensorflow/contrib/bayesflow/BUILD1
-rw-r--r--tensorflow/contrib/cmake/external/protobuf.cmake2
-rw-r--r--tensorflow/contrib/compiler/jit_test.py4
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py21
-rw-r--r--tensorflow/contrib/eager/python/examples/mnist/mnist.py3
-rw-r--r--tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py10
-rw-r--r--tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py79
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/estimator.py9
-rw-r--r--tensorflow/contrib/lite/interpreter.cc69
-rw-r--r--tensorflow/contrib/lite/interpreter.h19
-rw-r--r--tensorflow/contrib/lite/interpreter_test.cc127
-rw-r--r--tensorflow/contrib/lite/kernels/BUILD1
-rw-r--r--tensorflow/contrib/lite/kernels/conv.cc122
-rw-r--r--tensorflow/contrib/lite/kernels/conv_test.cc70
-rw-r--r--tensorflow/contrib/lite/kernels/test_util.cc11
-rw-r--r--tensorflow/contrib/lite/kernels/test_util.h52
-rwxr-xr-xtensorflow/contrib/lite/lib_package/create_ios_frameworks.sh81
-rw-r--r--tensorflow/contrib/lite/toco/export_tensorflow.cc3
-rw-r--r--tensorflow/contrib/lite/tools/BUILD4
-rw-r--r--tensorflow/contrib/lite/tools/verifier.cc170
-rw-r--r--tensorflow/contrib/lite/tools/verifier.h4
-rw-r--r--tensorflow/contrib/lite/tools/verifier_test.cc200
-rw-r--r--tensorflow/contrib/py2tf/BUILD1
-rw-r--r--tensorflow/contrib/py2tf/__init__.py3
-rw-r--r--tensorflow/contrib/py2tf/converters/BUILD1
-rw-r--r--tensorflow/contrib/py2tf/converters/side_effect_guards.py19
-rw-r--r--tensorflow/contrib/py2tf/converters/side_effect_guards_test.py2
-rw-r--r--tensorflow/contrib/py2tf/impl/config.py3
-rw-r--r--tensorflow/contrib/py2tf/utils/BUILD37
-rw-r--r--tensorflow/contrib/py2tf/utils/__init__.py21
-rw-r--r--tensorflow/contrib/py2tf/utils/context_managers.py41
-rw-r--r--tensorflow/contrib/py2tf/utils/context_managers_test.py43
-rw-r--r--tensorflow/contrib/quantize/BUILD4
-rw-r--r--tensorflow/contrib/quantize/python/fold_batch_norms.py430
-rw-r--r--tensorflow/contrib/quantize/python/fold_batch_norms_test.py97
-rw-r--r--tensorflow/contrib/quantize/python/quantize_graph.py12
-rw-r--r--tensorflow/contrib/training/BUILD23
-rw-r--r--tensorflow/contrib/training/python/training/tensor_queue_dataset.py200
-rw-r--r--tensorflow/contrib/training/python/training/tensor_queue_dataset_test.py355
-rw-r--r--tensorflow/core/api_def/base_api/api_def_EnqueueInQueueDataset.pbtxt3
-rw-r--r--tensorflow/core/api_def/base_api/api_def_PrependFromQueueAndPaddedBatchDataset.pbtxt3
-rw-r--r--tensorflow/core/common_runtime/graph_execution_state.cc7
-rw-r--r--tensorflow/core/distributed_runtime/tensor_coding.cc4
-rw-r--r--tensorflow/core/grappler/costs/virtual_scheduler.cc7
-rw-r--r--tensorflow/core/grappler/graph_view.h1
-rw-r--r--tensorflow/core/grappler/optimizers/BUILD1
-rw-r--r--tensorflow/core/grappler/optimizers/memory_optimizer.cc67
-rw-r--r--tensorflow/core/grappler/utils/BUILD26
-rw-r--r--tensorflow/core/grappler/utils/traversal.cc80
-rw-r--r--tensorflow/core/grappler/utils/traversal.h39
-rw-r--r--tensorflow/core/grappler/utils/traversal_test.cc101
-rw-r--r--tensorflow/core/kernels/batch_util.cc113
-rw-r--r--tensorflow/core/kernels/batch_util.h10
-rw-r--r--tensorflow/core/kernels/data/BUILD15
-rw-r--r--tensorflow/core/kernels/data/iterator_ops.cc30
-rw-r--r--tensorflow/core/kernels/data/padded_batch_dataset_op.cc116
-rw-r--r--tensorflow/core/kernels/data/tensor_queue_dataset_op.cc646
-rw-r--r--tensorflow/core/kernels/gather_op.cc3
-rw-r--r--tensorflow/core/kernels/matrix_band_part_op.cc12
-rw-r--r--tensorflow/core/kernels/strided_slice_op.cc6
-rw-r--r--tensorflow/core/kernels/strided_slice_op_impl.h3
-rw-r--r--tensorflow/core/ops/array_ops.cc5
-rw-r--r--tensorflow/core/ops/compat/ops_history.v1.pbtxt96
-rw-r--r--tensorflow/core/ops/dataset_ops.cc25
-rw-r--r--tensorflow/core/ops/ops.pbtxt77
-rw-r--r--tensorflow/core/platform/cloud/gcs_throttle.h2
-rw-r--r--tensorflow/core/platform/cloud/gcs_throttle_test.cc2
-rw-r--r--tensorflow/docs_src/api_guides/python/TPUEstimator.md396
-rw-r--r--tensorflow/docs_src/install/install_sources.md2
-rw-r--r--tensorflow/docs_src/performance/performance_guide.md2
-rw-r--r--tensorflow/docs_src/programmers_guide/saved_model.md26
-rw-r--r--tensorflow/go/op/wrappers.go50
-rw-r--r--tensorflow/python/data/kernel_tests/BUILD3
-rw-r--r--tensorflow/python/data/kernel_tests/iterator_ops_cluster_test.py37
-rw-r--r--tensorflow/python/data/ops/dataset_ops.py25
-rw-r--r--tensorflow/python/data/util/nest.py4
-rw-r--r--tensorflow/python/data/util/nest_test.py4
-rw-r--r--tensorflow/python/eager/function.py95
-rw-r--r--tensorflow/python/eager/function_test.py72
-rwxr-xr-xtensorflow/python/keras/BUILD1
-rw-r--r--tensorflow/python/kernel_tests/array_ops_test.py26
-rw-r--r--tensorflow/python/kernel_tests/losses_test.py28
-rw-r--r--tensorflow/python/kernel_tests/matrix_band_part_op_test.py11
-rw-r--r--tensorflow/python/ops/losses/losses_impl.py2
-rw-r--r--tensorflow/python/util/nest.py4
-rw-r--r--tensorflow/python/util/nest_test.py4
-rw-r--r--tensorflow/tools/ci_build/windows/cpu/cmake/run_build.bat2
-rw-r--r--tensorflow/tools/ci_build/windows/gpu/cmake/run_build.bat2
-rw-r--r--tensorflow/tools/docs/pretty_docs.py14
-rw-r--r--tensorflow/tools/pip_package/setup.py4
105 files changed, 5041 insertions, 623 deletions
diff --git a/tensorflow/BUILD b/tensorflow/BUILD
index 2fa02a9b4c..c225cc1a74 100644
--- a/tensorflow/BUILD
+++ b/tensorflow/BUILD
@@ -553,6 +553,7 @@ filegroup(
"//tensorflow/contrib/py2tf/impl:all_files",
"//tensorflow/contrib/py2tf/pyct:all_files",
"//tensorflow/contrib/py2tf/pyct/static_analysis:all_files",
+ "//tensorflow/contrib/py2tf/utils:all_files",
"//tensorflow/contrib/quantize:all_files",
"//tensorflow/contrib/receptive_field:all_files",
"//tensorflow/contrib/reduce_slice_ops:all_files",
diff --git a/tensorflow/compiler/aot/tfcompile.bzl b/tensorflow/compiler/aot/tfcompile.bzl
index 2b9c83ba14..58572fea3d 100644
--- a/tensorflow/compiler/aot/tfcompile.bzl
+++ b/tensorflow/compiler/aot/tfcompile.bzl
@@ -4,7 +4,7 @@
To use from your BUILD file, add the following line to load the macro:
-load("@org_tensorflow//tensorflow/compiler/aot:tfcompile.bzl", "tf_library")
+load("//tensorflow/compiler/aot:tfcompile.bzl", "tf_library")
Then call the macro like this:
@@ -16,14 +16,15 @@ tf_library(
)
"""
-load("@org_tensorflow//tensorflow:tensorflow.bzl", "if_android", "tf_copts")
+load("//tensorflow:tensorflow.bzl",
+ "if_android", "tf_cc_test", "tf_copts")
def tf_library(name, graph, config,
freeze_checkpoint=None, freeze_saver=None,
cpp_class=None, gen_test=True, gen_benchmark=True,
visibility=None, testonly=None,
tfcompile_flags=None,
- tfcompile_tool="@org_tensorflow//tensorflow/compiler/aot:tfcompile",
+ tfcompile_tool="//tensorflow/compiler/aot:tfcompile",
include_standard_runtime_deps=True, deps=None, tags=None):
"""Runs tfcompile to compile a TensorFlow graph into executable code.
@@ -119,9 +120,9 @@ def tf_library(name, graph, config,
out_nodes_file,
] + freeze_saver_srcs,
outs=[freeze_file],
- cmd=("$(location @org_tensorflow//tensorflow/python/tools:freeze_graph)" +
+ cmd=("$(location //tensorflow/python/tools:freeze_graph)" +
freeze_args),
- tools=["@org_tensorflow//tensorflow/python/tools:freeze_graph"],
+ tools=["//tensorflow/python/tools:freeze_graph"],
tags=tags,
)
tfcompile_graph = freeze_file
@@ -213,22 +214,22 @@ def tf_library(name, graph, config,
# These deps are required by all tf_library targets even if
# include_standard_runtime_deps is False. Without them, the
# generated code will fail to compile.
- "@org_tensorflow//tensorflow/compiler/tf2xla:xla_compiled_cpu_function",
- "@org_tensorflow//tensorflow/core:framework_lite",
+ "//tensorflow/compiler/tf2xla:xla_compiled_cpu_function",
+ "//tensorflow/core:framework_lite",
] + (need_xla_data_proto and [
# If we're generating the program shape, we must depend on the proto.
- "@org_tensorflow//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla:xla_data_proto",
] or []) + (include_standard_runtime_deps and [
# TODO(cwhipkey): only depend on kernel code that the model actually needed.
- "@org_tensorflow//tensorflow/compiler/tf2xla/kernels:index_ops_kernel_argmax_float_1d",
- "@org_tensorflow//tensorflow/compiler/tf2xla/kernels:index_ops_kernel_argmax_float_2d",
- "@org_tensorflow//tensorflow/compiler/xla/service/cpu:cpu_runtime_avx",
- "@org_tensorflow//tensorflow/compiler/xla/service/cpu:cpu_runtime_neon",
- "@org_tensorflow//tensorflow/compiler/xla/service/cpu:cpu_runtime_sse4_1",
- "@org_tensorflow//tensorflow/compiler/xla/service/cpu:runtime_conv2d",
- "@org_tensorflow//tensorflow/compiler/xla/service/cpu:runtime_matmul",
- "@org_tensorflow//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_conv2d",
- "@org_tensorflow//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_matmul",
+ "//tensorflow/compiler/tf2xla/kernels:index_ops_kernel_argmax_float_1d",
+ "//tensorflow/compiler/tf2xla/kernels:index_ops_kernel_argmax_float_2d",
+ "//tensorflow/compiler/xla/service/cpu:cpu_runtime_avx",
+ "//tensorflow/compiler/xla/service/cpu:cpu_runtime_neon",
+ "//tensorflow/compiler/xla/service/cpu:cpu_runtime_sse4_1",
+ "//tensorflow/compiler/xla/service/cpu:runtime_conv2d",
+ "//tensorflow/compiler/xla/service/cpu:runtime_matmul",
+ "//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_conv2d",
+ "//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_matmul",
"//third_party/eigen3",
] or []) + (deps or []),
tags=tags,
@@ -254,28 +255,32 @@ def tf_library(name, graph, config,
name=("gen_" + test_name),
testonly=1,
srcs=[
- "@org_tensorflow//tensorflow/compiler/aot:test.cc",
+ "//tensorflow/compiler/aot:test.cc",
header_file,
],
outs=[test_file],
cmd=("sed " + sed_replace +
- " $(location @org_tensorflow//tensorflow/compiler/aot:test.cc) " +
+ " $(location //tensorflow/compiler/aot:test.cc) " +
"> $(OUTS)"),
tags=tags,
)
- # The cc_test rule for the generated code.
- native.cc_test(
+ # The cc_test rule for the generated code. To ensure that this works
+ # reliably across build configurations, we must use tf_cc_test instead of
+ # native.cc_test. This is related to how we build
+ # //tensorflow/core:lib -- see the note in tensorflow/core/BUILD
+ # for more details.
+ tf_cc_test(
name=test_name,
srcs=[test_file],
deps=[
":" + name,
- "@org_tensorflow//tensorflow/compiler/aot:runtime",
- "@org_tensorflow//tensorflow/compiler/aot:tf_library_test_main",
- "@org_tensorflow//tensorflow/compiler/xla:executable_run_options",
+ "//tensorflow/compiler/aot:runtime",
+ "//tensorflow/compiler/aot:tf_library_test_main",
+ "//tensorflow/compiler/xla:executable_run_options",
"//third_party/eigen3",
- "@org_tensorflow//tensorflow/core:lib",
- "@org_tensorflow//tensorflow/core:test",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
],
tags=tags,
)
@@ -283,7 +288,7 @@ def tf_library(name, graph, config,
if gen_benchmark:
benchmark_name = name + "_benchmark"
benchmark_file = benchmark_name + ".cc"
- benchmark_main = ("@org_tensorflow//tensorflow/compiler/aot:" +
+ benchmark_main = ("//tensorflow/compiler/aot:" +
"benchmark_main.template")
# Rule to rewrite benchmark.cc to produce the benchmark_file.
@@ -301,7 +306,9 @@ def tf_library(name, graph, config,
tags=tags,
)
- # The cc_benchmark rule for the generated code.
+ # The cc_benchmark rule for the generated code. This does not need the
+ # tf_cc_binary since we (by deliberate design) do not depend on
+ # //tensorflow/core:lib.
#
# Note: to get smaller size on android for comparison, compile with:
# --copt=-fvisibility=hidden
@@ -315,12 +322,12 @@ def tf_library(name, graph, config,
linkopts = if_android(["-pie", "-s"]),
deps=[
":" + name,
- "@org_tensorflow//tensorflow/compiler/aot:benchmark",
- "@org_tensorflow//tensorflow/compiler/aot:runtime",
- "@org_tensorflow//tensorflow/compiler/xla:executable_run_options",
+ "//tensorflow/compiler/aot:benchmark",
+ "//tensorflow/compiler/aot:runtime",
+ "//tensorflow/compiler/xla:executable_run_options",
"//third_party/eigen3",
] + if_android([
- "@org_tensorflow//tensorflow/compiler/aot:benchmark_extra_android",
+ "//tensorflow/compiler/aot:benchmark_extra_android",
]),
tags=tags,
)
@@ -330,11 +337,11 @@ def target_llvm_triple():
# TODO(toddw): Add target_triple for other targets. For details see:
# http://llvm.org/docs/doxygen/html/Triple_8h_source.html
return select({
- "@org_tensorflow//tensorflow:android_armeabi": "armv5-none-android",
- "@org_tensorflow//tensorflow:android_arm": "armv7-none-android",
- "@org_tensorflow//tensorflow:android_arm64": "aarch64-none-android",
- "@org_tensorflow//tensorflow:android_x86": "i686-none-android",
- "@org_tensorflow//tensorflow:linux_ppc64le": "ppc64le-ibm-linux-gnu",
- "@org_tensorflow//tensorflow:darwin": "x86_64-none-darwin",
+ "//tensorflow:android_armeabi": "armv5-none-android",
+ "//tensorflow:android_arm": "armv7-none-android",
+ "//tensorflow:android_arm64": "aarch64-none-android",
+ "//tensorflow:android_x86": "i686-none-android",
+ "//tensorflow:linux_ppc64le": "ppc64le-ibm-linux-gnu",
+ "//tensorflow:darwin": "x86_64-none-darwin",
"//conditions:default": "x86_64-pc-linux",
})
diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD
index 7277ba42ce..b0b038775f 100644
--- a/tensorflow/compiler/tests/BUILD
+++ b/tensorflow/compiler/tests/BUILD
@@ -354,6 +354,19 @@ tf_xla_py_test(
)
tf_xla_py_test(
+ name = "matrix_band_part_test",
+ size = "medium",
+ srcs = ["matrix_band_part_test.py"],
+ tags = ["optonly"],
+ deps = [
+ ":xla_test",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:platform_test",
+ ],
+)
+
+tf_xla_py_test(
name = "momentum_test",
size = "small",
srcs = ["momentum_test.py"],
diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py
index c95fb1c515..30a6d3a74d 100644
--- a/tensorflow/compiler/tests/binary_ops_test.py
+++ b/tensorflow/compiler/tests/binary_ops_test.py
@@ -1181,6 +1181,50 @@ class BinaryOpsTest(XLATestCase):
np.array([4, 5, 6], dtype=np.int32),
expected=None)
+ def testMatrixSetDiag(self):
+ for dtype in self.numeric_types:
+ # Square
+ self._testBinary(
+ array_ops.matrix_set_diag,
+ np.array([[0.0, 1.0, 0.0], [1.0, 0.0, 1.0], [1.0, 1.0, 1.0]],
+ dtype=dtype),
+ np.array([1.0, 2.0, 3.0], dtype=dtype),
+ expected=np.array([[1.0, 1.0, 0.0], [1.0, 2.0, 1.0], [1.0, 1.0, 3.0]],
+ dtype=dtype))
+
+ self._testBinary(
+ array_ops.matrix_set_diag,
+ np.array([[[1.0, 0.0, 3.0], [0.0, 2.0, 0.0], [1.0, 0.0, 3.0]],
+ [[4.0, 0.0, 4.0], [0.0, 5.0, 0.0], [2.0, 0.0, 6.0]]],
+ dtype=dtype),
+ np.array([[-1.0, 0.0, -3.0], [-4.0, -5.0, -6.0]], dtype=dtype),
+ expected=np.array(
+ [[[-1.0, 0.0, 3.0], [0.0, 0.0, 0.0], [1.0, 0.0, -3.0]],
+ [[-4.0, 0.0, 4.0], [0.0, -5.0, 0.0], [2.0, 0.0, -6.0]]],
+ dtype=dtype))
+
+ # Rectangular
+ self._testBinary(
+ array_ops.matrix_set_diag,
+ np.array([[0.0, 1.0, 0.0], [1.0, 0.0, 1.0]], dtype=dtype),
+ np.array([3.0, 4.0], dtype=dtype),
+ expected=np.array([[3.0, 1.0, 0.0], [1.0, 4.0, 1.0]], dtype=dtype))
+
+ self._testBinary(
+ array_ops.matrix_set_diag,
+ np.array([[0.0, 1.0], [1.0, 0.0], [1.0, 1.0]], dtype=dtype),
+ np.array([3.0, 4.0], dtype=dtype),
+ expected=np.array([[3.0, 1.0], [1.0, 4.0], [1.0, 1.0]], dtype=dtype))
+
+ self._testBinary(
+ array_ops.matrix_set_diag,
+ np.array([[[1.0, 0.0, 3.0], [0.0, 2.0, 0.0]],
+ [[4.0, 0.0, 4.0], [0.0, 5.0, 0.0]]], dtype=dtype),
+ np.array([[-1.0, -2.0], [-4.0, -5.0]],
+ dtype=dtype),
+ expected=np.array([[[-1.0, 0.0, 3.0], [0.0, -2.0, 0.0]],
+ [[-4.0, 0.0, 4.0], [0.0, -5.0, 0.0]]],
+ dtype=dtype))
if __name__ == "__main__":
googletest.main()
diff --git a/tensorflow/compiler/tests/matrix_band_part_test.py b/tensorflow/compiler/tests/matrix_band_part_test.py
new file mode 100644
index 0000000000..29394f9ea5
--- /dev/null
+++ b/tensorflow/compiler/tests/matrix_band_part_test.py
@@ -0,0 +1,64 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import array_ops
+from tensorflow.python.platform import test
+
+
+class MatrixBandPartTest(XLATestCase):
+
+ def _testMatrixBandPart(self, dtype, shape):
+ with self.test_session():
+ batch_shape = shape[:-2]
+ mat = np.ones(shape).astype(dtype)
+ batch_mat = np.tile(mat, batch_shape + [1, 1])
+ for lower in -1, 0, 1, shape[-2] - 1:
+ for upper in -1, 0, 1, shape[-1] - 1:
+ band_np = mat
+ if lower >= 0:
+ band_np = np.triu(band_np, -lower)
+ if upper >= 0:
+ band_np = np.tril(band_np, upper)
+ if batch_shape:
+ band_np = np.tile(band_np, batch_shape + [1, 1])
+
+ placeholder = array_ops.placeholder(dtype)
+ with self.test_scope():
+ band = array_ops.matrix_band_part(
+ placeholder,
+ constant_op.constant(lower, dtype=dtypes.int32),
+ constant_op.constant(upper, dtype=dtypes.int32))
+ feed_dict = {placeholder: batch_mat}
+ self.assertAllEqual(band_np, band.eval(feed_dict=feed_dict))
+
+ def testMatrixBandPart(self):
+ for dtype in self.float_types:
+ for batch_shape in [[], [2,], [1, 3, 2]]:
+ for rows in 1, 2, 7:
+ for cols in 1, 2, 7:
+ self._testMatrixBandPart(dtype, batch_shape + [rows, cols])
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD
index 67be1a4ba6..e9be6f8476 100644
--- a/tensorflow/compiler/tf2xla/kernels/BUILD
+++ b/tensorflow/compiler/tf2xla/kernels/BUILD
@@ -44,6 +44,8 @@ tf_kernel_library(
"l2loss_op.cc",
"lrn_ops.cc",
"matmul_op.cc",
+ "matrix_band_part_op.cc",
+ "matrix_set_diag_op.cc",
"matrix_triangular_solve_op.cc",
"mirror_pad_op.cc",
"no_op.cc",
diff --git a/tensorflow/compiler/tf2xla/kernels/matrix_band_part_op.cc b/tensorflow/compiler/tf2xla/kernels/matrix_band_part_op.cc
new file mode 100644
index 0000000000..faa415a97b
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/kernels/matrix_band_part_op.cc
@@ -0,0 +1,98 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/tf2xla/xla_helpers.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+
+namespace tensorflow {
+namespace {
+
+class MatrixBandPartOp : public XlaOpKernel {
+ public:
+ explicit MatrixBandPartOp(OpKernelConstruction* context)
+ : XlaOpKernel(context) {}
+
+ void Compile(XlaOpKernelContext* context) override {
+ const TensorShape input_shape = context->InputShape(0);
+ // Preliminary validation of sizes.
+ OP_REQUIRES(context, TensorShapeUtils::IsMatrixOrHigher(input_shape),
+ errors::InvalidArgument(
+ "input must be at least 2-dim, received shape: ",
+ input_shape.DebugString()));
+
+ const TensorShape num_lower_in_shape = context->InputShape(1);
+ OP_REQUIRES(context, TensorShapeUtils::IsScalar(num_lower_in_shape),
+ errors::InvalidArgument("num_lower must be scalar, got shape ",
+ num_lower_in_shape.DebugString()));
+
+ const TensorShape num_upper_in_shape = context->InputShape(2);
+ OP_REQUIRES(context, TensorShapeUtils::IsScalar(num_upper_in_shape),
+ errors::InvalidArgument("num_upper must be scalar, got shape ",
+ num_upper_in_shape.DebugString()));
+
+ xla::ComputationBuilder* builder = context->builder();
+ xla::ComputationDataHandle input = context->Input(0);
+ xla::ComputationDataHandle num_lower = context->Input(1);
+ xla::ComputationDataHandle num_upper = context->Input(2);
+ DataType input_type = context->input_type(0);
+ DataType index_type = context->input_type(1);
+
+ TensorShape batch_shape = input_shape;
+ batch_shape.RemoveLastDims(2);
+ const int64 m = input_shape.dim_size(input_shape.dims() - 2);
+ const int64 n = input_shape.dim_size(input_shape.dims() - 1);
+
+ // Compute 'offset', which is how many diagonals we are above/below the
+ // diagonal.
+ xla::ComputationDataHandle iota_m;
+ OP_REQUIRES_OK(context, XlaHelpers::Iota(builder, index_type, m, &iota_m));
+
+ xla::ComputationDataHandle iota_n;
+ OP_REQUIRES_OK(context, XlaHelpers::Iota(builder, index_type, n, &iota_n));
+
+ auto offset = builder->Sub(builder->Broadcast(iota_n, {m}), iota_m,
+ /*broadcast_dimensions=*/{0});
+
+ // If num_lower or num_upper are negative, include all lower/upper
+ // diagonals.
+ auto zero_index = XlaHelpers::Zero(builder, index_type);
+ num_lower = builder->Select(
+ builder->Lt(num_lower, zero_index),
+ XlaHelpers::IntegerLiteral(builder, index_type, m), num_lower);
+ num_upper = builder->Select(
+ builder->Lt(num_upper, zero_index),
+ XlaHelpers::IntegerLiteral(builder, index_type, n), num_upper);
+
+ auto indicator = builder->And(builder->Le(builder->Neg(num_lower), offset),
+ builder->Le(offset, num_upper));
+ indicator = builder->Broadcast(indicator, batch_shape.dim_sizes());
+
+ auto zero_input = XlaHelpers::Zero(builder, input_type);
+ auto output = builder->Select(
+ indicator, input,
+ builder->Broadcast(zero_input, input_shape.dim_sizes()));
+
+ context->SetOutput(0, output);
+ }
+
+ private:
+ TF_DISALLOW_COPY_AND_ASSIGN(MatrixBandPartOp);
+};
+REGISTER_XLA_OP(Name("MatrixBandPart"), MatrixBandPartOp);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/matrix_set_diag_op.cc b/tensorflow/compiler/tf2xla/kernels/matrix_set_diag_op.cc
new file mode 100644
index 0000000000..b2940bdcff
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/kernels/matrix_set_diag_op.cc
@@ -0,0 +1,93 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/tf2xla/xla_helpers.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+
+namespace tensorflow {
+
+class MatrixSetDiagOp : public XlaOpKernel {
+ public:
+ explicit MatrixSetDiagOp(OpKernelConstruction* context)
+ : XlaOpKernel(context) {}
+
+ void Compile(XlaOpKernelContext* context) override {
+ const TensorShape input_shape = context->InputShape(0);
+ const TensorShape diag_shape = context->InputShape(1);
+
+ const int rank = input_shape.dims();
+
+ // Preliminary validation of sizes.
+ OP_REQUIRES(context, TensorShapeUtils::IsMatrixOrHigher(input_shape),
+ errors::InvalidArgument(
+ "input must be at least 2-dim, received shape: ",
+ input_shape.DebugString()));
+
+ // Check to make sure the last dimension of diag is equal to the smaller of
+ // the last two dimensions of input.
+ const int64 m = input_shape.dim_size(rank - 2);
+ const int64 n = input_shape.dim_size(rank - 1);
+ const int64 min_dim = std::min(m, n);
+
+ TensorShape batch_shape = input_shape;
+ batch_shape.RemoveLastDims(2);
+
+ TensorShape expected_diag_shape = batch_shape;
+ expected_diag_shape.AddDim(min_dim);
+ OP_REQUIRES(context, expected_diag_shape == diag_shape,
+ errors::InvalidArgument(
+ "must have diagonal.shape == input.shape[:-2] + "
+ "min(input.shape[-2:]), but received input shape: ",
+ input_shape.DebugString(),
+ " and diagonal shape: ", diag_shape.DebugString()));
+
+ xla::ComputationBuilder* builder = context->builder();
+ xla::ComputationDataHandle input = context->Input(0);
+ xla::ComputationDataHandle diag = context->Input(1);
+
+ auto zero = XlaHelpers::Zero(builder, context->input_type(0));
+
+ // Create an indicator tensor that is true only on the diagonal.
+ xla::ComputationDataHandle iota_m;
+ OP_REQUIRES_OK(context, XlaHelpers::Iota(builder, DT_INT32, m, &iota_m));
+ xla::ComputationDataHandle iota_n;
+ OP_REQUIRES_OK(context, XlaHelpers::Iota(builder, DT_INT32, n, &iota_n));
+ auto indicator = builder->Eq(iota_m,
+ builder->Broadcast(iota_n, {m}),
+ /*broadcast_dimensions=*/{0});
+ indicator = builder->Broadcast(indicator, batch_shape.dim_sizes());
+
+ // Broadcast diag up to the input shape. Use an implicit broadcast (Add)
+ // because we need to broadcast on the right.
+ std::vector<int64> diag_broadcast_dims(rank - 1);
+ std::iota(diag_broadcast_dims.begin(), diag_broadcast_dims.end(), 0);
+ if (min_dim != m) {
+ diag_broadcast_dims.back() = rank - 1;
+ }
+ diag = builder->Add(diag, builder->Broadcast(zero, input_shape.dim_sizes()),
+ /*broadcast_dimensions=*/diag_broadcast_dims);
+
+ auto output = builder->Select(indicator, diag, input);
+ context->SetOutput(0, output);
+ }
+
+ private:
+ TF_DISALLOW_COPY_AND_ASSIGN(MatrixSetDiagOp);
+};
+
+REGISTER_XLA_OP(Name("MatrixSetDiag"), MatrixSetDiagOp);
+
+} // namespace tensorflow
diff --git a/tensorflow/compiler/xla/python/local_computation_builder.cc b/tensorflow/compiler/xla/python/local_computation_builder.cc
index 67a73bc33d..8386acf0cd 100644
--- a/tensorflow/compiler/xla/python/local_computation_builder.cc
+++ b/tensorflow/compiler/xla/python/local_computation_builder.cc
@@ -98,15 +98,25 @@ const std::unique_ptr<ScopedShapedBuffer>& LocalShapedBuffer::shaped_buffer()
return shaped_buffer_;
}
+static StatusOr<std::unique_ptr<ScopedShapedBuffer>> ToBuffer(
+ LocalClient* client, int device_ordinal, const Literal& arg) {
+ return client->LiteralToShapedBuffer(arg, device_ordinal,
+ client->backend().memory_allocator());
+}
+
/* static */
-LocalShapedBuffer* LocalShapedBuffer::FromLiteral(const Literal& argument) {
+LocalShapedBuffer* LocalShapedBuffer::FromLiteral(
+ const Literal& argument,
+ const tensorflow::gtl::optional<Shape>& shape_with_layout) {
LocalClient* client = GetOrCreateLocalClient();
- std::unique_ptr<ScopedShapedBuffer> buf =
- client
- ->LiteralToShapedBuffer(argument,
- /*device_ordinal=*/0,
- client->backend().memory_allocator())
- .ConsumeValueOrDie();
+ std::unique_ptr<ScopedShapedBuffer> buf;
+ if (shape_with_layout) {
+ std::unique_ptr<Literal> relaid =
+ argument.Relayout(shape_with_layout.value());
+ buf = ToBuffer(client, /*device_ordinal=*/0, *relaid).ConsumeValueOrDie();
+ } else {
+ buf = ToBuffer(client, /*device_ordinal=*/0, argument).ConsumeValueOrDie();
+ }
return new LocalShapedBuffer(std::move(buf));
}
@@ -120,7 +130,8 @@ CompiledLocalComputation::CompiledLocalComputation(
: executable_(std::move(executable)) {}
StatusOr<std::unique_ptr<Literal>> CompiledLocalComputation::Execute(
- const std::vector<Literal>& arguments) {
+ const std::vector<Literal>& arguments,
+ const std::vector<tensorflow::gtl::optional<Shape>>& shapes_with_layout) {
LocalClient* client = GetOrCreateLocalClient();
VLOG(1) << "Execution requested with " << GetReplicaCount() << " replicas.";
@@ -133,7 +144,8 @@ StatusOr<std::unique_ptr<Literal>> CompiledLocalComputation::Execute(
GetReplicaCount());
for (int replica = 0; replica < GetReplicaCount(); ++replica) {
- pool.Schedule([this, client, replica, &arguments, &results] {
+ pool.Schedule([this, client, replica, &arguments, &shapes_with_layout,
+ &results] {
StatusOr<int> device_ordinal_status =
client->ReplicaNumberToDeviceOrdinal(replica);
if (!device_ordinal_status.ok()) {
@@ -144,18 +156,28 @@ StatusOr<std::unique_ptr<Literal>> CompiledLocalComputation::Execute(
VLOG(3) << "Replica " << replica
<< " mapped to device ordinal for execution: "
<< device_ordinal;
+
// Transfer arguments in
std::vector<std::unique_ptr<ScopedShapedBuffer>> scoped_buffers;
scoped_buffers.reserve(arguments.size());
- for (const Literal& argument : arguments) {
- StatusOr<std::unique_ptr<ScopedShapedBuffer>> pushed =
- client->LiteralToShapedBuffer(
- argument, device_ordinal,
- client->backend().memory_allocator());
+ for (int i = 0; i < arguments.size(); ++i) {
+ const Literal& argument = arguments[i];
+ const tensorflow::gtl::optional<Shape>& shape_with_layout =
+ shapes_with_layout[i];
+
+ StatusOr<std::unique_ptr<ScopedShapedBuffer>> pushed;
+ if (shape_with_layout) {
+ std::unique_ptr<Literal> relaid =
+ argument.Relayout(shape_with_layout.value());
+ pushed = ToBuffer(client, device_ordinal, *relaid);
+ } else {
+ pushed = ToBuffer(client, device_ordinal, argument);
+ }
if (!pushed.ok()) {
results[replica] = pushed.status();
return;
}
+
scoped_buffers.push_back(std::move(pushed).ValueOrDie());
}
@@ -382,6 +404,12 @@ ComputationDataHandle LocalComputationBuilder::Dot(
return builder_.Dot(lhs, rhs);
}
+ComputationDataHandle LocalComputationBuilder::DotGeneral(
+ const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
+ const DotDimensionNumbers& dimension_numbers) {
+ return builder_.DotGeneral(lhs, rhs, dimension_numbers);
+}
+
ComputationDataHandle LocalComputationBuilder::ConvGeneralDilated(
const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
tensorflow::gtl::ArraySlice<int64> window_strides,
diff --git a/tensorflow/compiler/xla/python/local_computation_builder.h b/tensorflow/compiler/xla/python/local_computation_builder.h
index d5c4c58040..f39d15cff7 100644
--- a/tensorflow/compiler/xla/python/local_computation_builder.h
+++ b/tensorflow/compiler/xla/python/local_computation_builder.h
@@ -59,7 +59,9 @@ StatusOr<std::unique_ptr<Literal> > TransferFromOutfeedLocalReplica(
// client.
class LocalShapedBuffer {
public:
- static LocalShapedBuffer* FromLiteral(const Literal& argument);
+ static LocalShapedBuffer* FromLiteral(
+ const Literal& argument,
+ const tensorflow::gtl::optional<Shape>& shape_with_layout);
LocalShapedBuffer(std::unique_ptr<ScopedShapedBuffer> shaped_buffer);
const std::unique_ptr<ScopedShapedBuffer>& shaped_buffer() const;
std::unique_ptr<Literal> ToLiteral() const;
@@ -77,8 +79,15 @@ class LocalShapedBuffer {
class CompiledLocalComputation {
public:
CompiledLocalComputation(std::unique_ptr<LocalExecutable> executable);
+
+ // Execute the computation with the given argument literals, and
+ // with optionally-specified argument layouts. The literals will be
+ // re-laid out according to the corresponding elements of
+ // shapes_with_layout.
StatusOr<std::unique_ptr<Literal> > Execute(
- const std::vector<Literal>& arguments);
+ const std::vector<Literal>& arguments,
+ const std::vector<tensorflow::gtl::optional<Shape> >& shapes_with_layout);
+
LocalShapedBuffer* ExecuteWithShapedBuffers(
tensorflow::gtl::ArraySlice<LocalShapedBuffer*> argument_handles);
@@ -183,6 +192,10 @@ class LocalComputationBuilder {
ComputationDataHandle Dot(const ComputationDataHandle& lhs,
const ComputationDataHandle& rhs);
+ ComputationDataHandle DotGeneral(
+ const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
+ const DotDimensionNumbers& dimension_numbers);
+
ComputationDataHandle ConvGeneralDilated(
const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
tensorflow::gtl::ArraySlice<int64> window_strides,
diff --git a/tensorflow/compiler/xla/python/local_computation_builder.i b/tensorflow/compiler/xla/python/local_computation_builder.i
index 89f8385501..5ea75550c9 100644
--- a/tensorflow/compiler/xla/python/local_computation_builder.i
+++ b/tensorflow/compiler/xla/python/local_computation_builder.i
@@ -27,12 +27,14 @@ limitations under the License.
// ArraySlice<ComputationDataHandle> <- sequence of int
// Literal <-> (nested tuple of) numpy ndarray
// std::vector<Literal> <- sequence of (nested tuple of) ndarray
-// Shape <-> pair holding (dtype, dimensions)
-// std::vector<Shape> <- sequence of shape information pairs
+// Shape -> pair holding (dtype, dimensions)
+// <- object duck-typed as xla_client.Shape
+// std::vector<Shape> <- sequence of xla_client.Shape objects
// PrimitiveType <- int
// ArraySlice<pair<int64, in64>> <- sequence of int pairs
// PaddingConfig proto <- corresponding Python proto
// ConvolutionDimensionNumbers proto <- corresponding Python proto
+// DotDimensionNumbers proto <- corresponding Python proto
//
// Arrows indicate whether a conversion only ever occurs in one
// direction, or whether it is maintained bidirectionally.
@@ -55,7 +57,7 @@ limitations under the License.
// translates to a tuple-shaped XLA Literal, whose component subshapes
// are a 2x3 F32-shaped literal followed by two tuple-shaped literals.
//
-// The Python objects corresponding to C++ Shapes have the type:
+// Shapes output by C++ become Python objects with the type:
//
// T = (dtype, S)
// S = DIMENSIONS | TUPLE_SHAPES
@@ -353,15 +355,31 @@ tensorflow::ImportNumpy();
// Shape
%typemap(in) const Shape& (Shape temp) {
- Status shape_status = numpy::CheckPyShapeInfo($input);
- if (!shape_status.ok()) {
- PyErr_SetString(PyExc_RuntimeError, shape_status.ToString().c_str());
+ StatusOr<Shape> statusor = numpy::XlaShapeFromPyShape($input);
+ if (!statusor.ok()) {
+ PyErr_SetString(PyExc_RuntimeError, statusor.status().ToString().c_str());
return NULL;
}
- temp = numpy::XlaShapeFromPyShapeInfo($input);
+ temp = std::move(statusor).ValueOrDie();
$1 = &temp;
}
+%typemap(in) const tensorflow::gtl::optional<Shape>& (
+ tensorflow::gtl::optional<Shape> temp) {
+ if ($input == Py_None) {
+ temp = tensorflow::gtl::nullopt;
+ $1 = &temp;
+ } else {
+ StatusOr<Shape> statusor = numpy::XlaShapeFromPyShape($input);
+ if (!statusor.ok()) {
+ PyErr_SetString(PyExc_RuntimeError, statusor.status().ToString().c_str());
+ return NULL;
+ }
+ temp = std::move(statusor).ValueOrDie();
+ $1 = &temp;
+ }
+}
+
%typemap(out) std::unique_ptr<Shape> {
$result = numpy::PyShapeInfoFromXlaShape(*$1);
}
@@ -374,14 +392,37 @@ tensorflow::ImportNumpy();
const int size = PySequence_Size($input);
for (int i = 0; i < size; ++i) {
PyObject* o = PySequence_GetItem($input, i);
- Status shape_status = numpy::CheckPyShapeInfo(o);
- if (!shape_status.ok()) {
- PyErr_SetString(PyExc_RuntimeError, shape_status.ToString().c_str());
- Py_DECREF(o);
+ StatusOr<Shape> statusor = numpy::XlaShapeFromPyShape(o);
+ Py_DECREF(o);
+ if (!statusor.ok()) {
+ PyErr_SetString(PyExc_RuntimeError, statusor.status().ToString().c_str());
return NULL;
}
- temps.push_back(numpy::XlaShapeFromPyShapeInfo(o));
- Py_DECREF(o);
+ temps.push_back(statusor.ConsumeValueOrDie());
+ }
+ $1 = &temps;
+}
+
+%typemap(in) const std::vector<tensorflow::gtl::optional<Shape> >& (
+ std::vector<tensorflow::gtl::optional<Shape> > temps) {
+ if (!PySequence_Check($input)) {
+ PyErr_SetString(PyExc_TypeError, "Argument is not a sequence");
+ return NULL;
+ }
+ const int size = PySequence_Size($input);
+ for (int i = 0; i < size; ++i) {
+ PyObject* o = PySequence_GetItem($input, i);
+ if (o == Py_None) {
+ temps.push_back(tensorflow::gtl::nullopt);
+ } else {
+ StatusOr<Shape> statusor = numpy::XlaShapeFromPyShape(o);
+ Py_DECREF(o);
+ if (!statusor.ok()) {
+ PyErr_SetString(PyExc_RuntimeError, statusor.status().ToString().c_str());
+ return NULL;
+ }
+ temps.push_back(statusor.ConsumeValueOrDie());
+ }
}
$1 = &temps;
}
@@ -471,6 +512,135 @@ tensorflow::ImportNumpy();
$1 = temps;
}
+// DotDimensionNumbers
+
+%typemap(in) const DotDimensionNumbers&
+ (DotDimensionNumbers dimension_numbers) {
+ int length;
+
+ /* lhs_contracting_dimensions */
+ PyObject* lhs_contracting_dimensions = PyObject_GetAttrString(
+ $input, "lhs_contracting_dimensions");
+ if (!lhs_contracting_dimensions) {
+ return NULL;
+ }
+
+ length = PySequence_Size(lhs_contracting_dimensions);
+ if (length == -1) {
+ Py_DECREF(lhs_contracting_dimensions);
+ return NULL;
+ }
+
+ for (int i = 0; i < length; ++i) {
+ PyObject* item = PySequence_GetItem(lhs_contracting_dimensions, i);
+ if (!item) {
+ Py_DECREF(lhs_contracting_dimensions);
+ return NULL;
+ }
+ const int64 dimension = numpy::PyIntOrPyLongToLong(item);
+ if (dimension == -1 && PyErr_Occurred()) {
+ Py_DECREF(item);
+ Py_DECREF(lhs_contracting_dimensions);
+ return NULL;
+ }
+ dimension_numbers.add_lhs_contracting_dimensions(dimension);
+ Py_DECREF(item);
+ }
+ Py_DECREF(lhs_contracting_dimensions);
+
+ /* rhs_contracting_dimensions */
+ PyObject* rhs_contracting_dimensions = PyObject_GetAttrString(
+ $input, "rhs_contracting_dimensions");
+ if (!lhs_contracting_dimensions) {
+ return NULL;
+ }
+
+ length = PySequence_Size(rhs_contracting_dimensions);
+ if (length == -1) {
+ Py_DECREF(rhs_contracting_dimensions);
+ return NULL;
+ }
+
+ for (int i = 0; i < length; ++i) {
+ PyObject* item = PySequence_GetItem(rhs_contracting_dimensions, i);
+ if (!item) {
+ Py_DECREF(rhs_contracting_dimensions);
+ return NULL;
+ }
+ const int64 dimension = numpy::PyIntOrPyLongToLong(item);
+ if (dimension == -1 && PyErr_Occurred()) {
+ Py_DECREF(item);
+ Py_DECREF(rhs_contracting_dimensions);
+ return NULL;
+ }
+ dimension_numbers.add_rhs_contracting_dimensions(dimension);
+ Py_DECREF(item);
+ }
+ Py_DECREF(rhs_contracting_dimensions);
+
+ /* lhs_batch_dimensions */
+ PyObject* lhs_batch_dimensions = PyObject_GetAttrString(
+ $input, "lhs_batch_dimensions");
+ if (!lhs_batch_dimensions) {
+ return NULL;
+ }
+
+ length = PySequence_Size(lhs_batch_dimensions);
+ if (length == -1) {
+ Py_DECREF(lhs_batch_dimensions);
+ return NULL;
+ }
+
+ for (int i = 0; i < length; ++i) {
+ PyObject* item = PySequence_GetItem(lhs_batch_dimensions, i);
+ if (!item) {
+ Py_DECREF(lhs_batch_dimensions);
+ return NULL;
+ }
+ const int64 dimension = numpy::PyIntOrPyLongToLong(item);
+ if (dimension == -1 && PyErr_Occurred()) {
+ Py_DECREF(item);
+ Py_DECREF(lhs_batch_dimensions);
+ return NULL;
+ }
+ dimension_numbers.add_lhs_batch_dimensions(dimension);
+ Py_DECREF(item);
+ }
+ Py_DECREF(lhs_batch_dimensions);
+
+ /* rhs_batch_dimensions */
+ PyObject* rhs_batch_dimensions = PyObject_GetAttrString(
+ $input, "rhs_batch_dimensions");
+ if (!rhs_batch_dimensions) {
+ return NULL;
+ }
+
+ length = PySequence_Size(rhs_batch_dimensions);
+ if (length == -1) {
+ Py_DECREF(rhs_batch_dimensions);
+ return NULL;
+ }
+
+ for (int i = 0; i < length; ++i) {
+ PyObject* item = PySequence_GetItem(rhs_batch_dimensions, i);
+ if (!item) {
+ Py_DECREF(rhs_batch_dimensions);
+ return NULL;
+ }
+ const int64 dimension = numpy::PyIntOrPyLongToLong(item);
+ if (dimension == -1 && PyErr_Occurred()) {
+ Py_DECREF(item);
+ Py_DECREF(rhs_batch_dimensions);
+ return NULL;
+ }
+ dimension_numbers.add_rhs_batch_dimensions(dimension);
+ Py_DECREF(item);
+ }
+ Py_DECREF(rhs_batch_dimensions);
+
+ $1 = &dimension_numbers;
+}
+
// PaddingConfig
%typemap(in) const PaddingConfig&
@@ -716,6 +886,7 @@ tensorflow::ImportNumpy();
%unignore xla::swig::LocalComputationBuilder::Lt;
%unignore xla::swig::LocalComputationBuilder::Le;
%unignore xla::swig::LocalComputationBuilder::Dot;
+%unignore xla::swig::LocalComputationBuilder::DotGeneral;
%unignore xla::swig::LocalComputationBuilder::ConvGeneralDilated;
%unignore xla::swig::LocalComputationBuilder::Add;
%unignore xla::swig::LocalComputationBuilder::Sub;
diff --git a/tensorflow/compiler/xla/python/numpy_bridge.cc b/tensorflow/compiler/xla/python/numpy_bridge.cc
index 5c722623e3..3d87480728 100644
--- a/tensorflow/compiler/xla/python/numpy_bridge.cc
+++ b/tensorflow/compiler/xla/python/numpy_bridge.cc
@@ -176,85 +176,107 @@ static string PyObjectCppRepr(PyObject* o) {
return ExtractStringAndDecref(r);
}
-Status CheckPyShapeInfo(PyObject* o) {
+StatusOr<Shape> XlaShapeFromPyShape(PyObject* o) {
auto error = [o](const string& prefix) {
return InvalidArgument("%s; got %s", prefix.c_str(),
PyObjectCppRepr(o).c_str());
};
- // The object is a tuple (a pair)
- if (!PyTuple_Check(o)) {
- return error("Shape record must be a tuple");
- }
- if (PyTuple_Size(o) != 2) {
- return error("Shape record tuple must be of length 2");
- }
- // It has a first element, which is a numpy dtype object
- PyObject* first = PyTuple_GetItem(o, 0);
- if (first == nullptr) {
- return error("Tuple has no item 0 (shape dtype)");
- }
- if (first->ob_type != &PyArrayDescr_Type) {
- return error(
- "Shape record does not have a numpy dtype as its first element");
- }
- const int np_type = NumpyTypenum(first);
- if (!NumpyTypeIsValid(np_type)) {
- return error("Shape record has an invalid integer dtype");
- }
+ auto get_attr = [o, &error](const string& field) -> StatusOr<PyObject*> {
+ PyObject* result =
+ PyObject_GetAttrString(o, const_cast<char*>(field.c_str()));
+ if (result == nullptr) {
+ return error(tensorflow::strings::StrCat(
+ "Failed to get attribute of Shape object:", field));
+ }
+ return result;
+ };
- // It has a second element, which is a tuple, either of shape
- // records or of Python ints
- PyObject* second = PyTuple_GetItem(o, 1);
- if (!second) {
- return error("Tuple has no item 0 (shape dimensions)");
- }
- if (!PyTuple_Check(second)) {
- return error("Shape record does not have a tuple as its second element");
- }
- const int length = PyTuple_Size(second);
- const PrimitiveType element_type = NumpyTypeToPrimitiveType(np_type);
- for (int i = 0; i < length; i++) {
- PyObject* dimension = PyTuple_GetItem(second, i);
- if (element_type == TUPLE) {
- VLOG(3) << "element_type is tuple, checking member: " << i;
- Status result = CheckPyShapeInfo(dimension);
- if (!result.ok()) {
- return AddStatus(
- result, tensorflow::strings::StrCat("Validating tuple member ", i,
- " of ", PyObjectCppRepr(o)));
- }
- } else if (!CheckPyIntOrLong(dimension)) {
- return error("Non-tuple shape record has a non-integer dimension");
+ auto call_method = [o, &error](const string& method) -> StatusOr<PyObject*> {
+ PyObject* result =
+ PyObject_CallMethod(o, const_cast<char*>(method.c_str()), nullptr);
+ if (result == nullptr) {
+ return error(tensorflow::strings::StrCat(
+ "Failed to call method of shape object:", method));
}
- }
+ return result;
+ };
- return Status::OK();
-}
+ PyObject* np_type;
+ TF_ASSIGN_OR_RETURN(np_type, get_attr("np_dtype"));
+ if (np_type->ob_type != &PyArrayDescr_Type) {
+ return error("Shape attribute np_dtype is not an integer numpy dtype");
+ }
+ if (!NumpyTypeIsValid(NumpyTypenum(np_type))) {
+ return error("Shape attribute np_dtype is not a valid integer numpy dtype");
+ }
+ const PrimitiveType element_type =
+ NumpyTypeToPrimitiveType(NumpyTypenum(np_type));
+ Py_DECREF(np_type);
-// Precondition: CheckPyShapeInfo(o)
-Shape XlaShapeFromPyShapeInfo(PyObject* o) {
- const int np_type = NumpyTypenum(PyTuple_GetItem(o, 0));
- const PrimitiveType element_type = NumpyTypeToPrimitiveType(np_type);
- PyObject* py_dimensions = PyTuple_GetItem(o, 1);
- const int length = PyTuple_Size(py_dimensions);
if (element_type == TUPLE) {
+ PyObject* py_subshapes;
+ TF_ASSIGN_OR_RETURN(py_subshapes, call_method("tuple_shapes"));
+ if (!PyTuple_Check(py_subshapes)) {
+ return error(
+ "Return value of Shape method tuple_shapes() is not a tuple");
+ }
+ const int length = PyTuple_Size(py_subshapes);
std::vector<Shape> subshapes;
subshapes.reserve(length);
for (int i = 0; i < length; i++) {
- subshapes.push_back(
- XlaShapeFromPyShapeInfo(PyTuple_GetItem(py_dimensions, i)));
+ TF_ASSIGN_OR_RETURN(
+ const Shape& subshape,
+ XlaShapeFromPyShape(PyTuple_GetItem(py_subshapes, i)));
+ subshapes.push_back(subshape);
}
+ Py_DECREF(py_subshapes);
return ShapeUtil::MakeTupleShape(subshapes);
} else {
+ PyObject* py_dimensions;
+ PyObject* py_minor_to_major;
+ TF_ASSIGN_OR_RETURN(py_dimensions, call_method("dimensions"));
+ TF_ASSIGN_OR_RETURN(py_minor_to_major, call_method("minor_to_major"));
+ if (!PyTuple_Check(py_dimensions)) {
+ return error("Return value of Shape method dimensions() is not a tuple");
+ }
+ if (py_minor_to_major != Py_None && !PyTuple_Check(py_minor_to_major)) {
+ return error(
+ "Return value of Shape method minor_to_major() is neither a tuple "
+ "nor None");
+ }
+ const int length = PyTuple_Size(py_dimensions);
+ if (py_minor_to_major != Py_None &&
+ length != PyTuple_Size(py_minor_to_major)) {
+ return error(
+ "Shape methods dimensions() and minor_to_major() return "
+ "different-length tuples");
+ }
std::vector<int64> dimensions(length);
+ std::vector<int64> minor_to_major(length);
for (int i = 0; i < length; i++) {
dimensions[i] = PyIntOrPyLongToLong(PyTuple_GetItem(py_dimensions, i));
- if (dimensions[i] == -1) {
- CHECK(!PyErr_Occurred());
+ if (dimensions[i] == -1 && PyErr_Occurred()) {
+ return error("Dimension is not an int");
}
+
+ if (py_minor_to_major != Py_None) {
+ minor_to_major[i] =
+ PyIntOrPyLongToLong(PyTuple_GetItem(py_minor_to_major, i));
+ if (minor_to_major[i] == -1 && PyErr_Occurred()) {
+ return error("Minor-to-major value is not an int");
+ }
+ }
+ }
+ bool with_layout = py_minor_to_major != Py_None;
+ Py_DECREF(py_dimensions);
+ Py_DECREF(py_minor_to_major);
+ if (with_layout) {
+ return ShapeUtil::MakeShapeWithLayout(element_type, dimensions,
+ minor_to_major);
+ } else {
+ return ShapeUtil::MakeShape(element_type, dimensions);
}
- return ShapeUtil::MakeShape(element_type, dimensions);
}
}
diff --git a/tensorflow/compiler/xla/python/numpy_bridge.h b/tensorflow/compiler/xla/python/numpy_bridge.h
index 6ff1c34cfc..adfcc3b858 100644
--- a/tensorflow/compiler/xla/python/numpy_bridge.h
+++ b/tensorflow/compiler/xla/python/numpy_bridge.h
@@ -56,15 +56,11 @@ bool NumpyTypeIsValid(int np_type);
// The return value is a new reference.
PyObject* PyShapeInfoFromXlaShape(const Shape& shape);
-// Returns the outcome of a best-effort check that the Python object
-// is a pair of the form (numpy dtype, dimensions), as produced by
-// PyShapeInfoFromXlaShape.
-Status CheckPyShapeInfo(PyObject* o);
-
-// Performs the inverse conversion to that of PyShapeInfoFromXlaShape.
+// Converts a Python object with a method interface mathing that of
+// xla_client.Shape into an XLA Shape object.
//
// The return value is a new reference.
-Shape XlaShapeFromPyShapeInfo(PyObject* o);
+StatusOr<Shape> XlaShapeFromPyShape(PyObject* o);
// Converts a PyObject that represents operation metadata into protocol buffer
// form.
diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py
index 7ee5febc09..b890980955 100644
--- a/tensorflow/compiler/xla/python/xla_client.py
+++ b/tensorflow/compiler/xla/python/xla_client.py
@@ -156,9 +156,14 @@ class LocalBuffer(object):
self._delete = c_api.DeleteLocalShapedBuffer
@staticmethod
- def from_py(npval):
+ def from_py(npval, layout_fn=None):
npval = require_numpy_array_layout(npval)
- return LocalBuffer(c_api.LocalShapedBuffer.FromLiteral(npval))
+ if layout_fn:
+ shape = Shape.from_numpy(npval)
+ shape = shape.map_leaves(layout_fn)
+ else:
+ shape = None
+ return LocalBuffer(c_api.LocalShapedBuffer.FromLiteral(npval, shape))
def to_py(self):
return self.c_local_shaped_buffer.ToLiteral()
@@ -183,13 +188,17 @@ class Shape(object):
represents an XLA tuple.
"""
- def __init__(self, np_dtype, dimensions):
+ def __init__(self, np_dtype, dimensions, minor_to_major=None):
+ assert isinstance(dimensions, tuple)
self.np_dtype = np_dtype
self._dimensions = dimensions
+ self._minor_to_major = minor_to_major
+ self._check_minor_to_major()
def __repr__(self):
- return 'xla_client.Shape(np_dtype={!r}, dimensions={!r})'.format(
- self.np_dtype, self._dimensions)
+ return ('xla_client.Shape(np_dtype={!r}, dimensions={!r}, '
+ 'minor_to_major={!r})').format(self.np_dtype, self._dimensions,
+ self._minor_to_major)
def element_type(self):
return DTYPE_TO_XLA_ELEMENT_TYPE[str(self.np_dtype)]
@@ -202,11 +211,49 @@ class Shape(object):
raise ValueError('Tuple shape has no dimensions')
return self._dimensions
+ def minor_to_major(self):
+ return self._minor_to_major
+
def tuple_shapes(self):
if not self.is_tuple():
raise ValueError('Shape is not a tuple shape')
return self._dimensions
+ def rank(self):
+ return len(self.dimensions())
+
+ def map_leaves(self, f):
+ """Map f over each leaf-level array subshape.
+
+ Args:
+ f: The function to apply. Whenever f returns None, the identity is
+ applied instead.
+
+ Returns:
+ A new Shape with the mapped leaves.
+ """
+ if self.is_tuple():
+ children = tuple(child.map_leaves(f) for child in self.tuple_shapes())
+ return Shape(np.dtype('O'), children)
+ else:
+ mapped = f(self)
+ return self if mapped is None else mapped
+
+ def _check_minor_to_major(self):
+ mtm = self._minor_to_major
+ if self.is_tuple():
+ assert mtm is None, self
+ if mtm is not None:
+ assert self.rank() == len(mtm), self
+ assert sorted(mtm) == range(len(mtm)), self
+
+ def update_minor_to_major(self, minor_to_major):
+ if not isinstance(minor_to_major, tuple):
+ raise TypeError('minor_to_major must be a tuple')
+ updated = Shape(self.np_dtype, tuple(self.dimensions()), minor_to_major)
+ updated._check_minor_to_major() # pylint: disable=protected-access
+ return updated
+
@staticmethod
def from_numpy(npval):
@@ -223,23 +270,10 @@ def _wrap_shape(shape_info):
dtype, dims = shape_info
element_type = DTYPE_TO_XLA_ELEMENT_TYPE[str(dtype)]
if element_type == xla_data_pb2.TUPLE:
- dims = [_wrap_shape(subshape_info) for subshape_info in dims]
+ dims = tuple(_wrap_shape(subshape_info) for subshape_info in dims)
return Shape(dtype, dims)
-def _unwrap_shape(shape):
- if shape.is_tuple():
- components = tuple(
- _unwrap_shape(subshape) for subshape in shape.tuple_shapes())
- else:
- components = shape.dimensions()
- return (shape.np_dtype, components)
-
-
-def _unwrap_shapes(shapes):
- return [_unwrap_shape(shape) for shape in shapes]
-
-
def _wrap_data_handle(handle):
cdh = xla_data_pb2.ComputationDataHandle()
cdh.handle = handle
@@ -303,8 +337,7 @@ def transfer_from_outfeed(shape, replica_number=None):
Returns:
The literal value that is produced from the outfeed queue.
"""
- return c_api.TransferFromOutfeedLocalReplica(
- _unwrap_shape(shape), replica_number or 0)
+ return c_api.TransferFromOutfeedLocalReplica(shape, replica_number or 0)
class LocalComputation(object):
@@ -325,24 +358,39 @@ class LocalComputation(object):
else:
self._delete = c_api.DeleteLocalComputation
- def Compile(self, argument_shapes=(), compile_options=None):
+ def Compile(self, argument_shapes=(), compile_options=None, layout_fn=None):
if self.is_compiled:
raise ValueError('Attempt to compile a compiled local XLA computation.')
+ if layout_fn:
+ argument_shapes = [
+ shape.map_leaves(layout_fn) for shape in argument_shapes
+ ]
return LocalComputation(
- self.c_local_computation.Compile(
- _unwrap_shapes(argument_shapes), compile_options),
+ self.c_local_computation.Compile(argument_shapes, compile_options),
is_compiled=True)
- def CompileWithExampleArguments(self, arguments=(), compile_options=None):
+ def CompileWithExampleArguments(self,
+ arguments=(),
+ compile_options=None,
+ layout_fn=None):
return self.Compile(
argument_shapes=[Shape.from_numpy(arg) for arg in arguments],
- compile_options=compile_options)
+ compile_options=compile_options,
+ layout_fn=layout_fn)
- def Execute(self, arguments=()):
+ def Execute(self, arguments=(), layout_fn=None):
+ """Execute with Python values as arguments and return value."""
if not self.is_compiled:
raise ValueError('Cannot execute an uncompiled local XLA computation.')
+ argument_shapes = [Shape.from_numpy(arg) for arg in arguments]
+ if layout_fn:
+ argument_shapes = [
+ shape.map_leaves(layout_fn) for shape in argument_shapes
+ ]
+ else:
+ argument_shapes = [None for shape in argument_shapes]
arguments = tuple(map(require_numpy_array_layout, arguments))
- return self.c_local_computation.Execute(arguments)
+ return self.c_local_computation.Execute(arguments, argument_shapes)
def ExecuteWithLocalBuffers(self, arguments=()):
"""Execute with LocalBuffer arguments and return value."""
@@ -398,7 +446,7 @@ class ComputationBuilder(object):
Returns:
A ComputationDataHandle message.
"""
- return _wrap_data_handle(self._client.Infeed(_unwrap_shape(shape)))
+ return _wrap_data_handle(self._client.Infeed(shape))
def Outfeed(self, operand):
"""Enqueues an outfeed op onto the computation.
@@ -407,7 +455,7 @@ class ComputationBuilder(object):
outfeed queue for subsequent dequeue via the client API.
"""
self._client.Outfeed(
- _unwrap_data_handle(operand), _unwrap_shape(self.GetShape(operand)),
+ _unwrap_data_handle(operand), self.GetShape(operand),
''.encode('utf-8'))
def Constant(self, value):
@@ -498,8 +546,7 @@ class ComputationBuilder(object):
parameter_num = next(self._parameter_numbering)
return _wrap_data_handle(
- self._client.Parameter(
- parameter_num, _unwrap_shape(shape), name.encode('utf8')))
+ self._client.Parameter(parameter_num, shape, name.encode('utf8')))
def ParameterFromNumpy(self, value, name=None, parameter_num=None):
"""Enqueues a Parameter op onto the computation.
@@ -846,8 +893,7 @@ class ComputationBuilder(object):
shape = Shape(self.GetShape(mu).np_dtype, dims)
return _wrap_data_handle(
self._client.RngNormal(
- _unwrap_data_handle(mu), _unwrap_data_handle(sigma),
- _unwrap_shape(shape)))
+ _unwrap_data_handle(mu), _unwrap_data_handle(sigma), shape))
def RngUniform(self, a, b, dims):
"""Enqueues an RngUniform operation onto the computation.
@@ -867,8 +913,7 @@ class ComputationBuilder(object):
shape = Shape(self.GetShape(a).np_dtype, dims)
return _wrap_data_handle(
self._client.RngUniform(
- _unwrap_data_handle(a), _unwrap_data_handle(b),
- _unwrap_shape(shape)))
+ _unwrap_data_handle(a), _unwrap_data_handle(b), shape))
def While(self, cond, body, init):
"""Enqueues a While operation onto the computation.
@@ -886,10 +931,37 @@ class ComputationBuilder(object):
_unwrap_data_handle(init)))
def Dot(self, lhs, rhs):
- """Matrix multiplication between lhs and rhs."""
+ """Enqueues a dot operation onto the computation.
+
+ Args:
+ lhs: ComputationDataHandle for the rank 1 or rank 2 left-hand-side array.
+ rhs: ComputationDataHandle for the rank 1 or rank 2 right-hand-side array.
+
+ Returns: a ComputationDataHandle representing the Dot operation.
+ """
return _wrap_data_handle(
self._client.Dot(_unwrap_data_handle(lhs), _unwrap_data_handle(rhs)))
+ def DotGeneral(self, lhs, rhs, dimension_numbers):
+ """Enqueues a general dot operation onto the computation.
+
+ Args:
+ lhs: ComputationDataHandle for the left-hand-side array.
+ rhs: ComputationDataHandle for the right-hand-side array.
+ dimension_numbers: either an xla_data_pb2.DotDimensionNumbers or a nested
+ tuple ((lhs_contract, rhs_contract), (lhs_batch, rhs_batch)) of lists of
+ integers representing the dimensions to treat as contracting dimensions
+ and batch dimensions on each input operand.
+
+ Returns: a ComputationDataHandle representing the DotGeneral operation.
+ """
+ if not isinstance(dimension_numbers, xla_data_pb2.DotDimensionNumbers):
+ dimension_numbers = GetDotDimensionsFromLists(dimension_numbers)
+ return _wrap_data_handle(
+ self._client.DotGeneral(
+ _unwrap_data_handle(lhs), _unwrap_data_handle(rhs),
+ dimension_numbers))
+
def Conv(self, lhs, rhs, window_strides, padding):
"""Enqueues a Conv operation onto the computation.
@@ -1026,3 +1098,13 @@ def GetPaddingConfigFromTriples(triples):
dimension.edge_padding_high = hi
dimension.interior_padding = interior
return padding_config
+
+
+def GetDotDimensionsFromLists(dimension_numbers):
+ (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers
+ dot_dims_proto = xla_data_pb2.DotDimensionNumbers()
+ dot_dims_proto.lhs_contracting_dimensions.extend(lhs_contract)
+ dot_dims_proto.rhs_contracting_dimensions.extend(rhs_contract)
+ dot_dims_proto.lhs_batch_dimensions.extend(lhs_batch)
+ dot_dims_proto.rhs_batch_dimensions.extend(rhs_batch)
+ return dot_dims_proto
diff --git a/tensorflow/compiler/xla/python/xla_client_test.py b/tensorflow/compiler/xla/python/xla_client_test.py
index 3b5bbfd786..421fba40e3 100644
--- a/tensorflow/compiler/xla/python/xla_client_test.py
+++ b/tensorflow/compiler/xla/python/xla_client_test.py
@@ -444,6 +444,30 @@ class SingleOpTest(LocalComputationTest):
c.Dot(c.Constant(lhs), c.Constant(rhs))
self._ExecuteAndCompareClose(c, expected=np.dot(lhs, rhs))
+ def testDotGeneral(self):
+ c = self._NewComputation()
+ rng = np.random.RandomState(0)
+ lhs = NumpyArrayF32(rng.randn(10, 3, 4))
+ rhs = NumpyArrayF32(rng.randn(10, 4, 5))
+ dimension_numbers = (([2], [1]), ([0], [0]))
+ c.DotGeneral(c.Constant(lhs), c.Constant(rhs), dimension_numbers)
+ self._ExecuteAndCompareClose(c, expected=np.matmul(lhs, rhs))
+
+ def testDotGeneralWithDotDimensionNumbersProto(self):
+ c = self._NewComputation()
+ rng = np.random.RandomState(0)
+ lhs = NumpyArrayF32(rng.randn(10, 3, 4))
+ rhs = NumpyArrayF32(rng.randn(10, 4, 5))
+
+ dimension_numbers = xla_client.xla_data_pb2.DotDimensionNumbers()
+ dimension_numbers.lhs_contracting_dimensions.append(2)
+ dimension_numbers.rhs_contracting_dimensions.append(1)
+ dimension_numbers.lhs_batch_dimensions.append(0)
+ dimension_numbers.rhs_batch_dimensions.append(0)
+
+ c.DotGeneral(c.Constant(lhs), c.Constant(rhs), dimension_numbers)
+ self._ExecuteAndCompareClose(c, expected=np.matmul(lhs, rhs))
+
def testConvF32Same(self):
c = self._NewComputation()
a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32")
diff --git a/tensorflow/contrib/bayesflow/BUILD b/tensorflow/contrib/bayesflow/BUILD
index 11c3c037c4..6e0f0a0572 100644
--- a/tensorflow/contrib/bayesflow/BUILD
+++ b/tensorflow/contrib/bayesflow/BUILD
@@ -217,6 +217,7 @@ cuda_py_test(
"//tensorflow/python:platform_test",
"//tensorflow/python:random_seed",
],
+ tags = ["notsan"],
)
cuda_py_test(
diff --git a/tensorflow/contrib/cmake/external/protobuf.cmake b/tensorflow/contrib/cmake/external/protobuf.cmake
index aedb793d2a..fd05fa6d47 100644
--- a/tensorflow/contrib/cmake/external/protobuf.cmake
+++ b/tensorflow/contrib/cmake/external/protobuf.cmake
@@ -16,7 +16,7 @@ include (ExternalProject)
set(PROTOBUF_INCLUDE_DIRS ${CMAKE_CURRENT_BINARY_DIR}/protobuf/src/protobuf/src)
set(PROTOBUF_URL https://github.com/google/protobuf.git)
-set(PROTOBUF_TAG b04e5cba356212e4e8c66c61bbe0c3a20537c5b9)
+set(PROTOBUF_TAG 396336eb961b75f03b25824fe86cf6490fb75e3a)
if(WIN32)
set(protobuf_STATIC_LIBRARIES
diff --git a/tensorflow/contrib/compiler/jit_test.py b/tensorflow/contrib/compiler/jit_test.py
index 2108e42bce..29a593f6bc 100644
--- a/tensorflow/contrib/compiler/jit_test.py
+++ b/tensorflow/contrib/compiler/jit_test.py
@@ -24,6 +24,7 @@ from tensorflow.python.framework import function
from tensorflow.python.framework import op_def_registry
from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed
+from tensorflow.python.framework import test_util
from tensorflow.python.ops import gradients
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
@@ -169,6 +170,7 @@ class JITTest(test.TestCase):
self.assertEqual(b"jit_scope_0", func_attrs["_XlaScope"].s)
+@test_util.with_c_api
class CompilationEnabledInGradientTest(test.TestCase):
def testCompilationInGradient(self):
@@ -188,7 +190,7 @@ class CompilationEnabledInGradientTest(test.TestCase):
for cg in c_grad_ops:
self.assertTrue(cg.get_attr("_XlaCompile"))
for ncg in nc_grad_ops:
- with self.assertRaisesRegexp(ValueError, "No attr named"):
+ with self.assertRaisesRegexp(ValueError, "[Nn]o attr named"):
ncg.get_attr("_XlaCompile")
# d/dx (x ** 4) = 4 * (x ** 3)
diff --git a/tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py b/tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py
index 3f64475e47..dbc35097dd 100644
--- a/tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py
+++ b/tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py
@@ -24,6 +24,7 @@ import numpy as np
from tensorflow.contrib.data.python.ops import iterator_ops as contrib_iterator_ops
from tensorflow.python.data.ops import iterator_ops
+from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
@@ -35,6 +36,20 @@ from tensorflow.python.training import saver as saver_lib
from tensorflow.python.util import nest
+def remove_variants(get_next_op):
+ # TODO(b/72408568): Remove this once session.run can get
+ # variant tensors.
+ """Remove variants from a nest structure, so sess.run will execute."""
+
+ def _remove_variant(x):
+ if isinstance(x, ops.Tensor) and x.dtype == dtypes.variant:
+ return ()
+ else:
+ return x
+
+ return nest.map_structure(_remove_variant, get_next_op)
+
+
class DatasetSerializationTestBase(test.TestCase):
"""Base class for testing serializable datasets."""
@@ -235,6 +250,7 @@ class DatasetSerializationTestBase(test.TestCase):
saver = self._import_meta_graph()
init_op, get_next_op = self._get_iterator_ops_from_collection(
ds_fn, sparse_tensors=sparse_tensors)
+ get_next_op = remove_variants(get_next_op)
with self.test_session(graph=g) as sess:
self._restore(saver, sess)
self._initialize(init_op, sess)
@@ -297,6 +313,7 @@ class DatasetSerializationTestBase(test.TestCase):
with ops.Graph().as_default() as g:
_, get_next_op, saver = self._build_graph(
ds_fn2, sparse_tensors=sparse_tensors)
+ get_next_op = remove_variants(get_next_op)
with self.test_session(graph=g) as sess:
self._restore(saver, sess)
for _ in range(num_outputs - break_point):
@@ -357,6 +374,7 @@ class DatasetSerializationTestBase(test.TestCase):
with ops.Graph().as_default() as g:
get_next_op, saver = self._build_empty_graph(
ds_fn, sparse_tensors=sparse_tensors)
+ get_next_op = remove_variants(get_next_op)
with self.test_session(graph=g) as sess:
self._restore(saver, sess)
for _ in range(num_outputs - break_point):
@@ -390,6 +408,7 @@ class DatasetSerializationTestBase(test.TestCase):
with ops.Graph().as_default() as g:
init_op, get_next_op, saver = self._build_graph(
ds_fn, sparse_tensors=sparse_tensors)
+ get_next_op = remove_variants(get_next_op)
with self.test_session(graph=g) as sess:
self._initialize(init_op, sess)
for _ in range(break_point):
@@ -485,11 +504,13 @@ class DatasetSerializationTestBase(test.TestCase):
else:
init_op, get_next_op, saver = self._build_graph(
ds_fn, sparse_tensors=sparse_tensors)
+ get_next_op = remove_variants(get_next_op)
return init_op, get_next_op, saver
for i in range(len(break_points) + 1):
with ops.Graph().as_default() as g:
init_op, get_next_op, saver = get_ops()
+ get_next_op = remove_variants(get_next_op)
with self.test_session(graph=g) as sess:
if ckpt_saved:
if init_before_restore:
diff --git a/tensorflow/contrib/eager/python/examples/mnist/mnist.py b/tensorflow/contrib/eager/python/examples/mnist/mnist.py
index 2a7be95811..ed7dbc8904 100644
--- a/tensorflow/contrib/eager/python/examples/mnist/mnist.py
+++ b/tensorflow/contrib/eager/python/examples/mnist/mnist.py
@@ -95,8 +95,7 @@ class MNISTModel(tfe.Network):
x = self.max_pool2d(x)
x = tf.layers.flatten(x)
x = self.fc1(x)
- if training:
- x = self.dropout(x)
+ x = self.dropout(x, training=training)
x = self.fc2(x)
return x
diff --git a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py
index caa9dd8323..c9153c9352 100644
--- a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py
+++ b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py
@@ -457,6 +457,13 @@ def _get_local_devices(device_type):
def _split_batch(features, labels, number_of_shards, device):
"""Split input features and labes into batches."""
+ def ensure_divisible_by_shards(sequence):
+ batch_size = ops_lib.convert_to_tensor(sequence).get_shape()[0]
+ if batch_size % number_of_shards != 0:
+ raise ValueError(
+ 'Batch size {} needs to be divisible by the number of GPUs, which '
+ 'is {}.'.format(batch_size, number_of_shards))
+
def split_dictionary(dictionary):
"""Split a dictionary into shards."""
shards = [{} for _ in range(number_of_shards)]
@@ -467,6 +474,7 @@ def _split_batch(features, labels, number_of_shards, device):
sp_input=tensor, num_split=number_of_shards, axis=0)):
shards[i][name] = shard
else:
+ ensure_divisible_by_shards(tensor)
for i, shard in enumerate(array_ops.split(tensor, number_of_shards)):
shards[i][name] = shard
return shards
@@ -476,6 +484,7 @@ def _split_batch(features, labels, number_of_shards, device):
if isinstance(features, dict):
feature_shards = split_dictionary(features)
else:
+ ensure_divisible_by_shards(features)
feature_shards = array_ops.split(features, number_of_shards)
if labels is None:
@@ -483,6 +492,7 @@ def _split_batch(features, labels, number_of_shards, device):
elif isinstance(labels, dict):
label_shards = split_dictionary(labels)
else:
+ ensure_divisible_by_shards(labels)
label_shards = array_ops.split(labels, number_of_shards)
return feature_shards, label_shards
diff --git a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py
index 03d31226af..6936f8a131 100644
--- a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py
+++ b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py
@@ -37,6 +37,7 @@ from tensorflow.python.feature_column import feature_column
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops as ops_lib
+from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
@@ -433,6 +434,17 @@ class ReplicateModelTest(test_util.TensorFlowTestCase):
'probabilities': np.array([[0.1], [0.02]])
}, session.run(estimator_spec.predictions))
+ def test_batch_size_that_is_not_divisible_by_the_number_of_gpus(self):
+ features = np.array([[1.0], [2.0], [3.0]])
+ labels = np.array([[1.0], [2.0], [3.0]])
+
+ with self.assertRaisesRegexp(
+ ValueError, '.*Batch.+size.+needs.+to.+be.+divisible.+by.+GPUs.+'):
+ replicated_model_fn = replicate_model_fn.replicate_model_fn(
+ self.model_fn, devices=['/gpu:0', '/gpu:1'])
+ _ = replicated_model_fn(
+ features, labels, model_fn_lib.ModeKeys.TRAIN, self.params)
+
def test_unsupported_loss_reduction(self):
with self.assertRaisesRegexp(ValueError,
'.+none.+reduction.+is.+specified.+'):
@@ -981,8 +993,13 @@ class SplitBatchTest(test_util.TensorFlowTestCase):
return list(map(evaluate_items, first_list)), list(
map(evaluate_items, second_list))
+ def assertSparseValuesEqual(self, a, b):
+ self.assertAllEqual(a.indices, b.indices)
+ self.assertAllEqual(a.values, b.values)
+ self.assertAllEqual(a.dense_shape, b.dense_shape)
+
def test_simple_half_split(self):
- with self.test_session() as session: # pylint: disable=unused-variable
+ with self.test_session():
features = [0.0, 1.0, 2.0, 3.0]
labels = [10.0, 11.0, 12.0, 13.0]
feature_shards, label_shards = replicate_model_fn._split_batch(
@@ -995,7 +1012,7 @@ class SplitBatchTest(test_util.TensorFlowTestCase):
self.assertAllEqual([[10.0, 11.0], [12.0, 13.0]], label_shards)
def test_to_each_their_own(self):
- with self.test_session() as session: # pylint: disable=unused-variable
+ with self.test_session():
features = [0.0, 1.0, 2.0, 3.0]
labels = [10.0, 11.0, 12.0, 13.0]
feature_shards, label_shards = replicate_model_fn._split_batch(
@@ -1008,7 +1025,7 @@ class SplitBatchTest(test_util.TensorFlowTestCase):
self.assertAllEqual([[10.0], [11.0], [12.0], [13.0]], label_shards)
def test_one_batch(self):
- with self.test_session() as session: # pylint: disable=unused-variable
+ with self.test_session():
features = [0.0, 1.0, 2.0, 3.0]
labels = [10.0, 11.0, 12.0, 13.0]
feature_shards, label_shards = replicate_model_fn._split_batch(
@@ -1021,7 +1038,7 @@ class SplitBatchTest(test_util.TensorFlowTestCase):
self.assertAllEqual([[10.0, 11.0, 12.0, 13.0]], label_shards)
def test_half_split_in_dictionary(self):
- with self.test_session() as session: # pylint: disable=unused-variable
+ with self.test_session():
features = {'first': [0.0, 1.0, 2.0, 3.0], 'second': [4.0, 5.0, 6.0, 7.0]}
labels = [10.0, 11.0, 12.0, 13.0]
@@ -1035,6 +1052,60 @@ class SplitBatchTest(test_util.TensorFlowTestCase):
self.assertAllEqual([10.0, 11.0], label_shards[0].eval())
self.assertAllEqual([12.0, 13.0], label_shards[1].eval())
+ def test_sparse_tensor_can_be_split_unevenly(self):
+ with self.test_session():
+ features = {
+ 'x':
+ sparse_tensor.SparseTensor(
+ indices=[[0, 0], [1, 2], [2, 2]],
+ values=[1.0, 2.0, 3.0],
+ dense_shape=[3, 4])
+ }
+ labels = np.array([[1.0], [2.0]])
+
+ feature_shards, label_shards = replicate_model_fn._split_batch(
+ features, labels, 2, device='/gpu:0')
+
+ self.assertSparseValuesEqual(
+ sparse_tensor.SparseTensorValue(
+ indices=[[0, 0], [1, 2]], values=[1., 2.], dense_shape=[2, 4]),
+ feature_shards[0]['x'].eval())
+ self.assertSparseValuesEqual(
+ sparse_tensor.SparseTensorValue(
+ indices=[[0, 2]], values=[3.], dense_shape=[1, 4]),
+ feature_shards[1]['x'].eval())
+ self.assertAllEqual([[1.0]], label_shards[0].eval())
+ self.assertAllEqual([[2.0]], label_shards[1].eval())
+
+ def test_sparse_tensor_can_be_split_unevenly_repeated_row(self):
+ with self.test_session():
+ features = {
+ 'x':
+ sparse_tensor.SparseTensor(
+ indices=[[0, 0], [1, 0], [1, 1]],
+ values=[1.0, 2.0, 3.0],
+ dense_shape=[3, 4])
+ }
+ labels = np.array([[1.0], [2.0]])
+
+ feature_shards, label_shards = replicate_model_fn._split_batch(
+ features, labels, 2, device='/gpu:0')
+
+ print(feature_shards[0]['x'].eval())
+ print(feature_shards[1]['x'].eval())
+ self.assertSparseValuesEqual(
+ sparse_tensor.SparseTensorValue(
+ indices=[[0, 0], [1, 0], [1, 1]],
+ values=[1., 2., 3.],
+ dense_shape=[2, 4]), feature_shards[0]['x'].eval())
+
+ second_batch = feature_shards[1]['x'].eval()
+ self.assertFalse(len(second_batch.indices))
+ self.assertFalse(len(second_batch.values))
+ self.assertAllEqual([1, 4], second_batch.dense_shape)
+ self.assertAllEqual([[1.0]], label_shards[0].eval())
+ self.assertAllEqual([[2.0]], label_shards[1].eval())
+
def test_one_batch_in_dictionary(self):
with self.test_session() as session: # pylint: disable=unused-variable
features = {'first': [0.0, 1.0, 2.0, 3.0], 'second': [4.0, 5.0, 6.0, 7.0]}
diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator.py b/tensorflow/contrib/learn/python/learn/estimators/estimator.py
index 8d59fe66d9..63d0f1e1d4 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/estimator.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/estimator.py
@@ -600,7 +600,8 @@ class BaseEstimator(sklearn.BaseEstimator, evaluable.Evaluable,
input_fn=None,
batch_size=None,
outputs=None,
- as_iterable=True):
+ as_iterable=True,
+ iterate_batches=False):
"""Returns predictions for given features.
Args:
@@ -616,6 +617,9 @@ class BaseEstimator(sklearn.BaseEstimator, evaluable.Evaluable,
for each example until inputs are exhausted. Note: The inputs must
terminate if you want the iterable to terminate (e.g. be sure to pass
num_epochs=1 if you are using something like read_batch_features).
+ iterate_batches: If True, yield the whole batch at once instead of
+ decomposing the batch into individual samples. Only relevant when
+ as_iterable is True.
Returns:
A numpy array of predicted classes or regression values if the
@@ -635,7 +639,8 @@ class BaseEstimator(sklearn.BaseEstimator, evaluable.Evaluable,
input_fn=input_fn,
feed_fn=feed_fn,
outputs=outputs,
- as_iterable=as_iterable)
+ as_iterable=as_iterable,
+ iterate_batches=iterate_batches)
def get_variable_value(self, name):
"""Returns value of the variable given by name.
diff --git a/tensorflow/contrib/lite/interpreter.cc b/tensorflow/contrib/lite/interpreter.cc
index 69a597dc5a..a8db149eaa 100644
--- a/tensorflow/contrib/lite/interpreter.cc
+++ b/tensorflow/contrib/lite/interpreter.cc
@@ -36,6 +36,10 @@ constexpr const int kSlotsToReserve = 128;
namespace tflite {
// A trivial implementation of GraphInfo around the Interpreter.
+// NOTE: this interpreter info represents the subset of the
+// graph that is executed according to execution plan. Thus,
+// the indices are execution plan indices rather than raw node
+// indices.
class InterpreterInfo : public GraphInfo {
public:
explicit InterpreterInfo(Interpreter* interpreter)
@@ -45,9 +49,12 @@ class InterpreterInfo : public GraphInfo {
TfLiteTensor* tensor(size_t index) override {
return interpreter_->tensor(index);
}
- size_t num_nodes() const override { return interpreter_->nodes_size(); }
+ size_t num_nodes() const override {
+ return interpreter_->execution_plan().size();
+ }
const TfLiteNode& node(size_t index) const override {
- return interpreter_->node_and_registration(index)->first;
+ int node_index = interpreter_->execution_plan()[index];
+ return interpreter_->node_and_registration(node_index)->first;
}
const std::vector<int>& inputs() const override {
return interpreter_->inputs();
@@ -73,7 +80,7 @@ Interpreter::Interpreter(ErrorReporter* error_reporter)
// Reserve some space for the tensors to avoid excessive resizing.
tensors_.reserve(kSlotsToReserve);
nodes_and_registration_.reserve(kSlotsToReserve);
- next_node_to_prepare_ = 0;
+ next_execution_plan_index_to_prepare_ = 0;
UseNNAPI(false);
}
@@ -160,7 +167,7 @@ TfLiteIntArray* convertVectorToTfLiteIntArray(const std::vector<int>& x) {
} // namespace
TfLiteStatus Interpreter::AllocateTensors() {
- next_node_to_prepare_ = 0;
+ next_execution_plan_index_to_prepare_ = 0;
if (memory_planner_) {
TF_LITE_ENSURE_STATUS(memory_planner_->ResetAllocations());
}
@@ -190,7 +197,8 @@ TfLiteStatus Interpreter::AddNodeWithParameters(
&context_,
CheckTensorIndices("node outputs", outputs.data(), outputs.size()));
- if (node_index) *node_index = nodes_and_registration_.size();
+ int new_node_index = nodes_and_registration_.size();
+ if (node_index) *node_index = new_node_index;
nodes_and_registration_.resize(nodes_and_registration_.size() + 1);
auto& node_and_reg = nodes_and_registration_.back();
TfLiteNode& node = node_and_reg.first;
@@ -213,6 +221,7 @@ TfLiteStatus Interpreter::AddNodeWithParameters(
}
node.builtin_data = builtin_data_deleter.release();
node_and_reg.second = *registration;
+ execution_plan_.push_back(new_node_index);
return kTfLiteOk;
}
@@ -240,16 +249,19 @@ bool HasDynamicTensor(const TfLiteContext& context,
return false;
}
-TfLiteStatus Interpreter::PrepareOpsStartingAt(int first_node,
- int* last_node_prepared) {
- for (int i = first_node; i < nodes_and_registration_.size(); i++) {
- TfLiteNode& node = nodes_and_registration_[i].first;
- const TfLiteRegistration& registration = nodes_and_registration_[i].second;
+TfLiteStatus Interpreter::PrepareOpsStartingAt(
+ int first_execution_plan_index, int* last_execution_plan_index_prepared) {
+ for (int execution_plan_index = first_execution_plan_index;
+ execution_plan_index < execution_plan_.size(); execution_plan_index++) {
+ int node_index = execution_plan_[execution_plan_index];
+ TfLiteNode& node = nodes_and_registration_[node_index].first;
+ const TfLiteRegistration& registration =
+ nodes_and_registration_[node_index].second;
if (OpPrepare(registration, &node) == kTfLiteError) {
return kTfLiteError;
}
- *last_node_prepared = i;
+ *last_execution_plan_index_prepared = execution_plan_index;
// Discontinue if the node has dynamic outputs. Note that we don't
// stop for dynamic temporary tensors since they won't affect the
@@ -268,14 +280,14 @@ TfLiteStatus Interpreter::PrepareOpsAndTensors() {
memory_planner_->PlanAllocations();
}
- int last_node_prepared = 0;
+ int last_exec_plan_index_prepared = 0;
- TF_LITE_ENSURE_STATUS(
- PrepareOpsStartingAt(next_node_to_prepare_, &last_node_prepared));
+ TF_LITE_ENSURE_STATUS(PrepareOpsStartingAt(
+ next_execution_plan_index_to_prepare_, &last_exec_plan_index_prepared));
TF_LITE_ENSURE_STATUS(memory_planner_->ExecuteAllocations(
- next_node_to_prepare_, last_node_prepared));
+ next_execution_plan_index_to_prepare_, last_exec_plan_index_prepared));
- next_node_to_prepare_ = last_node_prepared + 1;
+ next_execution_plan_index_to_prepare_ = last_exec_plan_index_prepared + 1;
return kTfLiteOk;
}
@@ -291,7 +303,8 @@ TfLiteStatus Interpreter::Invoke() {
TfLiteStatus status = kTfLiteOk;
if (nnapi_delegate_) {
- if (next_node_to_prepare_ == nodes_and_registration_.size()) {
+ TF_LITE_ENSURE_STATUS(PrepareOpsAndTensors());
+ if (next_execution_plan_index_to_prepare_ == execution_plan_.size()) {
TF_LITE_ENSURE_OK(&context_, nnapi_delegate_->Invoke(this));
return kTfLiteOk;
} else {
@@ -311,13 +324,17 @@ TfLiteStatus Interpreter::Invoke() {
// TODO(b/71913981): we should force recalculation in the presence of dynamic
// tensors, because they may have new value which in turn may affect shapes
// and allocations.
- for (int i = 0; i < nodes_and_registration_.size(); i++) {
- if (i == next_node_to_prepare_) {
+ for (int execution_plan_index = 0;
+ execution_plan_index < execution_plan_.size(); execution_plan_index++) {
+ if (execution_plan_index == next_execution_plan_index_to_prepare_) {
TF_LITE_ENSURE_STATUS(PrepareOpsAndTensors());
- TF_LITE_ENSURE(&context_, next_node_to_prepare_ >= i);
+ TF_LITE_ENSURE(&context_, next_execution_plan_index_to_prepare_ >=
+ execution_plan_index);
}
- TfLiteNode& node = nodes_and_registration_[i].first;
- const TfLiteRegistration& registration = nodes_and_registration_[i].second;
+ int node_index = execution_plan_[execution_plan_index];
+ TfLiteNode& node = nodes_and_registration_[node_index].first;
+ const TfLiteRegistration& registration =
+ nodes_and_registration_[node_index].second;
if (OpInvoke(registration, &node) == kTfLiteError) {
status = kTfLiteError;
}
@@ -421,6 +438,14 @@ TfLiteStatus Interpreter::SetTensorParametersReadWrite(
return kTfLiteOk;
}
+TfLiteStatus Interpreter::SetExecutionPlan(const std::vector<int>& new_plan) {
+ for (int node_index : new_plan) {
+ TF_LITE_ENSURE(&context_, node_index >= 0 && node_index < nodes_size());
+ }
+ execution_plan_ = new_plan;
+ return kTfLiteOk;
+}
+
TfLiteStatus Interpreter::ResizeTensorImpl(TfLiteTensor* tensor,
TfLiteIntArray* new_size) {
// Note that in theory we could resize kTfLiteArenaRwPersistent tensors too.
diff --git a/tensorflow/contrib/lite/interpreter.h b/tensorflow/contrib/lite/interpreter.h
index 52e52df1b6..c822557d02 100644
--- a/tensorflow/contrib/lite/interpreter.h
+++ b/tensorflow/contrib/lite/interpreter.h
@@ -166,6 +166,13 @@ class Interpreter {
// Return the number of ops in the model.
int nodes_size() const { return nodes_and_registration_.size(); }
+ // WARNING: Experimental interface, subject to change
+ const std::vector<int>& execution_plan() const { return execution_plan_; }
+
+ // WARNING: Experimental interface, subject to change
+ // Overrides execution plan. This bounds checks indices sent in.
+ TfLiteStatus SetExecutionPlan(const std::vector<int>& new_plan);
+
// Get a tensor data structure.
// TODO(aselle): Create a safe ArrayHandle interface to avoid exposing this
// read/write access to structure
@@ -279,7 +286,8 @@ class Interpreter {
// dynamic tensors is found or all ops have been prepared. Fill
// 'last_node_prepared' with the id of the op containing dynamic tensors, or
// the last in the graph.
- TfLiteStatus PrepareOpsStartingAt(int first_node, int* last_node_prepared);
+ TfLiteStatus PrepareOpsStartingAt(int first_execution_plan_index,
+ int* last_execution_plan_index_prepared);
// Tensors needed by the interpreter. Use `AddTensors` to add more blank
// tensor entries. Note, `tensors_.data()` needs to be synchronized to the
@@ -354,7 +362,14 @@ class Interpreter {
// node id, and execute the node to generate the output tensor before continue
// to allocate successors. This process repeats until all nodes are executed.
// NOTE: this relies on the order of nodes that is in topological order.
- int next_node_to_prepare_;
+ int next_execution_plan_index_to_prepare_;
+
+ // WARNING: This is an experimental interface that is subject to change.
+ // This is a list of node indices (to index into nodes_and_registration).
+ // This represents a valid topological sort (dependency ordered) execution
+ // plan. In particular, it is valid for this ordering to contain only a
+ // subset of the node indices.
+ std::vector<int> execution_plan_;
// Whether to delegate to NN API
std::unique_ptr<NNAPIDelegate> nnapi_delegate_;
diff --git a/tensorflow/contrib/lite/interpreter_test.cc b/tensorflow/contrib/lite/interpreter_test.cc
index edff210943..2ab4bb6567 100644
--- a/tensorflow/contrib/lite/interpreter_test.cc
+++ b/tensorflow/contrib/lite/interpreter_test.cc
@@ -514,6 +514,133 @@ TEST(BasicInterpreter, TestCustomErrorReporter) {
ASSERT_EQ(reporter.calls, 1);
}
+// Test fixture that allows playing with execution plans. It creates a two
+// node graph that can be executed in either [0,1] order or [1,0] order.
+// The CopyOp records when it is invoked in the class member run_order_
+// so we can test whether the execution plan was honored.
+class TestExecutionPlan : public ::testing::Test {
+ // Encapsulates the node ids and provides them to a C primitive data type
+ // Allocatable with placement new, but never destructed, so make sure this
+ // doesn't own any heap allocated data. This is then is used as op local
+ // data to allow access to the test fixture data.
+ class CallReporting {
+ public:
+ CallReporting(int node_id, std::vector<int>* run_order)
+ : node_id_(node_id), run_order_(run_order) {}
+
+ void Record() { run_order_->push_back(node_id_); }
+
+ private:
+ // The node id for this particular node
+ int node_id_;
+ // A pointer to the global run-order
+ std::vector<int>* run_order_;
+ };
+
+ // Build a kernel registration for an op that copies its one input
+ // to an output
+ TfLiteRegistration CopyOpRegistration() {
+ TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr};
+
+ reg.prepare = [](TfLiteContext* context, TfLiteNode* node) {
+ // Set output size to input size
+ TfLiteTensor* tensor0 = &context->tensors[node->inputs->data[0]];
+ TfLiteTensor* tensor1 = &context->tensors[node->outputs->data[0]];
+ TfLiteIntArray* newSize = TfLiteIntArrayCopy(tensor0->dims);
+ return context->ResizeTensor(context, tensor1, newSize);
+ };
+
+ reg.invoke = [](TfLiteContext* context, TfLiteNode* node) {
+ CallReporting* call_reporting =
+ reinterpret_cast<CallReporting*>(node->builtin_data);
+ // Copy input data to output data.
+ TfLiteTensor* a0 = &context->tensors[node->inputs->data[0]];
+ TfLiteTensor* a1 = &context->tensors[node->outputs->data[0]];
+ int num = a0->dims->data[0];
+ for (int i = 0; i < num; i++) {
+ a1->data.f[i] = a0->data.f[i];
+ }
+ call_reporting->Record();
+ return kTfLiteOk;
+ };
+ return reg;
+ }
+
+ // Adds a copy node going from tensor `input` to output tensor `output`.
+ // Note, input is used as the node_id. Inject run_order as op accessible
+ // data. Note: this is a little strange of a way to do this, but it is
+ // using op functionality to avoid static global variables.
+ void MakeCopyNode(int input, int output) {
+ // Ownership of call_reporting is taken by interpreter (malloc is used due
+ // to nodes being a C99 interface so free() is used).
+ TfLiteRegistration copy_op = CopyOpRegistration();
+ CallReporting* call_reporting_1 =
+ reinterpret_cast<CallReporting*>(malloc(sizeof(CallReporting)));
+ new (call_reporting_1) CallReporting(input, &run_order_);
+ ASSERT_EQ(interpreter_.AddNodeWithParameters(
+ {0}, {2}, nullptr, 0,
+ reinterpret_cast<void*>(call_reporting_1), &copy_op),
+ kTfLiteOk);
+ ASSERT_EQ(interpreter_.ResizeInputTensor(input, {3}), kTfLiteOk);
+ }
+
+ void SetUp() final {
+ // Add two inputs and two outputs that don't depend on each other
+ ASSERT_EQ(interpreter_.AddTensors(4), kTfLiteOk);
+ interpreter_.SetInputs({0, 1});
+ interpreter_.SetOutputs({2, 3});
+ TfLiteQuantizationParams quantized;
+ for (int tensor_index = 0; tensor_index < 4; tensor_index++) {
+ ASSERT_EQ(interpreter_.SetTensorParametersReadWrite(
+ tensor_index, kTfLiteFloat32, "", {3}, quantized),
+ kTfLiteOk);
+ }
+
+ // Define two copy functions that also use the user_data to report that
+ // they were called.
+ // i.e. tensor[2] = copy(tensor[0]); tensor[3] = copy(tensor[1]);
+ // thus we can reorder the two nodes arbitrary and still satisfy dependency
+ // order.
+ MakeCopyNode(0, 2);
+ MakeCopyNode(1, 3);
+
+ ASSERT_EQ(interpreter_.AllocateTensors(), kTfLiteOk);
+ }
+
+ protected:
+ Interpreter interpreter_;
+
+ // list of node_ids that were run
+ std::vector<int> run_order_;
+};
+
+TEST_F(TestExecutionPlan, DefaultExecutionPlan) {
+ // Check default order
+ ASSERT_EQ(interpreter_.Invoke(), kTfLiteOk);
+ ASSERT_EQ(run_order_, std::vector<int>({0, 1}));
+}
+
+TEST_F(TestExecutionPlan, ReversedExecutionPlan) {
+ // Check reversed order
+ interpreter_.SetExecutionPlan({1, 0});
+ ASSERT_EQ(interpreter_.Invoke(), kTfLiteOk);
+ ASSERT_EQ(run_order_, std::vector<int>({1, 0}));
+}
+
+TEST_F(TestExecutionPlan, SubsetExecutionPlan) {
+ // Check running only node index 1
+ interpreter_.SetExecutionPlan({1});
+ ASSERT_EQ(interpreter_.Invoke(), kTfLiteOk);
+ ASSERT_EQ(run_order_, std::vector<int>({1}));
+}
+
+TEST_F(TestExecutionPlan, NullExecutionPlan) {
+ // Check nothing executed.
+ interpreter_.SetExecutionPlan({});
+ ASSERT_EQ(interpreter_.Invoke(), kTfLiteOk);
+ ASSERT_EQ(run_order_, std::vector<int>());
+}
+
} // namespace
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/BUILD b/tensorflow/contrib/lite/kernels/BUILD
index d9051f3516..8c40adfae5 100644
--- a/tensorflow/contrib/lite/kernels/BUILD
+++ b/tensorflow/contrib/lite/kernels/BUILD
@@ -249,6 +249,7 @@ tf_cc_test(
":builtin_ops",
"//tensorflow/contrib/lite:framework",
"//tensorflow/contrib/lite/kernels:test_util",
+ "@com_google_absl//absl/memory",
"@com_google_googletest//:gtest",
],
)
diff --git a/tensorflow/contrib/lite/kernels/conv.cc b/tensorflow/contrib/lite/kernels/conv.cc
index 37f499a4d0..a5095e1e64 100644
--- a/tensorflow/contrib/lite/kernels/conv.cc
+++ b/tensorflow/contrib/lite/kernels/conv.cc
@@ -42,7 +42,7 @@ namespace conv {
enum KernelType {
kReference,
kGenericOptimized, // Neon-free
- kNeonOptimized,
+ kMultithreadOptimized,
};
struct OpData {
@@ -290,26 +290,33 @@ void EvalQuantized(TfLiteContext* context, TfLiteNode* node,
auto filter_offset = -filter->params.zero_point;
auto output_offset = output->params.zero_point;
- if (kernel_type == kReference) {
- reference_ops::Conv(
- GetTensorData<uint8_t>(input), GetTensorDims(input), input_offset,
- GetTensorData<uint8_t>(filter), GetTensorDims(filter), filter_offset,
- GetTensorData<int32_t>(bias), GetTensorDims(bias), params->stride_width,
- params->stride_height, data->padding.width, data->padding.height,
- output_offset, data->output_multiplier, data->output_shift,
- data->output_activation_min, data->output_activation_max,
- GetTensorData<uint8_t>(output), GetTensorDims(output),
- GetTensorData<uint8_t>(im2col), GetTensorDims(im2col), gemm_context);
- } else {
- optimized_ops::Conv(
- GetTensorData<uint8_t>(input), GetTensorDims(input), input_offset,
- GetTensorData<uint8_t>(filter), GetTensorDims(filter), filter_offset,
- GetTensorData<int32_t>(bias), GetTensorDims(bias), params->stride_width,
- params->stride_height, data->padding.width, data->padding.height,
- output_offset, data->output_multiplier, data->output_shift,
- data->output_activation_min, data->output_activation_max,
- GetTensorData<uint8_t>(output), GetTensorDims(output),
- GetTensorData<uint8_t>(im2col), GetTensorDims(im2col), gemm_context);
+ switch (kernel_type) {
+ case kReference:
+ reference_ops::Conv(
+ GetTensorData<uint8_t>(input), GetTensorDims(input), input_offset,
+ GetTensorData<uint8_t>(filter), GetTensorDims(filter), filter_offset,
+ GetTensorData<int32_t>(bias), GetTensorDims(bias),
+ params->stride_width, params->stride_height, data->padding.width,
+ data->padding.height, output_offset, data->output_multiplier,
+ data->output_shift, data->output_activation_min,
+ data->output_activation_max, GetTensorData<uint8_t>(output),
+ GetTensorDims(output), GetTensorData<uint8_t>(im2col),
+ GetTensorDims(im2col), gemm_context);
+ break;
+ case kGenericOptimized:
+ case kMultithreadOptimized:
+ // There is only one optimized implementation for Quantized Conv.
+ optimized_ops::Conv(
+ GetTensorData<uint8_t>(input), GetTensorDims(input), input_offset,
+ GetTensorData<uint8_t>(filter), GetTensorDims(filter), filter_offset,
+ GetTensorData<int32_t>(bias), GetTensorDims(bias),
+ params->stride_width, params->stride_height, data->padding.width,
+ data->padding.height, output_offset, data->output_multiplier,
+ data->output_shift, data->output_activation_min,
+ data->output_activation_max, GetTensorData<uint8_t>(output),
+ GetTensorDims(output), GetTensorData<uint8_t>(im2col),
+ GetTensorDims(im2col), gemm_context);
+ break;
}
}
@@ -322,31 +329,46 @@ void EvalFloat(TfLiteContext* context, TfLiteNode* node,
CalculateActivationRangeFloat(params->activation, &output_activation_min,
&output_activation_max);
- if (kernel_type == kReference) {
- reference_ops::Conv(GetTensorData<float>(input), GetTensorDims(input),
- GetTensorData<float>(filter), GetTensorDims(filter),
- GetTensorData<float>(bias), GetTensorDims(bias),
- params->stride_width, params->stride_height,
- data->padding.width, data->padding.height,
- output_activation_min, output_activation_max,
- GetTensorData<float>(output), GetTensorDims(output),
- GetTensorData<float>(im2col), GetTensorDims(im2col));
- } else {
- const float* filter_data;
- if (data->need_hwcn_weights) {
- filter_data = GetTensorData<float>(hwcn_weights);
- } else {
- filter_data = GetTensorData<float>(filter);
+ switch (kernel_type) {
+ case kReference: {
+ reference_ops::Conv(GetTensorData<float>(input), GetTensorDims(input),
+ GetTensorData<float>(filter), GetTensorDims(filter),
+ GetTensorData<float>(bias), GetTensorDims(bias),
+ params->stride_width, params->stride_height,
+ data->padding.width, data->padding.height,
+ output_activation_min, output_activation_max,
+ GetTensorData<float>(output), GetTensorDims(output),
+ GetTensorData<float>(im2col), GetTensorDims(im2col));
+ break;
+ }
+ case kGenericOptimized: {
+ optimized_ops::Conv(GetTensorData<float>(input), GetTensorDims(input),
+ GetTensorData<float>(filter), GetTensorDims(filter),
+ GetTensorData<float>(bias), GetTensorDims(bias),
+ params->stride_width, params->stride_height,
+ data->padding.width, data->padding.height,
+ output_activation_min, output_activation_max,
+ GetTensorData<float>(output), GetTensorDims(output),
+ GetTensorData<float>(im2col), GetTensorDims(im2col));
+ break;
+ }
+ case kMultithreadOptimized: {
+ const float* filter_data;
+ if (data->need_hwcn_weights) {
+ filter_data = GetTensorData<float>(hwcn_weights);
+ } else {
+ filter_data = GetTensorData<float>(filter);
+ }
+ multithreaded_ops::Conv(
+ GetTensorData<float>(input), GetTensorDims(input), filter_data,
+ GetTensorDims(filter), GetTensorData<float>(bias),
+ GetTensorDims(bias), params->stride_width, params->stride_height,
+ data->padding.width, data->padding.height, params->padding,
+ output_activation_min, output_activation_max,
+ GetTensorData<float>(output), GetTensorDims(output),
+ GetTensorData<float>(im2col), GetTensorDims(im2col));
+ break;
}
-
- multithreaded_ops::Conv(
- GetTensorData<float>(input), GetTensorDims(input), filter_data,
- GetTensorDims(filter), GetTensorData<float>(bias), GetTensorDims(bias),
- params->stride_width, params->stride_height, data->padding.width,
- data->padding.height, params->padding, output_activation_min,
- output_activation_max, GetTensorData<float>(output),
- GetTensorDims(output), GetTensorData<float>(im2col),
- GetTensorDims(im2col));
}
}
@@ -407,18 +429,14 @@ TfLiteRegistration* Register_CONVOLUTION_GENERIC_OPT() {
return &r;
}
-TfLiteRegistration* Register_CONVOLUTION_NEON_OPT() {
+TfLiteRegistration* Register_CONVOLUTION_MULTITHREADED_OPT() {
static TfLiteRegistration r = {conv::Init, conv::Free, conv::Prepare,
- conv::Eval<conv::kNeonOptimized>};
+ conv::Eval<conv::kMultithreadOptimized>};
return &r;
}
TfLiteRegistration* Register_CONV_2D() {
-#ifdef USE_NEON
- return Register_CONVOLUTION_NEON_OPT();
-#else
- return Register_CONVOLUTION_GENERIC_OPT();
-#endif
+ return Register_CONVOLUTION_MULTITHREADED_OPT();
}
} // namespace builtin
diff --git a/tensorflow/contrib/lite/kernels/conv_test.cc b/tensorflow/contrib/lite/kernels/conv_test.cc
index 1d0a81c313..7550f7cc0d 100644
--- a/tensorflow/contrib/lite/kernels/conv_test.cc
+++ b/tensorflow/contrib/lite/kernels/conv_test.cc
@@ -15,12 +15,24 @@ limitations under the License.
#include <cstdarg>
#include <gtest/gtest.h>
+#include "absl/memory/memory.h"
#include "tensorflow/contrib/lite/interpreter.h"
#include "tensorflow/contrib/lite/kernels/register.h"
#include "tensorflow/contrib/lite/kernels/test_util.h"
#include "tensorflow/contrib/lite/model.h"
namespace tflite {
+
+namespace ops {
+namespace builtin {
+
+TfLiteRegistration* Register_CONVOLUTION_REF();
+TfLiteRegistration* Register_CONVOLUTION_GENERIC_OPT();
+TfLiteRegistration* Register_CONVOLUTION_MULTITHREADED_OPT();
+
+} // namespace builtin
+} // namespace ops
+
namespace {
using ::testing::ElementsAreArray;
@@ -30,9 +42,9 @@ class BaseConvolutionOpModel : public SingleOpModel {
// TODO(ahentz): Also test different activation types, bias, padding types,
// stride values.
BaseConvolutionOpModel(
- const TensorData& input, const TensorData& filter,
- const TensorData& output, int stride_width = 2, int stride_height = 2,
- enum Padding padding = Padding_VALID,
+ TfLiteRegistration* registration, const TensorData& input,
+ const TensorData& filter, const TensorData& output, int stride_width = 2,
+ int stride_height = 2, enum Padding padding = Padding_VALID,
enum ActivationFunctionType activation = ActivationFunctionType_NONE) {
input_ = AddInput(input);
filter_ = AddInput(filter);
@@ -62,6 +74,8 @@ class BaseConvolutionOpModel : public SingleOpModel {
stride_height, activation)
.Union());
+ resolver_ = absl::make_unique<SingleOpResolver>(BuiltinOperator_CONV_2D,
+ registration);
BuildInterpreter({GetShape(input_), GetShape(filter_), GetShape(bias_)});
}
@@ -83,12 +97,25 @@ class ConvolutionOpModel : public BaseConvolutionOpModel {
void SetInput(std::initializer_list<float> data) {
PopulateTensor(input_, data);
}
-
std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
};
-TEST(ConvolutionOpTest, SimpleTestFloat32) {
- ConvolutionOpModel m({TensorType_FLOAT32, {2, 2, 4, 1}},
+const auto kKernelMap = new std::map<string, TfLiteRegistration*>({
+ {"Reference", ops::builtin::Register_CONVOLUTION_REF()},
+ {"GenericOptimized", ops::builtin::Register_CONVOLUTION_GENERIC_OPT()},
+ {"MultithreadedOptimized",
+ ops::builtin::Register_CONVOLUTION_MULTITHREADED_OPT()},
+});
+
+class ConvolutionOpTest : public SingleOpTest {
+ protected:
+ const std::map<string, TfLiteRegistration*>& GetKernelMap() override {
+ return *kKernelMap;
+ }
+};
+
+TEST_P(ConvolutionOpTest, SimpleTestFloat32) {
+ ConvolutionOpModel m(GetRegistration(), {TensorType_FLOAT32, {2, 2, 4, 1}},
{TensorType_FLOAT32, {3, 2, 2, 1}},
{TensorType_FLOAT32, {}});
@@ -117,8 +144,8 @@ TEST(ConvolutionOpTest, SimpleTestFloat32) {
}));
}
-TEST(ConvolutionOpTest, SimpleTestFloat32WithAnisotropicStrides) {
- ConvolutionOpModel m({TensorType_FLOAT32, {1, 3, 6, 1}},
+TEST_P(ConvolutionOpTest, SimpleTestFloat32WithAnisotropicStrides) {
+ ConvolutionOpModel m(GetRegistration(), {TensorType_FLOAT32, {1, 3, 6, 1}},
{TensorType_FLOAT32, {1, 2, 2, 1}},
{TensorType_FLOAT32, {}},
/*stride_width=*/3, /*stride_height=*/1);
@@ -139,7 +166,7 @@ TEST(ConvolutionOpTest, SimpleTestFloat32WithAnisotropicStrides) {
}));
}
-TEST(ConvolutionOpTest, HandCalculatedFloat32) {
+TEST_P(ConvolutionOpTest, HandCalculatedFloat32) {
const int depth = 1;
const int image_width = 4;
const int image_height = 3;
@@ -150,6 +177,7 @@ TEST(ConvolutionOpTest, HandCalculatedFloat32) {
const int stride_height = 1;
const Padding padding = Padding_SAME;
ConvolutionOpModel m(
+ GetRegistration(),
{TensorType_FLOAT32,
{image_batch_count, image_height, image_width, depth}},
{TensorType_FLOAT32, {depth, filter_size, filter_size, filter_count}},
@@ -192,7 +220,7 @@ TEST(ConvolutionOpTest, HandCalculatedFloat32) {
178, 187, 234, 261, 121}));
}
-TEST(ConvolutionOpTest, HandCalculatedWithBiasFloat32) {
+TEST_P(ConvolutionOpTest, HandCalculatedWithBiasFloat32) {
const int depth = 1;
const int image_width = 4;
const int image_height = 3;
@@ -203,6 +231,7 @@ TEST(ConvolutionOpTest, HandCalculatedWithBiasFloat32) {
const int stride_height = 1;
const Padding padding = Padding_SAME;
ConvolutionOpModel m(
+ GetRegistration(),
{TensorType_FLOAT32,
{image_batch_count, image_height, image_width, depth}},
{TensorType_FLOAT32, {depth, filter_size, filter_size, filter_count}},
@@ -245,7 +274,7 @@ TEST(ConvolutionOpTest, HandCalculatedWithBiasFloat32) {
367, 188, 197, 244, 271, 131}));
}
-TEST(ConvolutionOpTest, HandCalculatedWithReluFloat32) {
+TEST_P(ConvolutionOpTest, HandCalculatedWithReluFloat32) {
const int depth = 1;
const int image_width = 4;
const int image_height = 3;
@@ -256,6 +285,7 @@ TEST(ConvolutionOpTest, HandCalculatedWithReluFloat32) {
const int stride_height = 1;
const Padding padding = Padding_SAME;
ConvolutionOpModel m(
+ GetRegistration(),
{TensorType_FLOAT32,
{image_batch_count, image_height, image_width, depth}},
{TensorType_FLOAT32, {depth, filter_size, filter_size, filter_count}},
@@ -300,7 +330,7 @@ TEST(ConvolutionOpTest, HandCalculatedWithReluFloat32) {
ElementsAreArray({0, 0, 0, 0, 35, 112, 157, 0, 0, 34, 61, 0}));
}
-TEST(ConvolutionOpTest, HandCalculatedValidFloat32) {
+TEST_P(ConvolutionOpTest, HandCalculatedValidFloat32) {
const int depth = 1;
const int image_width = 4;
const int image_height = 3;
@@ -311,6 +341,7 @@ TEST(ConvolutionOpTest, HandCalculatedValidFloat32) {
const int stride_height = 1;
const Padding padding = Padding_VALID;
ConvolutionOpModel m(
+ GetRegistration(),
{TensorType_FLOAT32,
{image_batch_count, image_height, image_width, depth}},
{TensorType_FLOAT32, {depth, filter_size, filter_size, filter_count}},
@@ -366,8 +397,9 @@ class QuantizedConvolutionOpModel : public BaseConvolutionOpModel {
// In this tests we set the input and output scales so that the results
// match exactly the 'non-quantized' version.
-TEST(ConvolutionOpTest, SimpleTestQuantized) {
- QuantizedConvolutionOpModel m({TensorType_UINT8, {2, 2, 4, 1}, -63.5, 64},
+TEST_P(ConvolutionOpTest, SimpleTestQuantized) {
+ QuantizedConvolutionOpModel m(GetRegistration(),
+ {TensorType_UINT8, {2, 2, 4, 1}, -63.5, 64},
{TensorType_UINT8, {3, 2, 2, 1}, -63.5, 64},
{TensorType_UINT8, {}, -127, 128});
m.SetInput({
@@ -405,8 +437,9 @@ TEST(ConvolutionOpTest, SimpleTestQuantized) {
}));
}
-TEST(ConvolutionOpTest, SimpleTestQuantizedWithAnisotropicStrides) {
- QuantizedConvolutionOpModel m({TensorType_UINT8, {1, 3, 6, 1}, -63.5, 64},
+TEST_P(ConvolutionOpTest, SimpleTestQuantizedWithAnisotropicStrides) {
+ QuantizedConvolutionOpModel m(GetRegistration(),
+ {TensorType_UINT8, {1, 3, 6, 1}, -63.5, 64},
{TensorType_UINT8, {1, 2, 2, 1}, -63.5, 64},
{TensorType_UINT8, {}, -127, 128},
/*stride_width=*/3, /*stride_height=*/1);
@@ -430,6 +463,11 @@ TEST(ConvolutionOpTest, SimpleTestQuantizedWithAnisotropicStrides) {
167, 93, //
}));
}
+
+INSTANTIATE_TEST_CASE_P(
+ ConvolutionOpTest, ConvolutionOpTest,
+ ::testing::ValuesIn(SingleOpTest::GetKernelTags(*kKernelMap)));
+
} // namespace
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/test_util.cc b/tensorflow/contrib/lite/kernels/test_util.cc
index 3a58e7ec32..6f56aa6bf3 100644
--- a/tensorflow/contrib/lite/kernels/test_util.cc
+++ b/tensorflow/contrib/lite/kernels/test_util.cc
@@ -172,11 +172,14 @@ void SingleOpModel::BuildInterpreter(
auto* model = GetModel(builder_.GetBufferPointer());
- ops::builtin::BuiltinOpResolver builtins;
- for (const auto& reg : custom_registrations_) {
- builtins.AddCustom(reg.first.data(), reg.second());
+ if (!resolver_) {
+ auto resolver = new ops::builtin::BuiltinOpResolver();
+ for (const auto& reg : custom_registrations_) {
+ resolver->AddCustom(reg.first.data(), reg.second());
+ }
+ resolver_ = std::unique_ptr<OpResolver>(resolver);
}
- InterpreterBuilder(model, builtins)(&interpreter_);
+ InterpreterBuilder(model, *resolver_)(&interpreter_);
CHECK(interpreter_ != nullptr);
diff --git a/tensorflow/contrib/lite/kernels/test_util.h b/tensorflow/contrib/lite/kernels/test_util.h
index cc445299ff..7d476ba1ea 100644
--- a/tensorflow/contrib/lite/kernels/test_util.h
+++ b/tensorflow/contrib/lite/kernels/test_util.h
@@ -85,6 +85,23 @@ struct TensorData {
int32_t zero_point;
};
+class SingleOpResolver : public OpResolver {
+ public:
+ SingleOpResolver(const BuiltinOperator op, TfLiteRegistration* registration)
+ : op_(op), registration_(registration) {}
+ TfLiteRegistration* FindOp(BuiltinOperator op) const override {
+ if (op == op_) {
+ return registration_;
+ }
+ return nullptr;
+ }
+ TfLiteRegistration* FindOp(const char* op) const override { return nullptr; }
+
+ private:
+ const BuiltinOperator op_;
+ TfLiteRegistration* registration_;
+};
+
class SingleOpModel {
public:
SingleOpModel() {}
@@ -178,11 +195,16 @@ class SingleOpModel {
return result;
}
+ void SetResolver(std::unique_ptr<OpResolver> resolver) {
+ resolver_ = std::move(resolver);
+ }
+
protected:
int32_t GetTensorSize(int index) const;
flatbuffers::FlatBufferBuilder builder_;
std::unique_ptr<tflite::Interpreter> interpreter_;
+ std::unique_ptr<OpResolver> resolver_;
private:
int AddTensor(TensorData t, std::initializer_list<int> data);
@@ -197,6 +219,36 @@ class SingleOpModel {
std::map<string, std::function<TfLiteRegistration*()>> custom_registrations_;
};
+// Base class for single op unit tests.
+// The tests are parameterized to test multiple kernels for a single op.
+// The parameters are strings like "optimized" and "reference" to have better
+// readability in test reports.
+//
+// To use this class:
+// * Define a constant map from strings to TfLiteRegistration.
+// * Implement a test class that inherits SingleOpTest.
+// * Instantiate the test cases with SingleOpTest::GetKernelTags helper
+// function.
+// * Call GetRegistration to get the TfLiteRegistration to be used before
+// building the interpreter.
+class SingleOpTest : public ::testing::TestWithParam<string> {
+ public:
+ static std::vector<string> GetKernelTags(
+ const std::map<string, TfLiteRegistration*>& kernel_map) {
+ std::vector<string> tags;
+ for (auto it : kernel_map) {
+ tags.push_back(it.first);
+ }
+ return tags;
+ }
+
+ protected:
+ virtual const std::map<string, TfLiteRegistration*>& GetKernelMap() = 0;
+ TfLiteRegistration* GetRegistration() {
+ return GetKernelMap().at(GetParam());
+ }
+};
+
// Strings have a special implementation that is in test_util.cc
template <>
std::vector<string> SingleOpModel::ExtractVector(int index);
diff --git a/tensorflow/contrib/lite/lib_package/create_ios_frameworks.sh b/tensorflow/contrib/lite/lib_package/create_ios_frameworks.sh
new file mode 100755
index 0000000000..b58ae26601
--- /dev/null
+++ b/tensorflow/contrib/lite/lib_package/create_ios_frameworks.sh
@@ -0,0 +1,81 @@
+#!/bin/bash -x
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+set -e
+
+echo "Starting"
+TFLITE_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)/.."
+
+TMP_DIR=$(mktemp -d)
+echo "Package dir: " $TMP_DIR
+FW_DIR=$TMP_DIR/tensorflow_lite_ios_frameworks
+FW_DIR_TFLITE=$FW_DIR/tensorflow_lite.framework
+FW_DIR_TFLITE_HDRS=$FW_DIR_TFLITE/Headers
+
+echo "Creating target Headers directories"
+mkdir -p $FW_DIR_TFLITE_HDRS
+
+echo "Headers, populating: TensorFlow Lite"
+cd $TFLITE_DIR/../../..
+
+find tensorflow/contrib/lite -name '*.h' \
+ -not -path 'tensorflow/contrib/lite/downloads/*' \
+ -not -path 'tensorflow/contrib/lite/examples/*' \
+ -not -path 'tensorflow/contrib/lite/gen/*' \
+ -not -path 'tensorflow/contrib/lite/toco/*' \
+ -not -path 'tensorflow/contrib/lite/nnapi/*' \
+ -not -path 'tensorflow/contrib/lite/java/*' \
+ | tar -cf $FW_DIR_TFLITE_HDRS/tmp.tar -T -
+cd $FW_DIR_TFLITE_HDRS
+tar xf tmp.tar
+rm -f tmp.tar
+
+echo "Headers, populating: Flatbuffer"
+cd $TFLITE_DIR/downloads/flatbuffers/include/
+find . -name '*.h' | tar -cf $FW_DIR_TFLITE_HDRS/tmp.tar -T -
+cd $FW_DIR_TFLITE_HDRS
+tar xf tmp.tar
+rm -f tmp.tar
+
+cd $TFLITE_DIR/../../..
+echo "Generate master LICENSE file and copy to target"
+bazel build //tensorflow/tools/lib_package:clicenses_generate
+cp $TFLITE_DIR/../../../bazel-genfiles/tensorflow/tools/lib_package/include/tensorflow/c/LICENSE \
+ $FW_DIR_TFLITE
+
+echo "Copying static libraries"
+cp $TFLITE_DIR/gen/lib/libtensorflow-lite.a \
+ $FW_DIR_TFLITE/tensorflow_lite
+
+# This is required, otherwise they interfere with the documentation of the
+# pod at cocoapods.org.
+echo "Remove all README files"
+cd $FW_DIR_TFLITE_HDRS
+find . -type f -name README\* -exec rm -f {} \;
+find . -type f -name readme\* -exec rm -f {} \;
+
+TARGET_GEN_LOCATION="$TFLITE_DIR/gen/ios_frameworks"
+echo "Moving results to target: " $TARGET_GEN_LOCATION
+cd $FW_DIR
+zip -q -r tensorflow_lite.framework.zip tensorflow_lite.framework -x .DS_Store
+rm -rf $TARGET_GEN_LOCATION
+mkdir -p $TARGET_GEN_LOCATION
+cp -r tensorflow_lite.framework.zip $TARGET_GEN_LOCATION
+
+echo "Cleaning up"
+rm -rf $TMP_DIR
+
+echo "Finished"
diff --git a/tensorflow/contrib/lite/toco/export_tensorflow.cc b/tensorflow/contrib/lite/toco/export_tensorflow.cc
index 529df3cd2e..4c70b01a9d 100644
--- a/tensorflow/contrib/lite/toco/export_tensorflow.cc
+++ b/tensorflow/contrib/lite/toco/export_tensorflow.cc
@@ -621,7 +621,8 @@ void ConvertSoftmaxOperator(const Model& model, const SoftmaxOperator& src_op,
GraphDef* tensorflow_graph) {
string softmax_input;
Operator* providing_op = GetOpWithOutput(model, src_op.inputs[0]);
- if (providing_op->type == OperatorType::kTensorFlowReshape) {
+ if (providing_op != nullptr &&
+ providing_op->type == OperatorType::kTensorFlowReshape) {
softmax_input = src_op.inputs[0];
} else {
// Insert a reshape operator that reduces the dimensions down to the 2 that
diff --git a/tensorflow/contrib/lite/tools/BUILD b/tensorflow/contrib/lite/tools/BUILD
index 1bffcfb987..4d3b553b22 100644
--- a/tensorflow/contrib/lite/tools/BUILD
+++ b/tensorflow/contrib/lite/tools/BUILD
@@ -99,8 +99,11 @@ cc_library(
srcs = ["verifier.cc"],
hdrs = ["verifier.h"],
deps = [
+ "//tensorflow/contrib/lite:framework",
"//tensorflow/contrib/lite:schema_fbs_version",
+ "//tensorflow/contrib/lite:string_util",
"//tensorflow/contrib/lite/schema:schema_fbs",
+ "@com_google_absl//absl/base:core_headers",
],
)
@@ -112,6 +115,7 @@ cc_test(
":verifier",
"//tensorflow/contrib/lite:framework",
"//tensorflow/contrib/lite:schema_fbs_version",
+ "//tensorflow/contrib/lite:string_util",
"//tensorflow/contrib/lite/schema:schema_fbs",
"//tensorflow/contrib/lite/testing:util",
"@com_google_googletest//:gtest",
diff --git a/tensorflow/contrib/lite/tools/verifier.cc b/tensorflow/contrib/lite/tools/verifier.cc
index 95a0895379..726e2aaa31 100644
--- a/tensorflow/contrib/lite/tools/verifier.cc
+++ b/tensorflow/contrib/lite/tools/verifier.cc
@@ -14,13 +14,32 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/lite/tools/verifier.h"
+#include <climits>
#include "tensorflow/contrib/lite/schema/schema_generated.h"
+#include "tensorflow/contrib/lite/string_util.h"
#include "tensorflow/contrib/lite/version.h"
namespace tflite {
namespace {
+// Reports error message when the reporter is set.
+void ReportError(ErrorReporter* error_reporter, const char* format, ...) {
+ if (error_reporter) {
+ va_list args;
+ va_start(args, format);
+ error_reporter->Report(format, args);
+ va_end(args);
+ }
+}
+
+// Returns the int32_t value pointed by ptr.
+const uint32_t* GetIntPtr(const char* ptr) {
+ return reinterpret_cast<const uint32_t*>(ptr);
+}
+
+// Verifies flatbuffer format of the model contents and returns the in-memory
+// model.
const Model* VerifyFlatbufferAndGetModel(const void* buf, size_t len) {
::flatbuffers::Verifier verifier(static_cast<const uint8_t*>(buf), len);
if (VerifyModelBuffer(verifier)) {
@@ -30,14 +49,159 @@ const Model* VerifyFlatbufferAndGetModel(const void* buf, size_t len) {
}
}
+const uint32_t kMaxNumString = UINT_MAX / sizeof(int32_t) - 2;
+
+// Verifies string tensor has legit buffer contents that follow the schema
+// defined in lite/string_util.h
+bool VerifyStringTensorBuffer(const Buffer& buffer,
+ ErrorReporter* error_reporter) {
+ uint32_t buffer_size = buffer.data()->size();
+ const char* buffer_ptr = reinterpret_cast<const char*>(buffer.data()->data());
+
+ uint32_t num_strings = *GetIntPtr(buffer_ptr);
+ if (num_strings > kMaxNumString) {
+ ReportError(error_reporter,
+ "String tensor has invalid num of string set: %d", num_strings);
+ return false;
+ }
+ uint32_t header_offsets =
+ static_cast<uint32_t>(num_strings + 2) * sizeof(int32_t);
+
+ if (buffer_size < header_offsets) {
+ ReportError(error_reporter,
+ "String tensor buffer requires at least %d bytes, but is "
+ "allocated with %d bytes",
+ header_offsets, buffer_size);
+ return false;
+ }
+
+ uint32_t prev_ptr = header_offsets;
+ uint32_t offset = sizeof(int32_t);
+
+ if (*GetIntPtr(buffer_ptr + offset) != header_offsets) {
+ ReportError(error_reporter,
+ "String tensor buffer initial offset must be: %d",
+ header_offsets);
+ return false;
+ }
+ offset += sizeof(int32_t);
+ for (int i = 1; i <= num_strings; i++, offset += sizeof(int32_t)) {
+ int string_offset = *GetIntPtr(buffer_ptr + offset);
+ if (string_offset < prev_ptr || string_offset > buffer_size) {
+ ReportError(error_reporter, "String tensor buffer is invalid: index %d",
+ i);
+ return false;
+ }
+ }
+ if (*GetIntPtr(buffer_ptr + offset - sizeof(int32_t)) != buffer_size) {
+ ReportError(error_reporter, "String tensor buffer last offset must be %d",
+ buffer_size);
+ return false;
+ }
+ return true;
+}
+
+// Verifies numeric tensor has legit buffer.
+bool VerifyNumericTensorBuffer(const Tensor& tensor, const Buffer& buffer,
+ ErrorReporter* error_reporter) {
+ uint64_t bytes_required = 1;
+ for (int dim : *tensor.shape()) {
+ bytes_required *= dim;
+ if (bytes_required > UINT_MAX) {
+ ReportError(error_reporter, "Tensor dimension overflow");
+ return false;
+ }
+ }
+ switch (tensor.type()) {
+ case TensorType_FLOAT32:
+ bytes_required *= sizeof(float);
+ break;
+ case TensorType_INT32:
+ bytes_required *= sizeof(int32_t);
+ break;
+ case TensorType_UINT8:
+ bytes_required *= sizeof(uint8_t);
+ break;
+ case TensorType_INT64:
+ bytes_required *= sizeof(int64_t);
+ break;
+ case TensorType_FLOAT16:
+ // FALLTHROUGH_INTENDED;
+ default:
+ ReportError(error_reporter, "Invalid tensor type: %d", tensor.type());
+ return false;
+ }
+ if (bytes_required > UINT_MAX) {
+ ReportError(error_reporter, "Tensor dimension overflow");
+ return false;
+ }
+
+ if (bytes_required != buffer.data()->size()) {
+ ReportError(
+ error_reporter,
+ "Tensor requires %d bytes, but is allocated with %d bytes buffer",
+ bytes_required, buffer.data()->size());
+ return false;
+ }
+ return true;
+
+ // TODO(yichengfan): verify quantized tensors.
+}
+
+// Verifies tensors have valid properties and legit buffer if set.
+bool VerifyTensors(const Model& model, ErrorReporter* error_reporter) {
+ if (!model.subgraphs()) {
+ return true;
+ }
+ for (const auto& subgraph : *model.subgraphs()) {
+ if (!subgraph->tensors()) {
+ return true;
+ }
+ for (const auto& tensor : *subgraph->tensors()) {
+ if (!tensor->buffer()) {
+ return true;
+ }
+ if (tensor->buffer() >= model.buffers()->size()) {
+ ReportError(error_reporter, "Invalid tensor buffer index: %d",
+ tensor->buffer());
+ return false;
+ }
+ auto* buffer = model.buffers()->Get(tensor->buffer());
+ if (!buffer || !buffer->data()) {
+ ReportError(error_reporter, "Tensor buffer %d not set",
+ tensor->buffer());
+ return false;
+ }
+
+ if (tensor->type() == TensorType_STRING) {
+ if (!VerifyStringTensorBuffer(*buffer, error_reporter)) {
+ return false;
+ }
+ } else {
+ if (!VerifyNumericTensorBuffer(*tensor, *buffer, error_reporter)) {
+ return false;
+ }
+ }
+ }
+ }
+ return true;
+}
+
} // namespace
-bool Verify(const void* buf, size_t len) {
+bool Verify(const void* buf, size_t len, ErrorReporter* error_reporter) {
const Model* model = VerifyFlatbufferAndGetModel(buf, len);
if (model == nullptr) {
+ ReportError(error_reporter, "Invalid flatbuffer format");
return false;
}
-
- return model->version() == TFLITE_SCHEMA_VERSION;
+ if (model->version() != TFLITE_SCHEMA_VERSION) {
+ ReportError(error_reporter, "Invalid model version %d", model->version());
+ return false;
+ }
+ if (!VerifyTensors(*model, error_reporter)) {
+ return false;
+ }
+ return true;
}
} // namespace tflite
diff --git a/tensorflow/contrib/lite/tools/verifier.h b/tensorflow/contrib/lite/tools/verifier.h
index 03e1f22b7e..d2bf3c91d5 100644
--- a/tensorflow/contrib/lite/tools/verifier.h
+++ b/tensorflow/contrib/lite/tools/verifier.h
@@ -18,13 +18,15 @@ limitations under the License.
#include <stdio.h>
+#include "tensorflow/contrib/lite/error_reporter.h"
+
namespace tflite {
// Verifies the integrity of a Tensorflow Lite flatbuffer model file.
// Currently, it verifies:
// * The file is following a legit flatbuffer schema.
// * The model is in supported version.
-bool Verify(const void* buf, size_t len);
+bool Verify(const void* buf, size_t len, ErrorReporter* error_reporter);
} // namespace tflite
diff --git a/tensorflow/contrib/lite/tools/verifier_test.cc b/tensorflow/contrib/lite/tools/verifier_test.cc
index 0481a55a78..244d4f0396 100644
--- a/tensorflow/contrib/lite/tools/verifier_test.cc
+++ b/tensorflow/contrib/lite/tools/verifier_test.cc
@@ -28,31 +28,62 @@ using flatbuffers::FlatBufferBuilder;
using flatbuffers::Offset;
using flatbuffers::Vector;
-// Class that abstracts the list of buffers at the end of the TF Lite structure
-class DeferredBufferWriter {
+// Build single subgraph model.
+class TfLiteFlatbufferModelBuilder {
public:
- DeferredBufferWriter() {
- data_.push_back({}); // sentinel empty buffer.
+ TfLiteFlatbufferModelBuilder() {
+ buffers_.push_back(
+ CreateBuffer(builder_, builder_.CreateVector(std::vector<uint8_t>{})));
}
- Offset<Vector<Offset<Buffer>>> BuildBuffers(FlatBufferBuilder *builder) {
- std::vector<Offset<Buffer>> buffer_vector;
- for (const auto &vec : data_) {
- auto data_buffer = builder->CreateVector(vec.data(), vec.size());
- buffer_vector.push_back(tflite::CreateBuffer(*builder, data_buffer));
+ void AddTensor(const std::vector<int>& shape, tflite::TensorType type,
+ const std::vector<uint8_t>& buffer, const char* name) {
+ int buffer_index = 0;
+ if (!buffer.empty()) {
+ buffer_index = buffers_.size();
+ buffers_.push_back(CreateBuffer(builder_, builder_.CreateVector(buffer)));
}
- return builder->CreateVector(buffer_vector);
+ tensors_.push_back(CreateTensorDirect(builder_, &shape, type, buffer_index,
+ name, /*quantization=*/0));
}
- // Registers a buffer index and takes ownership of the data to write to it.
- int Record(std::vector<uint8_t> data) {
- int buffer_index = data_.size();
- data_.emplace_back(std::move(data));
- return buffer_index;
+ void AddOperator(const std::vector<int32_t>& inputs,
+ const std::vector<int32_t>& outputs,
+ tflite::BuiltinOperator builtin_op, const char* custom_op) {
+ operator_codes_.push_back(
+ CreateOperatorCodeDirect(builder_, builtin_op, custom_op));
+ operators_.push_back(CreateOperator(
+ builder_, operator_codes_.size() - 1, builder_.CreateVector(inputs),
+ builder_.CreateVector(outputs), BuiltinOptions_NONE,
+ /*builtin_options=*/0,
+ /*custom_options=*/0, tflite::CustomOptionsFormat_FLEXBUFFERS));
+ }
+
+ void FinishModel(const std::vector<int32_t>& inputs,
+ const std::vector<int32_t>& outputs) {
+ auto subgraph = std::vector<Offset<SubGraph>>({CreateSubGraph(
+ builder_, builder_.CreateVector(tensors_),
+ builder_.CreateVector(inputs), builder_.CreateVector(outputs),
+ builder_.CreateVector(operators_),
+ builder_.CreateString("test_subgraph"))});
+ auto result = CreateModel(
+ builder_, TFLITE_SCHEMA_VERSION, builder_.CreateVector(operator_codes_),
+ builder_.CreateVector(subgraph), builder_.CreateString("test_model"),
+ builder_.CreateVector(buffers_));
+ tflite::FinishModelBuffer(builder_, result);
+ }
+
+ bool Verify() {
+ return tflite::Verify(builder_.GetBufferPointer(), builder_.GetSize(),
+ DefaultErrorReporter());
}
private:
- std::vector<std::vector<unsigned char>> data_;
+ FlatBufferBuilder builder_;
+ std::vector<Offset<Operator>> operators_;
+ std::vector<Offset<OperatorCode>> operator_codes_;
+ std::vector<Offset<Tensor>> tensors_;
+ std::vector<Offset<Buffer>> buffers_;
};
TEST(VerifyModel, TestEmptyModel) {
@@ -62,43 +93,26 @@ TEST(VerifyModel, TestEmptyModel) {
/*description=*/0, /*buffers=*/0);
::tflite::FinishModelBuffer(builder, model);
- ASSERT_TRUE(Verify(builder.GetBufferPointer(), builder.GetSize()));
+ ASSERT_TRUE(Verify(builder.GetBufferPointer(), builder.GetSize(),
+ DefaultErrorReporter()));
}
TEST(VerifyModel, TestSimpleModel) {
- FlatBufferBuilder builder;
- auto inputs = builder.CreateVector<int32_t>({0});
- auto outputs = builder.CreateVector<int32_t>({1});
- auto operator_codes = builder.CreateVector(std::vector<Offset<OperatorCode>>{
- CreateOperatorCodeDirect(builder, BuiltinOperator_CUSTOM, "test")});
- auto operators =
- builder.CreateVector(std::vector<Offset<Operator>>{CreateOperator(
- builder, /*opcode_index=*/0,
- /*inputs=*/builder.CreateVector<int32_t>({0}),
- /*outputs=*/builder.CreateVector<int32_t>({1}), BuiltinOptions_NONE,
- /*builtin_options=*/0,
- /*custom_options=*/0, ::tflite::CustomOptionsFormat_FLEXBUFFERS)});
- std::vector<int> shape;
- auto tensors = builder.CreateVector(std::vector<Offset<Tensor>>{
- CreateTensorDirect(builder, &shape, TensorType_INT32, /*buffer=*/0,
- "input", /*quantization=*/0),
- CreateTensorDirect(builder, &shape, TensorType_INT32, /*buffer=*/0,
- "output", /*quantization=*/0)});
- auto subgraph = std::vector<Offset<SubGraph>>(
- {CreateSubGraph(builder, tensors, inputs, outputs, operators,
- builder.CreateString("Main"))});
-
- auto model = CreateModel(builder, TFLITE_SCHEMA_VERSION, operator_codes,
- builder.CreateVector(subgraph),
- builder.CreateString("SmartReply"), /*buffers=*/0);
-
- ::tflite::FinishModelBuffer(builder, model);
- ASSERT_TRUE(Verify(builder.GetBufferPointer(), builder.GetSize()));
+ TfLiteFlatbufferModelBuilder builder;
+ builder.AddOperator({0, 1}, {2}, BuiltinOperator_CUSTOM, "test");
+ builder.AddTensor({2, 3}, TensorType_UINT8, {1, 2, 3, 4, 5, 6}, "input");
+ builder.AddTensor(
+ {2}, TensorType_STRING,
+ {2, 0, 0, 0, 16, 0, 0, 0, 17, 0, 0, 0, 19, 0, 0, 0, 'A', 'B', 'C'},
+ "data");
+ builder.AddTensor({2, 3}, TensorType_INT32, {}, "output");
+ builder.FinishModel({0, 1}, {2});
+ ASSERT_TRUE(builder.Verify());
}
TEST(VerifyModel, TestCorruptedData) {
string model = "123";
- ASSERT_FALSE(Verify(model.data(), model.size()));
+ ASSERT_FALSE(Verify(model.data(), model.size(), /*error_reporter=*/nullptr));
}
TEST(VerifyModel, TestUnsupportedVersion) {
@@ -106,7 +120,8 @@ TEST(VerifyModel, TestUnsupportedVersion) {
auto model = CreateModel(builder, /*version=*/1, /*operator_codes=*/0,
/*subgraphs=*/0, /*description=*/0, /*buffers=*/0);
::tflite::FinishModelBuffer(builder, model);
- ASSERT_FALSE(Verify(builder.GetBufferPointer(), builder.GetSize()));
+ ASSERT_FALSE(Verify(builder.GetBufferPointer(), builder.GetSize(),
+ DefaultErrorReporter()));
}
TEST(VerifyModel, TestRandomModificationIsNotAllowed) {
@@ -116,20 +131,105 @@ TEST(VerifyModel, TestRandomModificationIsNotAllowed) {
/*subgraphs=*/0, /*description=*/0, /*buffers=*/0);
::tflite::FinishModelBuffer(builder, model);
- string model_content(reinterpret_cast<char *>(builder.GetBufferPointer()),
+ string model_content(reinterpret_cast<char*>(builder.GetBufferPointer()),
builder.GetSize());
for (int i = 0; i < model_content.size(); i++) {
model_content[i] = (model_content[i] + 137) % 255;
- EXPECT_FALSE(Verify(model_content.data(), model_content.size()))
+ EXPECT_FALSE(Verify(model_content.data(), model_content.size(),
+ DefaultErrorReporter()))
<< "Fail at position: " << i;
}
}
+TEST(VerifyModel, TestIntTensorShapeIsGreaterThanBuffer) {
+ TfLiteFlatbufferModelBuilder builder;
+ builder.AddTensor({2, 3}, TensorType_UINT8, {1, 2, 3, 4}, "input");
+ builder.FinishModel({}, {});
+ ASSERT_FALSE(builder.Verify());
+}
+
+TEST(VerifyModel, TestIntTensorShapeIsSmallerThanBuffer) {
+ TfLiteFlatbufferModelBuilder builder;
+ builder.AddTensor({2, 1}, TensorType_UINT8, {1, 2, 3, 4}, "input");
+ builder.FinishModel({}, {});
+ ASSERT_FALSE(builder.Verify());
+}
+
+TEST(VerifyModel, TestIntTensorShapeOverflow) {
+ TfLiteFlatbufferModelBuilder builder;
+ builder.AddTensor({1024, 2048, 4096}, TensorType_UINT8, {1, 2, 3, 4},
+ "input");
+ builder.FinishModel({}, {});
+ ASSERT_FALSE(builder.Verify());
+}
+
+TEST(VerifyModel, TensorBufferIsNotValid) {
+ FlatBufferBuilder builder;
+ std::vector<int> shape = {2, 3};
+ auto tensors = builder.CreateVector(std::vector<Offset<Tensor>>{
+ CreateTensorDirect(builder, &shape, TensorType_INT32, /*buffer=*/2,
+ "input", /*quantization=*/0)});
+ auto subgraph = std::vector<Offset<SubGraph>>(
+ {CreateSubGraph(builder, tensors, /*inputs=*/0, /*outputs=*/0,
+ /*operators=*/0, builder.CreateString("Main"))});
+
+ auto buffers = builder.CreateVector(std::vector<Offset<Buffer>>{
+ CreateBuffer(builder,
+ builder.CreateVector(std::vector<uint8>{1, 2, 3, 4, 5, 6})),
+ });
+
+ auto model = CreateModel(builder, TFLITE_SCHEMA_VERSION, /*operator_codes=*/0,
+ builder.CreateVector(subgraph),
+ builder.CreateString("SmartReply"), buffers);
+
+ ::tflite::FinishModelBuffer(builder, model);
+ ASSERT_FALSE(Verify(builder.GetBufferPointer(), builder.GetSize(),
+ DefaultErrorReporter()));
+}
+
+TEST(VerifyModel, StringTensorHasInvalidNumString) {
+ TfLiteFlatbufferModelBuilder builder;
+ builder.AddTensor(
+ {2}, TensorType_STRING,
+ {0x00, 0x00, 0x00, 0x20, 16, 0, 0, 0, 17, 0, 0, 0, 18, 0, 0, 0, 'A', 'B'},
+ "input");
+ builder.FinishModel({}, {});
+ ASSERT_FALSE(builder.Verify());
+}
+
+TEST(VerifyModel, StringTensorOffsetTooSmall) {
+ TfLiteFlatbufferModelBuilder builder;
+ builder.AddTensor(
+ {2}, TensorType_STRING,
+ {2, 0, 0, 0, 12, 0, 0, 0, 17, 0, 0, 0, 18, 0, 0, 0, 'A', 'B'}, "input");
+ builder.FinishModel({}, {});
+ ASSERT_FALSE(builder.Verify());
+}
+
+TEST(VerifyModel, StringTensorOffsetOutOfRange) {
+ TfLiteFlatbufferModelBuilder builder;
+ builder.AddTensor(
+ {2}, TensorType_STRING,
+ {2, 0, 0, 0, 16, 0, 0, 0, 17, 0, 0, 0, 22, 0, 0, 0, 'A', 'B'}, "input");
+ builder.FinishModel({}, {});
+ ASSERT_FALSE(builder.Verify());
+}
+
+TEST(VerifyModel, StringTensorIsLargerThanRequired) {
+ TfLiteFlatbufferModelBuilder builder;
+ builder.AddTensor(
+ {2}, TensorType_STRING,
+ {2, 0, 0, 0, 16, 0, 0, 0, 17, 0, 0, 0, 18, 0, 0, 0, 'A', 'B', 'C'},
+ "input");
+ builder.FinishModel({}, {});
+ ASSERT_FALSE(builder.Verify());
+}
+
// TODO(yichengfan): make up malicious files to test with.
} // namespace tflite
-int main(int argc, char **argv) {
+int main(int argc, char** argv) {
::tflite::LogToStderr();
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
diff --git a/tensorflow/contrib/py2tf/BUILD b/tensorflow/contrib/py2tf/BUILD
index cea3738499..479ea9beca 100644
--- a/tensorflow/contrib/py2tf/BUILD
+++ b/tensorflow/contrib/py2tf/BUILD
@@ -23,6 +23,7 @@ py_library(
visibility = ["//visibility:public"],
deps = [
"//tensorflow/contrib/py2tf/impl",
+ "//tensorflow/contrib/py2tf/utils",
"@gast_archive//:gast",
"@six_archive//:six",
],
diff --git a/tensorflow/contrib/py2tf/__init__.py b/tensorflow/contrib/py2tf/__init__.py
index 878941b3a3..0d51bf0bf2 100644
--- a/tensorflow/contrib/py2tf/__init__.py
+++ b/tensorflow/contrib/py2tf/__init__.py
@@ -21,12 +21,13 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from tensorflow.contrib.py2tf import utils
from tensorflow.contrib.py2tf.impl.api import convert
from tensorflow.contrib.py2tf.impl.api import graph_ready
from tensorflow.contrib.py2tf.impl.api import to_code
from tensorflow.contrib.py2tf.impl.api import to_graph
from tensorflow.python.util.all_util import remove_undocumented
-_allowed_symbols = ['to_graph', 'to_code', 'convert', 'graph_ready']
+_allowed_symbols = ['to_graph', 'to_code', 'convert', 'graph_ready', 'utils']
remove_undocumented(__name__, _allowed_symbols)
diff --git a/tensorflow/contrib/py2tf/converters/BUILD b/tensorflow/contrib/py2tf/converters/BUILD
index b61fda3e91..3853c60f99 100644
--- a/tensorflow/contrib/py2tf/converters/BUILD
+++ b/tensorflow/contrib/py2tf/converters/BUILD
@@ -46,6 +46,7 @@ py_library(
deps = [
":converters",
"//tensorflow/contrib/py2tf/pyct/static_analysis",
+ "//tensorflow/contrib/py2tf/utils",
"@gast_archive//:gast",
],
)
diff --git a/tensorflow/contrib/py2tf/converters/side_effect_guards.py b/tensorflow/contrib/py2tf/converters/side_effect_guards.py
index ae96323966..1eda8ae630 100644
--- a/tensorflow/contrib/py2tf/converters/side_effect_guards.py
+++ b/tensorflow/contrib/py2tf/converters/side_effect_guards.py
@@ -111,31 +111,20 @@ class SideEffectGuardTransformer(gast.NodeTransformer):
# opt.minimize(loss)
# or:
# tf.py_func(...)
- args_scope = anno.getanno(node.value, 'args_scope')
- temp_name = self.namer.new_symbol('temp', args_scope.parent.referenced)
- # TODO(mdan): Unsafe reference modification!
- args_scope.mark_write(temp_name)
template = """
- temp_result = call
- if temp_result is not None:
- if not isinstance(temp_result, (list, tuple)):
- temp_result = (temp_result,)
- ctx = tf.control_dependencies(temp_result)
- else:
- ctx = contextmanager(lambda: (yield))()
- with ctx:
- # TODO(mdan): Also insert ops to re-fetch if variables are involved.
+ with py2tf_utils.control_dependency_on_returns(tf, call):
+ # TODO(mdan): Also insert ops to re-fetch if variables are involved?
pass # Will be removed below.
"""
# TODO(mdan): This is brittle. Reorganize the mechanism.
- statements = templates.replace(
- template, call=node.value, temp_result=temp_name)
+ statements = templates.replace(template, call=node.value)
control_deps_guard = statements[-1]
control_deps_guard.body = []
# First, attempt to gate future evaluation of args. If that's not
# possible, gate all remaining statements (and that may fail too, see
# _visit_and_reindent.
+ args_scope = anno.getanno(node.value, 'args_scope')
guarded_args = tuple(args_scope.used & (args_scope.parent.modified
| args_scope.parent.returned))
if guarded_args:
diff --git a/tensorflow/contrib/py2tf/converters/side_effect_guards_test.py b/tensorflow/contrib/py2tf/converters/side_effect_guards_test.py
index 5c56973dc2..452d7ab2be 100644
--- a/tensorflow/contrib/py2tf/converters/side_effect_guards_test.py
+++ b/tensorflow/contrib/py2tf/converters/side_effect_guards_test.py
@@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from tensorflow.contrib.py2tf import utils
from tensorflow.contrib.py2tf.converters import converter_test_base
from tensorflow.contrib.py2tf.converters import side_effect_guards
from tensorflow.contrib.py2tf.pyct import compiler
@@ -46,6 +47,7 @@ class SideEffectGuardsTest(converter_test_base.TestCase):
node = side_effect_guards.transform(node, TestNamer())
result = compiler.ast_to_object(node)
setattr(result, 'state_ops', state_ops)
+ setattr(result, 'py2tf_utils', utils)
# TODO(mdan): Configure the namespaces instead of doing these hacks.
ops.identity = array_ops.identity
diff --git a/tensorflow/contrib/py2tf/impl/config.py b/tensorflow/contrib/py2tf/impl/config.py
index 0892241983..6525806a09 100644
--- a/tensorflow/contrib/py2tf/impl/config.py
+++ b/tensorflow/contrib/py2tf/impl/config.py
@@ -36,4 +36,5 @@ NO_SIDE_EFFECT_CONSTRUCTORS = set(('tensorflow',))
# TODO(mdan): Make sure copybara renames the reference below.
COMPILED_IMPORT_STATEMENTS = (
'import tensorflow as tf',
-)
+ 'from tensorflow.contrib.py2tf import utils as '
+ 'py2tf_utils')
diff --git a/tensorflow/contrib/py2tf/utils/BUILD b/tensorflow/contrib/py2tf/utils/BUILD
new file mode 100644
index 0000000000..01804aa883
--- /dev/null
+++ b/tensorflow/contrib/py2tf/utils/BUILD
@@ -0,0 +1,37 @@
+licenses(["notice"]) # Apache 2.0
+
+load("//tensorflow:tensorflow.bzl", "py_test")
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
+
+py_library(
+ name = "utils",
+ srcs = [
+ "__init__.py",
+ "context_managers.py",
+ ],
+ srcs_version = "PY2AND3",
+ visibility = ["//tensorflow:__subpackages__"],
+ deps = [
+ ],
+)
+
+py_test(
+ name = "context_managers_test",
+ srcs = ["context_managers_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":utils",
+ "//tensorflow/python:client_testlib",
+ ],
+)
diff --git a/tensorflow/contrib/py2tf/utils/__init__.py b/tensorflow/contrib/py2tf/utils/__init__.py
new file mode 100644
index 0000000000..bca33e89e9
--- /dev/null
+++ b/tensorflow/contrib/py2tf/utils/__init__.py
@@ -0,0 +1,21 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Utility module that contains APIs usable in the generated code."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.py2tf.utils.context_managers import control_dependency_on_returns
diff --git a/tensorflow/contrib/py2tf/utils/context_managers.py b/tensorflow/contrib/py2tf/utils/context_managers.py
new file mode 100644
index 0000000000..47d9839997
--- /dev/null
+++ b/tensorflow/contrib/py2tf/utils/context_managers.py
@@ -0,0 +1,41 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Various context managers."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import contextlib
+
+
+def control_dependency_on_returns(tf, return_value):
+ """Create a TF control dependency on the return values of a function.
+
+ If the function had no return value, a no-op context is returned.
+
+ Args:
+ tf: The TensorFlow module.
+ return_value: The return value to set as control dependency.
+
+ Returns:
+ A context manager.
+ """
+ if return_value is None:
+ return contextlib.contextmanager(lambda: (yield))()
+ # TODO(mdan): Filter to tensor objects.
+ if not isinstance(return_value, (list, tuple)):
+ return_value = (return_value,)
+ return tf.control_dependencies(return_value)
diff --git a/tensorflow/contrib/py2tf/utils/context_managers_test.py b/tensorflow/contrib/py2tf/utils/context_managers_test.py
new file mode 100644
index 0000000000..c903f08252
--- /dev/null
+++ b/tensorflow/contrib/py2tf/utils/context_managers_test.py
@@ -0,0 +1,43 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for context_managers module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.py2tf.utils import context_managers
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import ops
+from tensorflow.python.platform import test
+
+
+class ContextManagersTest(test.TestCase):
+
+ def test_control_dependency_on_returns(self):
+ # Just dry run them.
+ with context_managers.control_dependency_on_returns(ops, None):
+ pass
+ with context_managers.control_dependency_on_returns(
+ ops, constant_op.constant(1)):
+ pass
+ with context_managers.control_dependency_on_returns(
+ ops, [constant_op.constant(1),
+ constant_op.constant(2)]):
+ pass
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/contrib/quantize/BUILD b/tensorflow/contrib/quantize/BUILD
index 3c5b34a0a6..b7d525a1fa 100644
--- a/tensorflow/contrib/quantize/BUILD
+++ b/tensorflow/contrib/quantize/BUILD
@@ -77,9 +77,13 @@ py_library(
"//tensorflow/contrib/graph_editor:graph_editor_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework_ops",
+ "//tensorflow/python:layers",
"//tensorflow/python:math_ops",
"//tensorflow/python:nn",
"//tensorflow/python:nn_ops",
+ "//tensorflow/python:ops",
+ "//tensorflow/python:training",
+ "//tensorflow/python:variables",
],
)
diff --git a/tensorflow/contrib/quantize/python/fold_batch_norms.py b/tensorflow/contrib/quantize/python/fold_batch_norms.py
index aa605e6caa..8ec5334a39 100644
--- a/tensorflow/contrib/quantize/python/fold_batch_norms.py
+++ b/tensorflow/contrib/quantize/python/fold_batch_norms.py
@@ -17,7 +17,6 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-
import re
from tensorflow.contrib import graph_editor
from tensorflow.contrib.quantize.python import common
@@ -26,14 +25,16 @@ from tensorflow.contrib.quantize.python import input_to_ops
from tensorflow.core.framework import attr_value_pb2
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
+from tensorflow.python.layers import utils
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
from tensorflow.python.ops import nn_ops
+from tensorflow.python.training import training_util
from tensorflow.python.util import compat
-def FoldBatchNorms(graph):
+def FoldBatchNorms(graph, freeze_batch_norm_delay=None, is_training=True):
"""Finds batch norm layers and folds them into preceding layers.
Folding only affects the following layers: Conv2D, fully connected, depthwise
@@ -41,15 +42,25 @@ def FoldBatchNorms(graph):
Args:
graph: Graph to walk and modify.
+ freeze_batch_norm_delay: How many steps to wait before freezing
+ moving mean and variance and using them for batch normalization. This value
+ is used only when is_training is True.
+ is_training: Bool, true if training
Raises:
ValueError: When batch norm folding fails.
"""
- _FoldFusedBatchNorms(graph)
- _FoldUnfusedBatchNorms(graph)
+ _FoldFusedBatchNorms(
+ graph,
+ freeze_batch_norm_delay=freeze_batch_norm_delay,
+ is_training=is_training)
+ _FoldUnfusedBatchNorms(
+ graph,
+ freeze_batch_norm_delay=freeze_batch_norm_delay,
+ is_training=is_training)
-def _FoldFusedBatchNorms(graph):
+def _FoldFusedBatchNorms(graph, freeze_batch_norm_delay, is_training):
"""Finds fused batch norm layers and folds them into preceding layers.
Folding only affects the following layers: Conv2D, fully connected, depthwise
@@ -57,6 +68,9 @@ def _FoldFusedBatchNorms(graph):
Args:
graph: Graph to walk and modify.
+ freeze_batch_norm_delay: How many steps to wait before freezing
+ moving mean and variance and using them for batch normalization
+ is_training: Bool, true if training
Raises:
ValueError: When batch norm folding fails.
@@ -67,8 +81,7 @@ def _FoldFusedBatchNorms(graph):
# `bn_op`. The '/' (i.e. `sep`) ensures that we reuse the existing scope
# named `scope`. Otherwise, TF creates a unique scope whose name starts with
# `scope`.
- with graph.as_default(), graph.name_scope(scope + sep), ops.device(
- match.bn_op.device):
+ with graph.as_default(), graph.name_scope(scope + sep):
with graph.name_scope(scope + sep + 'BatchNorm_Fold' + sep):
# new weights = old weights * gamma / sqrt(variance + epsilon)
# new biases = -mean * gamma / sqrt(variance + epsilon) + beta
@@ -79,9 +92,18 @@ def _FoldFusedBatchNorms(graph):
match.mean_tensor * multiplier_tensor,
name='bias')
+ correction_scale, correction_recip, correction_offset = None, None, None
+ if is_training:
+ correction_scale, correction_recip, correction_offset = (
+ _ComputeBatchNormCorrections(
+ context='',
+ match=match,
+ freeze_batch_norm_delay=freeze_batch_norm_delay,
+ fused_batch_norm=True))
# The shape of depthwise weights is different, so we need to reshape the
# multiplier_tensor to ensure that the scaled_weight_tensor has the
# expected shape.
+ weights = match.weight_tensor
if match.layer_op.type == 'DepthwiseConv2dNative':
new_shape = [
match.weight_tensor.get_shape().as_list()[2],
@@ -90,15 +112,29 @@ def _FoldFusedBatchNorms(graph):
multiplier_tensor = array_ops.reshape(
multiplier_tensor, new_shape, name='scale_reshape')
+ if correction_scale is not None:
+ correction_scale = array_ops.reshape(
+ correction_scale, new_shape, name='correction_reshape')
+
+ if correction_scale is not None:
+ weights = math_ops.multiply(
+ correction_scale, weights, name='correction_mult')
+
# TODO(suharshs): This naming of the following ops needs to carefully
# follow the naming expected by quantize.py. Generalize the quantize code
# to not require these delicate naming conventions.
scaled_weight_tensor = math_ops.multiply(
- match.weight_tensor, multiplier_tensor, name='mul_fold')
+ weights, multiplier_tensor, name='mul_fold')
new_layer_tensor = _CloneWithNewOperands(
match.layer_op, match.input_tensor, scaled_weight_tensor)
+ if correction_recip is not None:
+ new_layer_tensor = math_ops.multiply(
+ correction_recip, new_layer_tensor, name='post_conv_mul')
+ new_layer_tensor = math_ops.add(new_layer_tensor, (correction_offset),
+ 'correction_add')
+
bias_add_tensor = math_ops.add(
new_layer_tensor, bias_tensor, name='add_fold')
@@ -165,6 +201,8 @@ def _FindFusedBatchNorms(graph):
mean_pattern = graph_matcher.OpTypePattern('*')
variance_pattern = graph_matcher.OpTypePattern('*')
+ moving_average_pattern = graph_matcher.OpTypePattern('*')
+ bn_decay_pattern = graph_matcher.OpTypePattern('*')
conv_pattern = graph_matcher.OpTypePattern(
'Conv2D|DepthwiseConv2dNative', inputs=[input_pattern, weight_pattern])
# MatMul has a Reshape between it and FusedBatchNorm.
@@ -180,6 +218,11 @@ def _FindFusedBatchNorms(graph):
conv_pattern, gamma_pattern, beta_pattern, mean_pattern,
variance_pattern
])
+ conv_moving_average_sub_pattern = graph_matcher.OpTypePattern(
+ 'Sub', inputs=[moving_average_pattern, conv_batch_norm_pattern])
+ # TODO(suharshs): Use a OneofPattern here when available
+ conv_moving_average_mul_pattern = graph_matcher.OpTypePattern(
+ 'Mul', inputs=[conv_moving_average_sub_pattern, bn_decay_pattern])
matmul_batch_norm_pattern = graph_matcher.OpTypePattern(
'FusedBatchNorm',
inputs=[
@@ -191,8 +234,34 @@ def _FindFusedBatchNorms(graph):
inputs=[matmul_batch_norm_pattern,
graph_matcher.OpTypePattern('*')])
+ matmul_moving_average_sub_pattern = graph_matcher.OpTypePattern(
+ 'Sub', inputs=[moving_average_pattern, matmul_batch_norm_pattern])
+ matmul_moving_average_mul_pattern = graph_matcher.OpTypePattern(
+ 'Mul', inputs=[matmul_moving_average_sub_pattern, bn_decay_pattern])
+
conv_matcher = graph_matcher.GraphMatcher(conv_batch_norm_pattern)
matmul_matcher = graph_matcher.GraphMatcher(matmul_bn_output_reshape_pattern)
+ conv_moving_average_mul_matcher = graph_matcher.GraphMatcher(
+ conv_moving_average_mul_pattern)
+ matmul_moving_average_mul_matcher = graph_matcher.GraphMatcher(
+ matmul_moving_average_mul_pattern)
+
+ def _GetMovingAverageTensors(graph, moving_avg_mul_matcher,
+ moving_avg_sub_pattern, bn_op):
+ """Gets the moving mean and variance tensors and the batch norm momentum."""
+ for mul_match_result in moving_avg_mul_matcher.match_graph(graph):
+ sub_op = mul_match_result.get_op(moving_avg_sub_pattern)
+
+ if sub_op.inputs[1].name == bn_op.outputs[1].name:
+ # During training: Batch Mean is bn_op.outputs[1]
+ moving_mean_tensor = sub_op.inputs[0]
+ bn_decay_mean_tensor = mul_match_result.get_tensor(bn_decay_pattern)
+ if sub_op.inputs[1].name == bn_op.outputs[2].name:
+ # During training: Batch Var is bn_op.outputs[2]
+ moving_variance_tensor = sub_op.inputs[0]
+ bn_decay_var_tensor = mul_match_result.get_tensor(bn_decay_pattern)
+ return (moving_mean_tensor, bn_decay_mean_tensor, moving_variance_tensor,
+ bn_decay_var_tensor)
def _GetCommonTensors(match_result, bn_op, bn_input_tensor):
"""Gets tensors needed for FusedBatchNormMatch from match_result."""
@@ -222,10 +291,14 @@ def _FindFusedBatchNorms(graph):
# calculation, the variance is corrected by the term N/N-1 (Bessel's
# correction). The variance tensor read from FuseBatchNorm has bessel's
# correction applied, so we undo it here.
- n = math_ops.cast(
- array_ops.size(bn_input_tensor) / array_ops.size(mean_tensor),
- dtypes.float32)
- variance_tensor = bn_op.outputs[2] * (n - 1) / n
+ scope, sep, _ = bn_op.name.rpartition('/')
+ g = ops.get_default_graph()
+ with g.as_default(), g.name_scope(scope + sep):
+ n = math_ops.cast(
+ array_ops.size(bn_input_tensor) / array_ops.size(mean_tensor),
+ dtypes.float32)
+ variance_tensor = math_ops.multiply(
+ bn_op.outputs[2], (n - 1) / n, name='Undo_Bessel_Correction')
else:
mean_tensor = match_result.get_tensor(mean_pattern)
variance_tensor = match_result.get_tensor(variance_pattern)
@@ -233,15 +306,30 @@ def _FindFusedBatchNorms(graph):
variance_tensor)
for match_result in conv_matcher.match_graph(graph):
+ moving_mean_tensor = None
+ moving_variance_tensor = None
+ bn_decay_mean_tensor = None
+ bn_decay_var_tensor = None
layer_op = match_result.get_op(conv_pattern)
layer_tensor = match_result.get_tensor(conv_pattern)
bn_op = match_result.get_op(conv_batch_norm_pattern)
- # In the case of convolution the output_tensor is the output of bn_op.
- output_tensor = bn_op.outputs[0]
+ if bn_op.get_attr('is_training'):
+ (moving_mean_tensor, bn_decay_mean_tensor, moving_variance_tensor,
+ bn_decay_var_tensor) = _GetMovingAverageTensors(
+ graph,
+ moving_avg_mul_matcher=conv_moving_average_mul_matcher,
+ moving_avg_sub_pattern=conv_moving_average_sub_pattern,
+ bn_op=bn_op)
+ output_tensor = bn_op.outputs[0]
+ batch_epsilon_tensor = bn_op.get_attr('epsilon')
(input_tensor, weight_tensor, gamma_tensor, beta_tensor, mean_tensor,
- variance_tensor) = _GetCommonTensors(match_result, bn_op, layer_tensor)
- yield _FusedBatchNormMatch(
+ variance_tensor) = _GetCommonTensors(
+ match_result,
+ bn_op,
+ layer_tensor,
+ )
+ yield _BatchNormMatch(
layer_op=layer_op,
bn_op=bn_op,
output_tensor=output_tensor,
@@ -250,20 +338,38 @@ def _FindFusedBatchNorms(graph):
gamma_tensor=gamma_tensor,
beta_tensor=beta_tensor,
mean_tensor=mean_tensor,
- variance_tensor=variance_tensor)
+ variance_tensor=variance_tensor,
+ moving_mean_tensor=moving_mean_tensor,
+ moving_variance_tensor=moving_variance_tensor,
+ bn_decay_mean_tensor=bn_decay_mean_tensor,
+ bn_decay_var_tensor=bn_decay_var_tensor,
+ batch_epsilon_tensor=batch_epsilon_tensor)
for match_result in matmul_matcher.match_graph(graph):
+ moving_mean_tensor = None
+ moving_variance_tensor = None
+ bn_decay_mean_tensor = None
+ bn_decay_var_tensor = None
layer_op = match_result.get_op(matmul_pattern)
layer_tensor = match_result.get_tensor(matmul_pattern)
bn_op = match_result.get_op(matmul_batch_norm_pattern)
+ if bn_op.get_attr('is_training'):
+ (moving_mean_tensor, bn_decay_mean_tensor, moving_variance_tensor,
+ bn_decay_var_tensor) = _GetMovingAverageTensors(
+ graph,
+ moving_avg_mul_matcher=matmul_moving_average_mul_matcher,
+ moving_avg_sub_pattern=matmul_moving_average_sub_pattern,
+ bn_op=bn_op)
+
# In the MatMul case, the output of batch norm is reshaped back into a
# 2D tensor, so the output_tensor is the output of the Reshape op.
output_reshape_op = match_result.get_op(matmul_bn_output_reshape_pattern)
output_tensor = output_reshape_op.outputs[0]
+ batch_epsilon_tensor = bn_op.get_attr('epsilon')
(input_tensor, weight_tensor, gamma_tensor, beta_tensor, mean_tensor,
variance_tensor) = _GetCommonTensors(match_result, bn_op, layer_tensor)
- yield _FusedBatchNormMatch(
+ yield _BatchNormMatch(
layer_op=layer_op,
bn_op=bn_op,
output_tensor=output_tensor,
@@ -272,15 +378,21 @@ def _FindFusedBatchNorms(graph):
gamma_tensor=gamma_tensor,
beta_tensor=beta_tensor,
mean_tensor=mean_tensor,
- variance_tensor=variance_tensor)
+ variance_tensor=variance_tensor,
+ moving_mean_tensor=moving_mean_tensor,
+ moving_variance_tensor=moving_variance_tensor,
+ bn_decay_mean_tensor=bn_decay_mean_tensor,
+ bn_decay_var_tensor=bn_decay_var_tensor,
+ batch_epsilon_tensor=batch_epsilon_tensor)
-class _FusedBatchNormMatch(object):
- """Contains all information related to a found FusedBatchNorm."""
+class _BatchNormMatch(object):
+ """Contains all information related to a found Fused/UnfusedBatchNorm."""
def __init__(self, layer_op, bn_op, output_tensor, input_tensor,
weight_tensor, gamma_tensor, beta_tensor, mean_tensor,
- variance_tensor):
+ variance_tensor, moving_mean_tensor, moving_variance_tensor,
+ bn_decay_mean_tensor, bn_decay_var_tensor, batch_epsilon_tensor):
self._layer_op = layer_op
self._bn_op = bn_op
self._output_tensor = output_tensor
@@ -290,6 +402,11 @@ class _FusedBatchNormMatch(object):
self._beta_tensor = beta_tensor
self._mean_tensor = mean_tensor
self._variance_tensor = variance_tensor
+ self._moving_mean_tensor = moving_mean_tensor
+ self._moving_variance_tensor = moving_variance_tensor
+ self._bn_decay_mean_tensor = bn_decay_mean_tensor
+ self._bn_decay_var_tensor = bn_decay_var_tensor
+ self._batch_epsilon_tensor = batch_epsilon_tensor
@property
def layer_op(self):
@@ -327,8 +444,28 @@ class _FusedBatchNormMatch(object):
def variance_tensor(self):
return self._variance_tensor
+ @property
+ def moving_mean_tensor(self):
+ return self._moving_mean_tensor
+
+ @property
+ def moving_variance_tensor(self):
+ return self._moving_variance_tensor
+
+ @property
+ def batch_epsilon_tensor(self):
+ return self._batch_epsilon_tensor
+
+ @property
+ def bn_decay_mean_tensor(self):
+ return self._bn_decay_mean_tensor
+
+ @property
+ def bn_decay_var_tensor(self):
+ return self._bn_decay_var_tensor
+
-def _FoldUnfusedBatchNorms(graph):
+def _FoldUnfusedBatchNorms(graph, freeze_batch_norm_delay, is_training):
"""Finds unfused batch norm layers and folds them into preceding layers.
Folding only affects the following layers: Conv2D, fully connected, depthwise
@@ -336,6 +473,9 @@ def _FoldUnfusedBatchNorms(graph):
Args:
graph: Graph to walk and modify.
+ freeze_batch_norm_delay: How many steps to wait before freezing
+ moving mean and variance and using them for batch normalization
+ is_training: Bool, True if training
Raises:
ValueError: When batch norm folding fails.
@@ -346,7 +486,12 @@ def _FoldUnfusedBatchNorms(graph):
has_scaling = _HasScaling(graph, input_to_ops_map, bn)
# The mangling code intimately depends on BatchNorm node's internals.
- original_op, folded_op = _CreateFoldedOp(graph, bn, has_scaling=has_scaling)
+ original_op, folded_op = _CreateFoldedOp(
+ graph,
+ bn,
+ has_scaling=has_scaling,
+ freeze_batch_norm_delay=freeze_batch_norm_delay,
+ is_training=is_training)
activation = common.GetEndpointActivationOp(graph, bn)
if activation:
@@ -407,7 +552,186 @@ def _HasScaling(graph, input_to_ops_map, bn):
return sum(1 for op in rsqrt_consumers if op.type == 'Mul') == 1
-def _CreateFoldedOp(graph, context, has_scaling):
+def _GetBatchNormParams(graph, context, has_scaling):
+ """Extracts relevant tensors for folding batch norms.
+
+ Args:
+ graph: Graph to inspect.
+ context: The scope under which we look for batch norm params
+ has_scaling: Bool that specifies if scaling is done as part of batch
+ norm
+
+ Returns:
+ _BatchNormMatch containing all required batch norm parameters
+ """
+ gamma_tensor = None
+ batch_mean_tensor = None
+ batch_variance_tensor = None
+ moving_mean_tensor = None
+ moving_variance_tensor = None
+ batch_epsilon_tensor = None
+ bn_decay_mean_tensor = None
+ bn_decay_var_tensor = None
+
+ split_context = context.split('/')
+ base_context = split_context[-1]
+
+ oplist = graph.get_operations()
+ op_suffix_gamma = base_context + '/BatchNorm/gamma'
+ op_suffix_mean = base_context + '/BatchNorm/moments/Squeeze'
+ op_suffix_variance = base_context + '/BatchNorm/moments/Squeeze_1'
+ op_suffix_moving_variance = base_context + '/BatchNorm/moving_variance/read'
+ op_suffix_moving_mean = base_context + '/BatchNorm/moving_mean/read'
+ op_suffix_epsilon = base_context + '/BatchNorm/batchnorm/add/y'
+ op_suffix_bn_decay_mean = base_context + '/BatchNorm/AssignMovingAvg/decay'
+ op_suffix_bn_decay_var = base_context + '/BatchNorm/AssignMovingAvg_1/decay'
+
+ # Parse through list of ops to find relevant ops
+ for op in oplist:
+ if op.name.endswith(op_suffix_mean):
+ # This is an efficient way to check for two things:
+ # Is batch norm present and is it training mode?
+ # Batch statistics are computed only during batch norm in training
+ batch_mean_tensor = graph.get_tensor_by_name(op.name + ':0')
+ if op.name.endswith(op_suffix_variance):
+ batch_variance_tensor = graph.get_tensor_by_name(op.name + ':0')
+ if op.name.endswith(op_suffix_moving_mean):
+ moving_mean_tensor = graph.get_tensor_by_name(op.name + ':0')
+ if op.name.endswith(op_suffix_moving_variance):
+ moving_variance_tensor = graph.get_tensor_by_name(op.name + ':0')
+ if op.name.endswith(op_suffix_epsilon):
+ batch_epsilon_tensor = graph.get_tensor_by_name(op.name + ':0')
+ if op.name.endswith(op_suffix_bn_decay_mean):
+ bn_decay_mean_tensor = graph.get_tensor_by_name(op.name + ':0')
+ if op.name.endswith(op_suffix_bn_decay_var):
+ bn_decay_var_tensor = graph.get_tensor_by_name(op.name + ':0')
+ if has_scaling:
+ if op.name.endswith(op_suffix_gamma):
+ gamma_tensor = graph.get_tensor_by_name(op.name + ':0')
+
+ if not has_scaling:
+ gamma_tensor = array_ops.ones(batch_mean_tensor.shape)
+
+ return _BatchNormMatch(
+ layer_op=None,
+ bn_op=None,
+ output_tensor=None,
+ input_tensor=None,
+ weight_tensor=None,
+ gamma_tensor=gamma_tensor,
+ beta_tensor=None,
+ mean_tensor=batch_mean_tensor,
+ variance_tensor=batch_variance_tensor,
+ moving_mean_tensor=moving_mean_tensor,
+ moving_variance_tensor=moving_variance_tensor,
+ bn_decay_mean_tensor=bn_decay_mean_tensor,
+ bn_decay_var_tensor=bn_decay_var_tensor,
+ batch_epsilon_tensor=batch_epsilon_tensor)
+
+
+def _ComputeBatchNormCorrections(context, match, freeze_batch_norm_delay,
+ fused_batch_norm):
+ """Computes batch norm correction params.
+
+ Before batch normalization is frozen:
+ We use batch statistics for batch norm.
+ correction_scale = sigma_b/sigma_mv
+ correction_recip = 1/correction_scale
+ correction_offset = 0
+
+ After batch normalization is frozen:
+ correction_scale = sigma_b/sigma_mv
+ correction_recip = 1
+ correction_offset = gamma*(mu_b/sigma_b-mu_mv/sigma_mv).
+
+ Batch norm is frozen if global_step > bn_freeze_delay.
+ The corrections ensure that:
+ a) The weights are quantized after scaling by gamma/sigma_mv. This enables
+ smoother training as the scaling on the weights changes slowly, rather than
+ jump across mini-batches
+ b) Changing the values of the corrections allows for one to switch between
+ using batch statistics to using moving mean and average, without requiring
+ changes to batch_norm
+
+
+ Args:
+ context: The scope under which we look for batch norm params
+ match: Object containg required batch norm tensors for correction
+ computation
+ freeze_batch_norm_delay: Delay in steps at which computation switches
+ from regular batch norm to frozen mean and variance.
+ fused_batch_norm: Bool, true if fused batch norm is used
+
+ Returns:
+ A tuple of correction_scale, correction_recip, correction_offset
+ """
+
+ g = ops.get_default_graph()
+ with g.name_scope(context + 'batch_norm_correction'):
+ recip_sigma_mv = math_ops.rsqrt(
+ match.moving_variance_tensor + match.batch_epsilon_tensor)
+ recip_sigma = math_ops.rsqrt(
+ match.variance_tensor + match.batch_epsilon_tensor)
+ correction_scale = math_ops.divide(
+ recip_sigma_mv, recip_sigma, name='scale_compute')
+ correction_scale = array_ops.identity(
+ correction_scale, name='correction_scale')
+ correction_recip = math_ops.reciprocal(
+ correction_scale, name='reciprocal_compute')
+ correction_offset = math_ops.multiply(
+ match.gamma_tensor,
+ match.mean_tensor * recip_sigma -
+ match.moving_mean_tensor * recip_sigma_mv,
+ name='offset_compute')
+
+ if freeze_batch_norm_delay is not None:
+ use_mv_avg = math_ops.greater_equal(
+ training_util.get_or_create_global_step(),
+ freeze_batch_norm_delay,
+ name='use_moving_average')
+ else:
+ use_mv_avg = False
+
+ bn_decay_zero = 0.0
+ bn_decay_mean_consumers = list(match.bn_decay_mean_tensor.consumers())
+ bn_decay_var_consumers = list(match.bn_decay_mean_tensor.consumers())
+
+ bn_decay_mean_out = utils.smart_cond(
+ use_mv_avg,
+ lambda: bn_decay_zero,
+ lambda: match.bn_decay_mean_tensor,
+ name='freeze_moving_mean')
+ graph_editor.reroute_ts(
+ [bn_decay_mean_out], [match.bn_decay_mean_tensor],
+ can_modify=bn_decay_mean_consumers)
+
+ if fused_batch_norm is False:
+ bn_decay_var_consumers = list(match.bn_decay_var_tensor.consumers())
+ bn_decay_var_out = utils.smart_cond(
+ use_mv_avg,
+ lambda: bn_decay_zero,
+ lambda: match.bn_decay_var_tensor,
+ name='freeze_moving_var')
+ graph_editor.reroute_ts(
+ [bn_decay_var_out], [match.bn_decay_var_tensor],
+ can_modify=bn_decay_var_consumers)
+
+ correction_recip = utils.smart_cond(
+ use_mv_avg,
+ lambda: array_ops.ones(correction_scale.shape),
+ lambda: correction_recip,
+ name='correction_recip')
+
+ correction_offset = utils.smart_cond(
+ use_mv_avg,
+ lambda: correction_offset,
+ lambda: array_ops.zeros(correction_offset.shape),
+ name='correction_offset')
+ return correction_scale, correction_recip, correction_offset
+
+
+def _CreateFoldedOp(graph, context, has_scaling, freeze_batch_norm_delay,
+ is_training):
"""Folds in batch norm layer into preceding convolution or FC layer.
Creates 3 new nodes, connects their inputs and adds them to the graph:
@@ -419,6 +743,9 @@ def _CreateFoldedOp(graph, context, has_scaling):
context: String, batch norm context, i.e. node into which BatchNorm is
nested.
has_scaling: Whether the batch norm has scaling enabled.
+ freeze_batch_norm_delay: How many steps to wait before freezing
+ moving mean and variance and using them for batch normalization
+ is_training: Bool, true if training
Raises:
ValueError: When operation type is not supported, or input and output tensor
@@ -435,19 +762,43 @@ def _CreateFoldedOp(graph, context, has_scaling):
mul_scale_name)
op_below = mul_scale.inputs[0].op
weights = op_below.inputs[1]
-
+ match = _GetBatchNormParams(
+ graph=graph, context=context, has_scaling=has_scaling)
+ correction_scale, correction_recip, correction_offset = None, None, None
+ if is_training:
+ correction_scale, correction_recip, correction_offset = (
+ _ComputeBatchNormCorrections(
+ context=context,
+ match=match,
+ freeze_batch_norm_delay=freeze_batch_norm_delay,
+ fused_batch_norm=False))
# Special handling for weights of depthwise convolution.
if op_below.type == 'DepthwiseConv2dNative':
- new_shape = [weights.get_shape().as_list()[2],
- weights.get_shape().as_list()[3]]
+ new_shape = [
+ weights.get_shape().as_list()[2],
+ weights.get_shape().as_list()[3]
+ ]
scale_name = 'mul' if has_scaling else 'Rsqrt'
- scale = graph.get_operation_by_name(context + '/BatchNorm/batchnorm/' +
- scale_name)
+ scale = graph.get_operation_by_name(
+ context + '/BatchNorm/batchnorm/' + scale_name)
scale = array_ops.reshape(scale.outputs[0], new_shape,
context + '/scale_reshape')
- mul_fold = _CloneOp(mul_scale, context + '/mul_fold',
- [(0, weights), (1, scale)])
+
+ if correction_scale is not None:
+ correction_scale = array_ops.reshape(correction_scale, new_shape,
+ context + '/correction_reshape')
+ with ops.device(mul_scale.device):
+ weights = math_ops.multiply(correction_scale, weights,
+ context + '/correction_mult')
+
+ mul_fold = _CloneOp(mul_scale, context + '/mul_fold', [(0, weights),
+ (1, scale)])
elif op_below.type in ['Conv2D', 'MatMul']:
+
+ if correction_scale is not None:
+ with ops.device(mul_scale.device):
+ weights = math_ops.multiply(correction_scale, weights,
+ context + '/correction_mult')
mul_fold = _CloneOp(mul_scale, context + '/mul_fold', [(0, weights)])
else:
raise ValueError('Cannot handle operation of type: %s' % op_below.op)
@@ -456,10 +807,17 @@ def _CreateFoldedOp(graph, context, has_scaling):
conv_or_fc_folded = _CloneOp(op_below, op_below.name + '_Fold',
[(1, mul_fold.outputs[0])])
- add_shift = graph.get_operation_by_name(context +
- '/BatchNorm/batchnorm/add_1')
- add_fold = _CloneOp(add_shift, context + '/add_fold',
- [(0, conv_or_fc_folded.outputs[0])])
+ add_shift = graph.get_operation_by_name(
+ context + '/BatchNorm/batchnorm/add_1')
+
+ corrected_output = conv_or_fc_folded.outputs[0]
+ if correction_offset is not None:
+ with ops.device(conv_or_fc_folded.device):
+ corrected_output = math_ops.multiply(correction_recip, corrected_output,
+ context + '/post_conv_mul')
+ corrected_output = math_ops.add(corrected_output, (correction_offset),
+ context + '/correction_add')
+ add_fold = _CloneOp(add_shift, context + '/add_fold', [(0, corrected_output)])
_AssertShapesMatch('add_fold', add_fold.inputs[0], add_fold.outputs[0])
return add_shift, add_fold
diff --git a/tensorflow/contrib/quantize/python/fold_batch_norms_test.py b/tensorflow/contrib/quantize/python/fold_batch_norms_test.py
index ecf321ff57..330bd8a647 100644
--- a/tensorflow/contrib/quantize/python/fold_batch_norms_test.py
+++ b/tensorflow/contrib/quantize/python/fold_batch_norms_test.py
@@ -46,26 +46,27 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase):
def _RunTestOverParameters(self, test_fn):
parameters_list = [
- # (relu, relu_op_name, with_bypass, has_scaling, fused_batch_norm)
- (nn_ops.relu6, 'Relu6', False, False, False),
- (nn_ops.relu, 'Relu', False, False, False),
- (nn_ops.relu6, 'Relu6', True, False, False),
- (nn_ops.relu, 'Relu', True, False, False),
- (nn_ops.relu6, 'Relu6', False, True, False),
- (nn_ops.relu, 'Relu', False, True, False),
- (nn_ops.relu6, 'Relu6', True, True, False),
- (nn_ops.relu, 'Relu', True, True, False),
+ # (relu, relu_op_name, with_bypass, has_scaling, fused_batch_norm,
+ # freeze_batch_norm_delay)
+ (nn_ops.relu6, 'Relu6', False, False, False, 100),
+ (nn_ops.relu, 'Relu', False, False, False, None),
+ (nn_ops.relu6, 'Relu6', True, False, False, 100),
+ (nn_ops.relu, 'Relu', True, False, False, None),
+ (nn_ops.relu6, 'Relu6', False, True, False, 100),
+ (nn_ops.relu, 'Relu', False, True, False, None),
+ (nn_ops.relu6, 'Relu6', True, True, False, 100),
+ (nn_ops.relu, 'Relu', True, True, False, None),
# Fused batch norm always has scaling enabled.
- (nn_ops.relu6, 'Relu6', False, True, True),
- (nn_ops.relu, 'Relu', False, True, True),
- (nn_ops.relu6, 'Relu6', True, True, True),
- (nn_ops.relu, 'Relu', True, True, True),
+ (nn_ops.relu6, 'Relu6', False, True, True, None),
+ (nn_ops.relu, 'Relu', False, True, True, 100),
+ (nn_ops.relu6, 'Relu6', True, True, True, None),
+ (nn_ops.relu, 'Relu', True, True, True, 100),
]
for params in parameters_list:
- test_fn(params[0], params[1], params[2], params[3], params[4])
+ test_fn(params[0], params[1], params[2], params[3], params[4], params[5])
def _TestFoldConv2d(self, relu, relu_op_name, with_bypass, has_scaling,
- fused_batch_norm):
+ fused_batch_norm, freeze_batch_norm_delay):
"""Tests folding cases: inputs -> Conv2d with batch norm -> Relu*.
Args:
@@ -75,6 +76,8 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase):
inputs to just before Relu*.
has_scaling: Bool, when true the batch norm has scaling.
fused_batch_norm: Bool, when true the batch norm is fused.
+ freeze_batch_norm_delay: None or the number of steps after which training
+ switches to using frozen mean and variance
"""
g = ops.Graph()
with g.as_default():
@@ -99,12 +102,13 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase):
node = math_ops.add(inputs, node, name='test/Add')
relu(node, name='test/' + relu_op_name)
- fold_batch_norms.FoldBatchNorms(g)
+ fold_batch_norms.FoldBatchNorms(
+ g, is_training=True, freeze_batch_norm_delay=freeze_batch_norm_delay)
folded_mul = g.get_operation_by_name(scope + '/mul_fold')
self.assertEqual(folded_mul.type, 'Mul')
self._AssertInputOpsAre(folded_mul, [
- scope + '/weights/read',
+ scope + '/correction_mult',
self._BatchNormMultiplierName(scope, has_scaling, fused_batch_norm)
])
self._AssertOutputGoesToOps(folded_mul, g, [scope + '/Conv2D_Fold'])
@@ -113,12 +117,12 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase):
self.assertEqual(folded_conv.type, 'Conv2D')
self._AssertInputOpsAre(folded_conv,
[scope + '/mul_fold', inputs.op.name])
- self._AssertOutputGoesToOps(folded_conv, g, [scope + '/add_fold'])
+ self._AssertOutputGoesToOps(folded_conv, g, [scope + '/post_conv_mul'])
folded_add = g.get_operation_by_name(scope + '/add_fold')
self.assertEqual(folded_add.type, 'Add')
self._AssertInputOpsAre(folded_add, [
- scope + '/Conv2D_Fold',
+ scope + '/correction_add',
self._BathNormBiasName(scope, fused_batch_norm)
])
output_op_names = ['test/Add' if with_bypass else 'test/' + relu_op_name]
@@ -128,7 +132,8 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase):
self._RunTestOverParameters(self._TestFoldConv2d)
def _TestFoldConv2dUnknownShape(self, relu, relu_op_name, with_bypass,
- has_scaling, fused_batch_norm):
+ has_scaling, fused_batch_norm,
+ freeze_batch_norm_delay):
"""Tests folding cases: inputs -> Conv2d with batch norm -> Relu*.
Tests that folding works even with an input shape where some dimensions are
@@ -141,6 +146,8 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase):
inputs to just before Relu*.
has_scaling: Bool, when true the batch norm has scaling.
fused_batch_norm: Bool, when true the batch norm is fused.
+ freeze_batch_norm_delay: None or the number of steps after which training
+ switches to using frozen mean and variance
"""
g = ops.Graph()
with g.as_default():
@@ -164,12 +171,13 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase):
node = math_ops.add(inputs, node, name='test/Add')
relu(node, name='test/' + relu_op_name)
- fold_batch_norms.FoldBatchNorms(g)
+ fold_batch_norms.FoldBatchNorms(
+ g, is_training=True, freeze_batch_norm_delay=freeze_batch_norm_delay)
folded_mul = g.get_operation_by_name(scope + '/mul_fold')
self.assertEqual(folded_mul.type, 'Mul')
self._AssertInputOpsAre(folded_mul, [
- scope + '/weights/read',
+ scope + '/correction_mult',
self._BatchNormMultiplierName(scope, has_scaling, fused_batch_norm)
])
self._AssertOutputGoesToOps(folded_mul, g, [scope + '/Conv2D_Fold'])
@@ -177,12 +185,12 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase):
folded_conv = g.get_operation_by_name(scope + '/Conv2D_Fold')
self.assertEqual(folded_conv.type, 'Conv2D')
self._AssertInputOpsAre(folded_conv, [scope + '/mul_fold', inputs.op.name])
- self._AssertOutputGoesToOps(folded_conv, g, [scope + '/add_fold'])
+ self._AssertOutputGoesToOps(folded_conv, g, [scope + '/post_conv_mul'])
folded_add = g.get_operation_by_name(scope + '/add_fold')
self.assertEqual(folded_add.type, 'Add')
self._AssertInputOpsAre(folded_add, [
- scope + '/Conv2D_Fold',
+ scope + '/correction_add',
self._BathNormBiasName(scope, fused_batch_norm)
])
output_op_names = ['test/Add' if with_bypass else 'test/' + relu_op_name]
@@ -192,7 +200,8 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase):
self._RunTestOverParameters(self._TestFoldConv2dUnknownShape)
def _TestFoldFullyConnectedLayer(self, relu, relu_op_name, with_bypass,
- has_scaling, fused_batch_norm):
+ has_scaling, fused_batch_norm,
+ freeze_batch_norm_delay):
"""Tests folding cases: inputs -> FC with batch norm -> Relu*.
Args:
@@ -202,6 +211,8 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase):
inputs to just before Relu*.
has_scaling: Bool, when true the batch norm has scaling.
fused_batch_norm: Bool, when true the batch norm is fused.
+ freeze_batch_norm_delay: None or the number of steps after which training
+ switches to using frozen mean and variance
"""
g = ops.Graph()
with g.as_default():
@@ -223,12 +234,13 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase):
node = math_ops.add(inputs, node, name='test/Add')
relu(node, name='test/' + relu_op_name)
- fold_batch_norms.FoldBatchNorms(g)
+ fold_batch_norms.FoldBatchNorms(
+ g, is_training=True, freeze_batch_norm_delay=freeze_batch_norm_delay)
folded_mul = g.get_operation_by_name(scope + '/mul_fold')
self.assertEqual(folded_mul.type, 'Mul')
self._AssertInputOpsAre(folded_mul, [
- scope + '/weights/read',
+ scope + '/correction_mult',
self._BatchNormMultiplierName(scope, has_scaling, fused_batch_norm)
])
self._AssertOutputGoesToOps(folded_mul, g, [scope + '/MatMul_Fold'])
@@ -237,12 +249,12 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase):
self.assertEqual(folded_conv.type, 'MatMul')
self._AssertInputOpsAre(folded_conv,
[scope + '/mul_fold', inputs.op.name])
- self._AssertOutputGoesToOps(folded_conv, g, [scope + '/add_fold'])
+ self._AssertOutputGoesToOps(folded_conv, g, [scope + '/post_conv_mul'])
folded_add = g.get_operation_by_name(scope + '/add_fold')
self.assertEqual(folded_add.type, 'Add')
self._AssertInputOpsAre(folded_add, [
- scope + '/MatMul_Fold',
+ scope + '/correction_add',
self._BathNormBiasName(scope, fused_batch_norm)
])
output_op_names = ['test/Add' if with_bypass else 'test/' + relu_op_name]
@@ -252,7 +264,8 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase):
self._RunTestOverParameters(self._TestFoldFullyConnectedLayer)
def _TestFoldDepthwiseConv2d(self, relu, relu_op_name, with_bypass,
- has_scaling, fused_batch_norm):
+ has_scaling, fused_batch_norm,
+ freeze_batch_norm_delay):
"""Tests folding: inputs -> DepthwiseConv2d with batch norm -> Relu*.
Args:
@@ -262,6 +275,8 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase):
inputs to just before Relu*.
has_scaling: Bool, when true the batch norm has scaling.
fused_batch_norm: Bool, when true the batch norm is fused.
+ freeze_batch_norm_delay: None or the number of steps after which training
+ switches to using frozen mean and variance
"""
g = ops.Graph()
with g.as_default():
@@ -286,7 +301,8 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase):
node = math_ops.add(inputs, node, name='test/Add')
relu(node, name='test/' + relu_op_name)
- fold_batch_norms.FoldBatchNorms(g)
+ fold_batch_norms.FoldBatchNorms(
+ g, is_training=True, freeze_batch_norm_delay=freeze_batch_norm_delay)
folded_mul = g.get_operation_by_name(scope + '/mul_fold')
self.assertEqual(folded_mul.type, 'Mul')
@@ -295,8 +311,7 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase):
else:
scale_reshape_op_name = scope + '/scale_reshape'
self._AssertInputOpsAre(folded_mul,
- [scope + '/depthwise_weights/read',
- scale_reshape_op_name])
+ [scope + '/correction_mult', scale_reshape_op_name])
self._AssertOutputGoesToOps(folded_mul, g, [scope + '/depthwise_Fold'])
scale_reshape = g.get_operation_by_name(scale_reshape_op_name)
@@ -311,12 +326,12 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase):
self.assertEqual(folded_conv.type, 'DepthwiseConv2dNative')
self._AssertInputOpsAre(folded_conv,
[scope + '/mul_fold', inputs.op.name])
- self._AssertOutputGoesToOps(folded_conv, g, [scope + '/add_fold'])
+ self._AssertOutputGoesToOps(folded_conv, g, [scope + '/post_conv_mul'])
folded_add = g.get_operation_by_name(scope + '/add_fold')
self.assertEqual(folded_add.type, 'Add')
self._AssertInputOpsAre(folded_add, [
- scope + '/depthwise_Fold',
+ scope + '/correction_add',
self._BathNormBiasName(scope, fused_batch_norm)
])
output_op_names = ['test/Add' if with_bypass else 'test/' + relu_op_name]
@@ -326,7 +341,8 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase):
self._RunTestOverParameters(self._TestFoldDepthwiseConv2d)
def _TestCompareFoldAndUnfolded(self, relu, relu_op_name, with_bypass,
- has_scaling, fused_batch_norm):
+ has_scaling, fused_batch_norm,
+ freeze_batch_norm_delay):
"""Tests that running folded and unfolded BN returns the same results.
Args:
@@ -336,6 +352,8 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase):
inputs to just before Relu*.
has_scaling: Bool, when true the batch norm has scaling.
fused_batch_norm: Bool, when true the batch norm is fused.
+ freeze_batch_norm_delay: None or the number of steps after which training
+ switches to using frozen mean and variance
"""
random_seed.set_random_seed(1234)
unfolded_g = ops.Graph()
@@ -361,11 +379,12 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase):
if with_bypass:
node = math_ops.add(inputs, node, name='test/Add')
relu_node = relu(node, name='test/' + relu_op_name)
-
folded_g = copy_graph.CopyGraph(unfolded_g)
with folded_g.as_default():
- fold_batch_norms.FoldBatchNorms(folded_g)
-
+ fold_batch_norms.FoldBatchNorms(
+ folded_g,
+ is_training=True,
+ freeze_batch_norm_delay=freeze_batch_norm_delay)
with session.Session(graph=unfolded_g) as sess:
sess.run(variables.global_variables_initializer())
grad_node = gradients.gradients(relu_node, inputs)
diff --git a/tensorflow/contrib/quantize/python/quantize_graph.py b/tensorflow/contrib/quantize/python/quantize_graph.py
index bbd9743d80..89b744c559 100644
--- a/tensorflow/contrib/quantize/python/quantize_graph.py
+++ b/tensorflow/contrib/quantize/python/quantize_graph.py
@@ -52,9 +52,19 @@ def _create_graph(input_graph,
"""
# TODO(suharshs): Describe the process in more detail in the doc string.
g = copy_graph.CopyGraph(input_graph)
+ if is_training:
+ # TODO(raghuramank): Need to make freeze_batch_norm_delay
+ # a function of the batch size. For now setting this to 250 epochs
+ # This corresponds to 5 million steps at a batch size of 64.
+ freeze_batch_norm_delay = 5000000
+ else:
+ freeze_batch_norm_delay = None
with g.as_default():
with ops.device(device_name_or_function):
- fold_batch_norms.FoldBatchNorms(g)
+ fold_batch_norms.FoldBatchNorms(
+ g,
+ freeze_batch_norm_delay=freeze_batch_norm_delay,
+ is_training=is_training)
quantize.Quantize(g, is_training=is_training)
if elements is None:
return g
diff --git a/tensorflow/contrib/training/BUILD b/tensorflow/contrib/training/BUILD
index cccaa2b833..6db373d2d5 100644
--- a/tensorflow/contrib/training/BUILD
+++ b/tensorflow/contrib/training/BUILD
@@ -26,6 +26,7 @@ py_library(
"python/training/resample.py",
"python/training/sampling_ops.py",
"python/training/sequence_queueing_state_saver.py",
+ "python/training/tensor_queue_dataset.py",
"python/training/training.py",
"python/training/tuner.py",
],
@@ -285,6 +286,28 @@ py_test(
],
)
+py_test(
+ name = "tensor_queue_dataset_test",
+ size = "large",
+ srcs = ["python/training/tensor_queue_dataset_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["notsan"],
+ deps = [
+ ":training_py",
+ "//tensorflow/contrib/data/python/kernel_tests:dataset_serialization_test",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:gradients",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:random_seed",
+ "//tensorflow/python:training",
+ "//tensorflow/python:variables",
+ "//tensorflow/python/data",
+ "//third_party/py/numpy",
+ ],
+)
+
filegroup(
name = "all_files",
srcs = glob(
diff --git a/tensorflow/contrib/training/python/training/tensor_queue_dataset.py b/tensorflow/contrib/training/python/training/tensor_queue_dataset.py
new file mode 100644
index 0000000000..409aba817c
--- /dev/null
+++ b/tensorflow/contrib/training/python/training/tensor_queue_dataset.py
@@ -0,0 +1,200 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Python wrappers for Datasets and Iterators."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.data.util import nest
+from tensorflow.python.data.util import sparse
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import tensor_util
+from tensorflow.python.ops import gen_dataset_ops
+from tensorflow.python.util import nest as tf_nest
+
+
+class _PrependFromQueueAndPaddedBatchDataset(dataset_ops.Dataset):
+ """A `Dataset` that prepends a queue to another `Dataset`.
+
+ A vector of handles to the queue is returned as the first component of
+ the associated iterator. This vector can be passed to
+ `enqueue_in_queue_dataset` to add new elements to the queue.
+ """
+
+ def __init__(self, input_dataset, batch_size, padded_shapes, padding_values):
+ """Initialize `PrependFromQueueAndPaddedBatchDataset`."""
+ super(_PrependFromQueueAndPaddedBatchDataset, self).__init__()
+ if sparse.any_sparse(input_dataset.output_classes):
+ raise TypeError(
+ "Batching of padded sparse tensors is not currently supported")
+ self._input_dataset = input_dataset
+ self._batch_size = ops.convert_to_tensor(
+ batch_size, dtype=dtypes.int64, name="batch_size")
+ # pylint: disable=protected-access
+ if padded_shapes is None:
+ self._padded_shapes = nest.map_structure(
+ dataset_ops._partial_shape_to_tensor, input_dataset.output_shapes)
+ else:
+ self._padded_shapes = nest.map_structure_up_to(
+ input_dataset.output_shapes, dataset_ops._partial_shape_to_tensor,
+ padded_shapes)
+ padding_values = (
+ padding_values if padding_values is not None else
+ dataset_ops._default_padding(input_dataset))
+ self._padding_values = nest.map_structure_up_to(
+ input_dataset.output_shapes, dataset_ops._padding_value_to_tensor,
+ padding_values, input_dataset.output_types)
+ # pylint: enable=protected-access
+
+ def _as_variant_tensor(self):
+ # pylint: disable=protected-access
+ return gen_dataset_ops.prepend_from_queue_and_padded_batch_dataset(
+ self._input_dataset._as_variant_tensor(),
+ batch_size=self._batch_size,
+ padded_shapes=[
+ ops.convert_to_tensor(s, dtype=dtypes.int64)
+ for s in nest.flatten(self._padded_shapes)
+ ],
+ padding_values=nest.flatten(self._padding_values),
+ output_shapes=nest.flatten(
+ sparse.as_dense_shapes(self.output_shapes, self.output_classes)))
+ # pylint: enable=protected-access
+
+ @property
+ def output_classes(self):
+ return (ops.Tensor, self._input_dataset.output_classes)
+
+ def _as_batch_shape(self, shape_like):
+ return tensor_shape.vector(None).concatenate(
+ tensor_util.constant_value_as_shape(shape_like))
+
+ @property
+ def output_shapes(self):
+ # First output is a variant representing the Queue
+ return (tensor_shape.vector(None),
+ nest.map_structure(self._as_batch_shape, self._padded_shapes))
+
+ @property
+ def output_types(self):
+ # First output is a variant representing the Queue
+ return (dtypes.variant, self._input_dataset.output_types)
+
+
+def prepend_from_queue_and_padded_batch_dataset(batch_size,
+ padding_values=None,
+ padded_shapes=None):
+ """A transformation that prepends a queue to a `Dataset` and batches results.
+
+ A vector of handles to the queue is returned as the first component of the
+ associated iterator. This vector can be passed to `enqueue_in_queue_dataset`
+ to add new elements to the queue.
+
+ Below is an example of how this dataset might be used to split incoming
+ variable-length sequences into "head" and "rest" parts, where "rest" parts
+ are re-enqueued back into the dataset. A more realistic example would
+ perform some calculation on the "head" and modify some components of "rest"
+ with the result (before re-enqueueing).
+
+ ```python
+ dataset = tf.data.Dataset.from_tensor_slices([2*x for x in range(10)])
+ # Make a dataset of variable-length vectors and their lengths.
+ dataset = dataset.map(lambda count: (count, tf.ones((count,))))
+ # Emit a queue we can prepend to, and counts/values as padded batch.
+ dataset = dataset.apply(
+ tf.contrib.training.prepend_from_queue_and_padded_batch_dataset(
+ batch_size=10))
+ dataset = dataset.prefetch(1)
+
+ iterator = dataset.make_one_shot_iterator()
+ queue, (count, padded_value) = iterator.get_next()
+
+ # Split the padded_value into two pieces: head and rest
+ rest_indices = tf.squeeze(tf.where(count > 3), axis=1)
+ bound = tf.minimum(3, tf.reduce_max(count))
+ value_head = padded_value[:, :bound]
+ count_rest = tf.gather(count - 3, rest_indices)
+ value_rest = tf.gather(padded_value[:, bound:], rest_indices)
+ queue_rest = tf.gather(queue, rest_indices)
+ enqueue_rest_op = tf.contrib.training.enqueue_in_queue_dataset(
+ queue_rest, (count_rest, value_rest))
+ with tf.control_dependencies([enqueue_rest_op]):
+ calculation = fn(value_head)
+
+ while True: # Will raise OutOfRange when finished with all pieces.
+ session.run(calculation)
+ ```
+
+ Args:
+ batch_size: `int64` scalar tensor. The batch size to use when performing
+ padded batching.
+ padding_values: (optional) Nested tuple of scalar tensors. If provided,
+ the structure and dtypes of padding_values should match that of
+ incoming dataset's `output_types`.
+ padded_shapes: (optional) Nested tuple of `int64` vector tensors.
+ If provided, the structure must match that of the incoming dataset's
+ `output_types`. If not provided, the incoming dataset's `output_shapes`
+ is used. Any unknown (`None` or `-1`) dimensions in the shapes are
+ treated as being unique per-batch: for each batch time, an unknown
+ dimension is replaced with the maximum given value of this dimension
+ across all tensors for the given component in the batch.
+
+ Returns:
+ A `Dataset` transformation function, which can be passed to
+ @{tf.data.Dataset.apply}.
+ """
+
+ def _apply_fn(dataset):
+ return _PrependFromQueueAndPaddedBatchDataset(
+ dataset,
+ batch_size=batch_size,
+ padding_values=padding_values,
+ padded_shapes=padded_shapes)
+
+ return _apply_fn
+
+
+def enqueue_in_queue_dataset(queue, components):
+ """Enqueue components into queue from `PrependFromQueueAndPaddedBatchDataset`.
+
+ The components' dtypes and shapes must be compatible with the `output_shapes`
+ attribute of the `dataset` created by
+ `prepend_from_queue_and_padded_batch_dataset`. This operation supports both
+ non-batched and batched modes.
+
+ For more details, see the example in the docstring for
+ `prepend_from_queue_and_padded_batch_dataset`.
+
+ Args:
+ queue: `variant` scalar or vector tensor.
+ The tensor emitted by the first component of the iterator associated with
+ `prepend_from_queue_and_padded_batch_dataset`. If this is a scalar,
+ then the `components` input tensors should not have a prepended batch
+ dimension.
+ components: Nested tuple of tensors, each with a leading batch dimension
+ if `queue` is a vector. The structure, dtypes, and shapes
+ (excluding batch dimension) must match the nested tuples
+ `dataset.output_types[1]` and `dataset.output_shapes[1]` (the non-queue
+ output types and shapes) of the `dataset` emitted by
+ the original `prepend_from_queue_and_padded_batch_dataset` call.
+
+ Returns:
+ An `Operation` that enqueues `components` into the dataset(s) associated
+ with entries of `queue`.
+ """
+ return gen_dataset_ops.enqueue_in_queue_dataset(
+ queue=queue, components=tf_nest.flatten(components))
diff --git a/tensorflow/contrib/training/python/training/tensor_queue_dataset_test.py b/tensorflow/contrib/training/python/training/tensor_queue_dataset_test.py
new file mode 100644
index 0000000000..0338f409a2
--- /dev/null
+++ b/tensorflow/contrib/training/python/training/tensor_queue_dataset_test.py
@@ -0,0 +1,355 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for TensorQueueDataset."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base
+from tensorflow.contrib.training.python.training import tensor_queue_dataset as tqd
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import string_ops
+from tensorflow.python.platform import test
+
+
+class PrependFromQueueAndPaddedBatchDatasetTest(test.TestCase):
+
+ def testNoEnqueue(self):
+ dataset = dataset_ops.Dataset.from_tensor_slices([0, 1, 2])
+ dataset = dataset.apply(
+ tqd.prepend_from_queue_and_padded_batch_dataset(batch_size=1))
+ self.assertEqual((dtypes.variant, dtypes.int32), dataset.output_types)
+ self.assertAllEqual(([None],) * 2,
+ [x.as_list() for x in dataset.output_shapes])
+ iterator = dataset.make_one_shot_iterator()
+ _, value = iterator.get_next()
+ self.assertEqual([0], self.evaluate(value))
+ self.assertEqual([1], self.evaluate(value))
+ self.assertEqual([2], self.evaluate(value))
+ with self.assertRaisesOpError("End of sequence"):
+ self.evaluate(value)
+
+ def testBatchedNoEnqueue(self):
+ dataset = dataset_ops.Dataset.from_tensor_slices([0, 1, 2])
+ dataset = dataset.apply(
+ tqd.prepend_from_queue_and_padded_batch_dataset(batch_size=2))
+ iterator = dataset.make_one_shot_iterator()
+ _, value = iterator.get_next()
+ self.assertAllEqual([0, 1], self.evaluate(value))
+ self.assertAllEqual([2], self.evaluate(value))
+ with self.assertRaisesOpError("End of sequence"):
+ self.evaluate(value)
+
+ def testBatchedWithBiggerPaddingNoEnqueue(self):
+ dataset = dataset_ops.Dataset.from_tensor_slices([[0], [1], [2]])
+ dataset = dataset.apply(
+ tqd.prepend_from_queue_and_padded_batch_dataset(
+ batch_size=2, padded_shapes=[3]))
+ iterator = dataset.make_one_shot_iterator()
+ _, value = iterator.get_next()
+ self.assertAllEqual([[0, 0, 0], [1, 0, 0]], self.evaluate(value))
+ self.assertAllEqual([[2, 0, 0]], self.evaluate(value))
+ with self.assertRaisesOpError("End of sequence"):
+ self.evaluate(value)
+
+ def testBatchedWithBiggerPaddingOneEnqueue(self):
+ dataset = dataset_ops.Dataset.from_tensor_slices([[0], [1], [2]])
+ dataset = dataset.apply(
+ tqd.prepend_from_queue_and_padded_batch_dataset(
+ batch_size=1, padded_shapes=[3]))
+ iterator = dataset.make_one_shot_iterator()
+ queue_handle, value = iterator.get_next()
+ enqueue_negative = tqd.enqueue_in_queue_dataset(queue_handle, -value)
+ with self.test_session() as sess:
+ self.assertAllEqual([[0, 0, 0]], sess.run(value))
+ value_1, _ = sess.run([value, enqueue_negative])
+ self.assertAllEqual([[1, 0, 0]], value_1)
+ value_2, _ = sess.run([value, enqueue_negative])
+ self.assertAllEqual([[-1, 0, 0]], value_2)
+ value_3 = sess.run(value)
+ self.assertAllEqual([[1, 0, 0]], value_3)
+ value_4, _ = sess.run([value, enqueue_negative])
+ self.assertAllEqual([[2, 0, 0]], value_4)
+ value_5 = sess.run(value)
+ self.assertAllEqual([[-2, 0, 0]], value_5)
+ with self.assertRaisesOpError("End of sequence"):
+ sess.run(value)
+
+ def testOneEnqueue(self):
+ dataset = dataset_ops.Dataset.from_tensor_slices([0, 1, 2])
+ dataset = dataset.apply(
+ tqd.prepend_from_queue_and_padded_batch_dataset(batch_size=1))
+ iterator = dataset.make_one_shot_iterator()
+ queue_handle, value = iterator.get_next()
+ enqueue_negative = tqd.enqueue_in_queue_dataset(queue_handle, -value)
+ with self.test_session() as sess:
+ self.assertEqual([0], sess.run(value))
+ value_1, _ = sess.run([value, enqueue_negative])
+ self.assertEqual([1], value_1)
+ value_2, _ = sess.run([value, enqueue_negative])
+ self.assertEqual([-1], value_2)
+ value_3 = sess.run(value)
+ self.assertEqual([1], value_3)
+ value_4, _ = sess.run([value, enqueue_negative])
+ self.assertEqual([2], value_4)
+ value_5 = sess.run(value)
+ self.assertEqual([-2], value_5)
+ with self.assertRaisesOpError("End of sequence"):
+ sess.run(value)
+
+ def testBatchedOneEnqueue(self):
+ dataset = dataset_ops.Dataset.from_tensor_slices([0, 1, 2])
+ dataset = dataset.apply(
+ tqd.prepend_from_queue_and_padded_batch_dataset(batch_size=2))
+ iterator = dataset.make_one_shot_iterator()
+ queue_handle, value = iterator.get_next()
+ enqueue_negative = tqd.enqueue_in_queue_dataset(queue_handle, -value)
+ enqueue_zeroth = tqd.enqueue_in_queue_dataset([queue_handle[0]],
+ array_ops.expand_dims(
+ value[0], axis=0))
+ with self.test_session() as sess:
+ value_0, _ = sess.run([value, enqueue_negative])
+ self.assertAllEqual([0, 1], value_0)
+ value_1, _ = sess.run([value, enqueue_zeroth])
+ self.assertAllEqual([0, -1], value_1)
+ value_2, _ = sess.run([value, enqueue_negative])
+ self.assertAllEqual([0, 2], value_2)
+ self.assertAllEqual([0, -2], sess.run(value))
+ with self.assertRaisesOpError("End of sequence"):
+ sess.run(value)
+
+ def testManyEnqueue(self):
+ dataset = dataset_ops.Dataset.from_tensor_slices([0, 1])
+ dataset = dataset.apply(
+ tqd.prepend_from_queue_and_padded_batch_dataset(batch_size=1))
+ iterator = dataset.make_one_shot_iterator()
+ queue_handle, value = iterator.get_next()
+ enqueue_many_more = [
+ tqd.enqueue_in_queue_dataset(queue_handle, value + 100 + i)
+ for i in range(1000)
+ ]
+ with self.test_session() as sess:
+ value_0, _ = sess.run((value, enqueue_many_more))
+ self.assertEqual([0], value_0)
+ rest = []
+ for _ in range(1000):
+ rest.append(sess.run(value))
+ self.assertEquals([[100 + i] for i in range(1000)], sorted(rest))
+ # Going back to the original input.
+ value_1, _ = sess.run((value, enqueue_many_more))
+ self.assertEqual(1, value_1)
+ rest = []
+ for _ in range(1000):
+ rest.append(sess.run(value))
+ self.assertEquals([[100 + i + 1] for i in range(1000)], sorted(rest))
+ with self.assertRaisesOpError("End of sequence"):
+ sess.run(value)
+
+ def testEnqueueWithPrefetch(self):
+ dataset = dataset_ops.Dataset.from_tensor_slices([0])
+ dataset = dataset.apply(
+ tqd.prepend_from_queue_and_padded_batch_dataset(batch_size=1))
+ # Prefetching will request additional values before they are
+ # available to the queue.
+ dataset = dataset.prefetch(buffer_size=3)
+ iterator = dataset.make_one_shot_iterator()
+ queue_handle, value = iterator.get_next()
+ enqueue = tqd.enqueue_in_queue_dataset(queue_handle, value + 1)
+ with self.test_session() as sess:
+ i = 0
+ while i < 4:
+ received, _ = sess.run((value, enqueue))
+ if received.size > 0:
+ self.assertAllEqual([i], received)
+ i += 1
+ received_last = False
+ while True:
+ try:
+ received = sess.run(value)
+ if received.size > 0:
+ self.assertAllEqual([4], received)
+ received_last = True
+ except errors.OutOfRangeError:
+ break
+ self.assertTrue(received_last)
+
+ def testDatasetWithPaddedShapeSmallerThanInputFails(self):
+ dataset = dataset_ops.Dataset.from_tensor_slices([[0, 0, 0]]).repeat(None)
+ dataset = dataset.apply(
+ tqd.prepend_from_queue_and_padded_batch_dataset(
+ batch_size=1, padded_shapes=[2]))
+ iterator = dataset.make_one_shot_iterator()
+ _, value = iterator.get_next()
+ with self.test_session() as sess:
+ with self.assertRaisesOpError(
+ r"Incompatible input shapes at component 0 between "
+ r"input dataset this dataset: \[3\] vs. \[2\]"):
+ sess.run(value)
+
+ def testEnqueueWithIncompatibleInputsFailsWithInformativeError(self):
+ dataset = dataset_ops.Dataset.from_tensor_slices([0]).repeat(None)
+ dataset = dataset.apply(
+ tqd.prepend_from_queue_and_padded_batch_dataset(batch_size=1))
+ iterator = dataset.make_one_shot_iterator()
+ queue_handle, value = iterator.get_next()
+
+ enqueue_bad_structure = tqd.enqueue_in_queue_dataset(
+ queue_handle, (value, value))
+ enqueue_bad_dtype = tqd.enqueue_in_queue_dataset(queue_handle,
+ np.array(
+ [1.0],
+ dtype=np.float32))
+ enqueue_bad_shape_no_batch_dim = tqd.enqueue_in_queue_dataset(
+ queue_handle, ([1],))
+ enqueue_bad_shape = tqd.enqueue_in_queue_dataset(queue_handle,
+ np.array(
+ [[1]], dtype=np.int32))
+
+ with self.test_session() as sess:
+ with self.assertRaisesOpError(
+ "mismatched number of tensors. Queue expects 1 tensors but "
+ "tried to insert 2"):
+ sess.run(enqueue_bad_structure)
+ with self.assertRaisesOpError(r"Expected component 0 to have batched "
+ r"shape \[1,...\], but saw shape: \[\]"):
+ sess.run(enqueue_bad_shape_no_batch_dim)
+ with self.assertRaisesOpError(
+ r"mismatched shapes at component 0. Attempted to insert tensor "
+ r"with shape \[1\] but queue expected shape: \[\]"):
+ sess.run(enqueue_bad_shape)
+ with self.assertRaisesOpError(
+ r"mismatched dtypes at component 0. Attempted to insert tensor "
+ r"of type float but queue expected type: int32"):
+ sess.run(enqueue_bad_dtype)
+
+ def testEnqueueWithPaddedBatchFailsWithInformativeError(self):
+ dataset = dataset_ops.Dataset.from_tensor_slices([0, 1, 2])
+ dataset = dataset.apply(
+ tqd.prepend_from_queue_and_padded_batch_dataset(batch_size=1))
+ with self.assertRaisesRegexp(
+ TypeError, r"Unable to create padding for field of type 'variant'"):
+ dataset.padded_batch(batch_size=10, padded_shapes=[1])
+
+ def testOneEnqueueWithPadding(self):
+ dataset = dataset_ops.Dataset.from_tensor_slices([0, 2, 4, 6])
+ # Make a dataset of variable-length vectors and their lengths.
+ dataset = dataset.map(
+ lambda c: (c, c * array_ops.ones((c,), dtype=c.dtype)))
+ # Emit a queue we can prepend to, and counts/values as padded
+ # batch.
+ dataset = dataset.apply(
+ tqd.prepend_from_queue_and_padded_batch_dataset(batch_size=3))
+
+ iterator = dataset.make_one_shot_iterator()
+ queue, (count, padded_value) = iterator.get_next()
+
+ # Split the padded_value into two pieces: head and rest
+ rest_indices = array_ops.squeeze(array_ops.where(count > 2), axis=1)
+ bound = math_ops.minimum(2, math_ops.reduce_max(count))
+ value_head = padded_value[:, :bound]
+ count_rest = array_ops.gather(count - 2, rest_indices)
+ value_rest = array_ops.gather(padded_value, rest_indices)[:, bound:]
+ queue_rest = array_ops.gather(queue, rest_indices)
+ enqueue_rest_op = tqd.enqueue_in_queue_dataset(queue_rest,
+ (count_rest, value_rest))
+ with ops.control_dependencies([enqueue_rest_op]):
+ calc = array_ops.identity(value_head)
+
+ with self.test_session() as sess:
+ self.assertAllEqual([[0, 0], [2, 2], [4, 4]], sess.run(calc))
+ self.assertAllEqual([[4, 4], [6, 6]], sess.run(calc))
+ self.assertAllEqual([[6, 6]], sess.run(calc))
+ self.assertAllEqual([[6, 6]], sess.run(calc))
+ # Get some final batches due to prefetching.
+ for _ in range(3):
+ try:
+ self.assertAllEqual(
+ np.empty(shape=(0, 0), dtype=np.int32), sess.run(calc))
+ except errors.OutOfRangeError as e:
+ self.assertTrue(str(e).startswith("End of sequence"))
+
+ def testNonstandardPadding(self):
+ dataset = dataset_ops.Dataset.from_tensor_slices([0, 2, 4, 6])
+ # Make a dataset of variable-length vectors and their lengths.
+ dataset = dataset.map(
+ lambda c: (c, c * array_ops.ones((c,), dtype=c.dtype)))
+ # Emit a queue we can prepend to, and counts/values as padded
+ # batch.
+ dataset = dataset.apply(
+ tqd.prepend_from_queue_and_padded_batch_dataset(
+ batch_size=3, padding_values=(
+ 0,
+ -1,
+ )))
+
+ iterator = dataset.make_one_shot_iterator()
+ _, (unused_count, padded_value) = iterator.get_next()
+
+ with self.test_session() as sess:
+ self.assertAllEqual([[-1, -1, -1, -1], [2, 2, -1, -1], [4, 4, 4, 4]],
+ sess.run(padded_value))
+ self.assertAllEqual([[6] * 6], sess.run(padded_value))
+ with self.assertRaisesOpError("End of sequence"):
+ sess.run(padded_value)
+
+
+# TODO(ebrevdo): Figure out how to use run_core_tests to test state
+# saving of an iterator that's had some tensors enqueued into its queue.
+class PrependFromQueueAndPaddedBatchDatasetSerializationTest(
+ dataset_serialization_test_base.DatasetSerializationTestBase):
+
+ def testPrependFromQueueAndPaddedBatch(self):
+
+ def build_dataset(seq_lens):
+ return dataset_ops.Dataset.from_tensor_slices(seq_lens).map(
+ lambda x: array_ops.fill([x], x)).apply(
+ tqd.prepend_from_queue_and_padded_batch_dataset(batch_size=4))
+
+ seq_lens1 = np.random.randint(1, 20, size=(32,)).astype(np.int32)
+ seq_lens2 = np.random.randint(21, 40, size=(32,)).astype(np.int32)
+ self.run_core_tests(lambda: build_dataset(seq_lens1),
+ lambda: build_dataset(seq_lens2), 8)
+
+ def testPrependFromQueueAndPaddedBatchNonDefaultPadding(self):
+
+ def build_dataset(seq_lens):
+
+ def fill_tuple(x):
+ filled = array_ops.fill([x], x)
+ return (filled, string_ops.as_string(filled))
+
+ padded_shape = [-1]
+ return dataset_ops.Dataset.from_tensor_slices(seq_lens).map(
+ fill_tuple).apply(
+ tqd.prepend_from_queue_and_padded_batch_dataset(
+ batch_size=4,
+ padded_shapes=(padded_shape, padded_shape),
+ padding_values=(-1, "<end>")))
+
+ seq_lens1 = np.random.randint(1, 20, size=(32,)).astype(np.int32)
+ seq_lens2 = np.random.randint(21, 40, size=(32,)).astype(np.int32)
+ self.run_core_tests(lambda: build_dataset(seq_lens1),
+ lambda: build_dataset(seq_lens2), 8)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/core/api_def/base_api/api_def_EnqueueInQueueDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_EnqueueInQueueDataset.pbtxt
new file mode 100644
index 0000000000..9722f5ede3
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_EnqueueInQueueDataset.pbtxt
@@ -0,0 +1,3 @@
+op {
+ graph_op_name: "EnqueueInQueueDataset"
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_PrependFromQueueAndPaddedBatchDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_PrependFromQueueAndPaddedBatchDataset.pbtxt
new file mode 100644
index 0000000000..d4549340fa
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_PrependFromQueueAndPaddedBatchDataset.pbtxt
@@ -0,0 +1,3 @@
+op {
+ graph_op_name: "PrependFromQueueAndPaddedBatchDataset"
+}
diff --git a/tensorflow/core/common_runtime/graph_execution_state.cc b/tensorflow/core/common_runtime/graph_execution_state.cc
index 3b309e915c..33a5d60eb7 100644
--- a/tensorflow/core/common_runtime/graph_execution_state.cc
+++ b/tensorflow/core/common_runtime/graph_execution_state.cc
@@ -340,8 +340,11 @@ Status GraphExecutionState::OptimizeGraph(
std::unordered_map<string, DeviceProperties> device_map;
Device* cpu_device = nullptr;
for (const auto& device : device_set_->devices()) {
- device_map[device->name()] =
- grappler::GetDeviceInfo(device->parsed_name());
+ DeviceProperties props = grappler::GetDeviceInfo(device->parsed_name());
+ if (props.type() == "UNKNOWN") {
+ continue;
+ }
+ device_map[device->name()] = props;
if (device->parsed_name().id == 0 &&
StringPiece(device->parsed_name().type) == "CPU" &&
device->GetAllocator(AllocatorAttributes()) != nullptr) {
diff --git a/tensorflow/core/distributed_runtime/tensor_coding.cc b/tensorflow/core/distributed_runtime/tensor_coding.cc
index fe2d1a1293..34a4013547 100644
--- a/tensorflow/core/distributed_runtime/tensor_coding.cc
+++ b/tensorflow/core/distributed_runtime/tensor_coding.cc
@@ -81,7 +81,7 @@ void TensorResponse::InitPartial(const RecvTensorResponse& response) {
Status TensorResponse::ParseFrom(Source* source) {
if (!on_host_) {
protobuf::io::CodedInputStream input(source->contents());
- input.SetTotalBytesLimit(INT_MAX, INT_MAX); // Unlimited
+ input.SetTotalBytesLimit(INT_MAX); // Unlimited
// Pre-parse into local storage, then delegate to device.
if (!meta_.ParseFromCodedStream(&input) || !input.ConsumedEntireMessage()) {
@@ -217,7 +217,7 @@ bool TensorResponse::ParseTensorSubmessage(
bool TensorResponse::ParseFast(Source* source) {
protobuf::io::CodedInputStream input(source->contents());
- input.SetTotalBytesLimit(INT_MAX, INT_MAX); // Unlimited
+ input.SetTotalBytesLimit(INT_MAX); // Unlimited
while (true) {
auto p = input.ReadTagWithCutoff(127);
int tag = GetTagFieldNumber(p.first);
diff --git a/tensorflow/core/grappler/costs/virtual_scheduler.cc b/tensorflow/core/grappler/costs/virtual_scheduler.cc
index d7d07ee7a5..020492a3e9 100644
--- a/tensorflow/core/grappler/costs/virtual_scheduler.cc
+++ b/tensorflow/core/grappler/costs/virtual_scheduler.cc
@@ -323,8 +323,13 @@ Status VirtualScheduler::Init() {
}
// Get the nodes that would run to output fetch_nodes.
+ bool ill_formed = false;
std::vector<const NodeDef*> nodes =
- ComputeTransitiveFanin(graph, fetch_nodes);
+ ComputeTransitiveFanin(graph, fetch_nodes, &ill_formed);
+ if (ill_formed) {
+ return errors::InvalidArgument(
+ "Ill formed graph or invalid set of fetch nodes specified");
+ }
// TODO(dyoon): this is a bit inefficient as name_to_node is already built in
// ComputeTransitiveFanin().
diff --git a/tensorflow/core/grappler/graph_view.h b/tensorflow/core/grappler/graph_view.h
index f4e2de75a6..173ce9c09c 100644
--- a/tensorflow/core/grappler/graph_view.h
+++ b/tensorflow/core/grappler/graph_view.h
@@ -46,6 +46,7 @@ class GraphView {
};
explicit GraphView(GraphDef* graph);
+ GraphDef* GetGraph() const { return graph_; }
NodeDef* GetNode(const string& node_name) const;
// Get the specified input port. Note that the special '-1' port_id can be
// used to access the controlling nodes (i.e. the nodes connected to node_name
diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD
index 68de03e81c..8b9885e4c1 100644
--- a/tensorflow/core/grappler/optimizers/BUILD
+++ b/tensorflow/core/grappler/optimizers/BUILD
@@ -289,6 +289,7 @@ cc_library(
"//tensorflow/core/grappler/costs:graph_memory",
"//tensorflow/core/grappler/costs:graph_properties",
"//tensorflow/core/grappler/utils:topological_sort",
+ "//tensorflow/core/grappler/utils:traversal",
],
)
diff --git a/tensorflow/core/grappler/optimizers/memory_optimizer.cc b/tensorflow/core/grappler/optimizers/memory_optimizer.cc
index 6f95a00fa3..ffa03db262 100644
--- a/tensorflow/core/grappler/optimizers/memory_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/memory_optimizer.cc
@@ -35,6 +35,7 @@ limitations under the License.
#include "tensorflow/core/grappler/optimizers/static_schedule.h"
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/grappler/utils/topological_sort.h"
+#include "tensorflow/core/grappler/utils/traversal.h"
#include "tensorflow/core/protobuf/rewriter_config.pb.h"
namespace tensorflow {
@@ -497,7 +498,7 @@ bool SchedulingPass(Cluster* cluster, GrapplerItem* item) {
if (!IsAddN(node)) {
continue;
}
- // There is nothing to gain by optimizing nodes with 2 inputs of fewer.
+ // There is nothing to gain by optimizing nodes with 2 or fewer inputs.
if (view.NumFanins(node, false) <= 2) {
continue;
}
@@ -559,6 +560,54 @@ bool SchedulingPass(Cluster* cluster, GrapplerItem* item) {
VLOG(1) << "Missing properties for " << node->name();
continue;
}
+
+ // Compute a topological ordering for the node fanin.
+ std::unordered_map<NodeDef*, int> topo_order;
+ ReverseDfs(view, {node}, nullptr,
+ [&topo_order](NodeDef* n) {
+ int topo_index = topo_order.size();
+ topo_order[n] = topo_index;
+ },
+ nullptr);
+
+ std::vector<int> input_topo_index;
+
+ for (int i = 0; i < node->input_size(); ++i) {
+ const string& input = node->input(i);
+ const string node_name = NodeName(input);
+ NodeDef* node = view.GetNode(node_name);
+ input_topo_index.push_back(topo_order.at(node));
+ }
+ int min_input_topo_index = INT_MAX;
+ int min_input_id = -1;
+ for (int i = 0; i < node->input_size(); ++i) {
+ if (IsControlInput(node->input(i))) {
+ // control inputs are always last.
+ break;
+ }
+ const int current = input_topo_index[i];
+ if (current < min_input_topo_index) {
+ min_input_topo_index = current;
+ min_input_id = i;
+ }
+ }
+ CHECK_LE(0, min_input_id);
+ std::vector<string> pre_ctrl_deps;
+ std::vector<string> post_ctrl_deps;
+ for (int i = node->input_size() - 1; i >= 0; --i) {
+ if (!IsControlInput(node->input(i))) {
+ // control inputs are always last.
+ break;
+ }
+ if (input_topo_index[i] < min_input_topo_index) {
+ // These control dependencies can be executed before the node.
+ pre_ctrl_deps.push_back(node->input(i));
+ } else {
+ // These control dependencies should be executed after the node.
+ post_ctrl_deps.push_back(node->input(i));
+ }
+ }
+
const TensorShapeProto& shape =
properties.GetOutputProperties(node->name())[0].shape();
DataType dtype = node->attr().at("T").type();
@@ -573,13 +622,19 @@ bool SchedulingPass(Cluster* cluster, GrapplerItem* item) {
*(*tmp_var->mutable_attr())["shape"].mutable_shape() = shape;
(*tmp_var->mutable_attr())["var_name"].set_s(tmp_var->name());
+ for (const string& ctrl_dep : pre_ctrl_deps) {
+ *tmp_var->add_input() = ctrl_dep;
+ }
+ *tmp_var->add_input() =
+ AsControlDependency(NodeName(node->input(min_input_id)));
+
// Initialize it to zero
NodeDef* zeros = item->graph.add_node();
zeros->set_name(strings::StrCat(node->name(), "/tmp_var_zeros"));
zeros->set_op("ZerosLike");
zeros->set_device(device);
(*zeros->mutable_attr())["T"].set_type(dtype);
- *zeros->add_input() = node->input(0);
+ *zeros->add_input() = node->input(min_input_id);
NodeDef* initialize = item->graph.add_node();
initialize->set_name(strings::StrCat(node->name(), "/tmp_var_initializer"));
@@ -593,9 +648,7 @@ bool SchedulingPass(Cluster* cluster, GrapplerItem* item) {
std::vector<NodeDef*> accumulates;
for (int i = 0; i < node->input_size(); ++i) {
const string& input = node->input(i);
- if (IsControlInput(input)) {
- *zeros->add_input() = input;
- } else {
+ if (!IsControlInput(input)) {
NodeDef* accumulate = item->graph.add_node();
accumulate->set_name(
strings::StrCat(node->name(), "/tmp_var_accum_", i));
@@ -618,6 +671,10 @@ bool SchedulingPass(Cluster* cluster, GrapplerItem* item) {
for (const NodeDef* accum : accumulates) {
*node->add_input() = AsControlDependency(accum->name());
}
+ for (const string& ctrl_dep : post_ctrl_deps) {
+ *node->add_input() = ctrl_dep;
+ }
+
updated_graph = true;
}
diff --git a/tensorflow/core/grappler/utils/BUILD b/tensorflow/core/grappler/utils/BUILD
index 534f7a063f..137d51790d 100644
--- a/tensorflow/core/grappler/utils/BUILD
+++ b/tensorflow/core/grappler/utils/BUILD
@@ -99,3 +99,29 @@ tf_cc_test(
"//tensorflow/core:test_main",
],
)
+
+cc_library(
+ name = "traversal",
+ srcs = ["traversal.cc"],
+ hdrs = ["traversal.h"],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core/grappler:graph_view",
+ "//tensorflow/core/grappler:op_types",
+ "//tensorflow/core/grappler:utils",
+ ],
+)
+
+tf_cc_test(
+ name = "traversal_test",
+ srcs = ["traversal_test.cc"],
+ deps = [
+ ":traversal",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
+)
diff --git a/tensorflow/core/grappler/utils/traversal.cc b/tensorflow/core/grappler/utils/traversal.cc
new file mode 100644
index 0000000000..f44f53c4e6
--- /dev/null
+++ b/tensorflow/core/grappler/utils/traversal.cc
@@ -0,0 +1,80 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/utils/traversal.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+
+namespace tensorflow {
+namespace grappler {
+
+void ReverseDfs(const GraphView& graph_view, const std::vector<NodeDef*>& from,
+ const std::function<void(NodeDef*)>& pre_order,
+ const std::function<void(NodeDef*)>& post_order,
+ const std::function<void(NodeDef*, NodeDef*)>& on_back_edge) {
+ // Stack of work to do.
+ struct StackElem {
+ NodeDef* node;
+ bool children_visited;
+ NodeDef* src;
+ };
+ std::vector<StackElem> stack;
+
+ stack.reserve(from.size());
+ for (NodeDef* node : from) {
+ stack.push_back(StackElem{node, false});
+ }
+
+ enum NodeState { NOT_VISITED = 0, VISITING = 1, DONE = 2 };
+ std::unordered_map<NodeDef*, NodeState> node_state;
+ while (!stack.empty()) {
+ StackElem w = stack.back();
+ stack.pop_back();
+
+ if (w.children_visited) {
+ // We've processed all the children of this node
+ node_state[w.node] = DONE;
+ if (post_order) {
+ post_order(w.node);
+ }
+ continue;
+ }
+
+ auto& rslt = node_state[w.node];
+ if (rslt == DONE) {
+ continue;
+ } else if (rslt == VISITING) {
+ // Loop detected
+ if (on_back_edge) {
+ on_back_edge(w.src, w.node);
+ }
+ continue;
+ }
+ rslt = VISITING;
+ if (pre_order) {
+ pre_order(w.node);
+ }
+
+ // Enqueue the node again with the children_visited flag set to true.
+ stack.push_back(StackElem{w.node, true, w.src});
+
+ // Now enqueu the node children.
+ for (const auto fanin : graph_view.GetFanins(*w.node, true)) {
+ stack.push_back(StackElem{fanin.node, false, w.node});
+ }
+ }
+}
+
+} // namespace grappler
+} // namespace tensorflow
diff --git a/tensorflow/core/grappler/utils/traversal.h b/tensorflow/core/grappler/utils/traversal.h
new file mode 100644
index 0000000000..bb3fa090e8
--- /dev/null
+++ b/tensorflow/core/grappler/utils/traversal.h
@@ -0,0 +1,39 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_GRAPPLER_UTILS_TRAVERSAL_H_
+#define TENSORFLOW_CORE_GRAPPLER_UTILS_TRAVERSAL_H_
+
+#include <functional>
+#include "tensorflow/core/grappler/graph_view.h"
+
+namespace tensorflow {
+namespace grappler {
+
+// Traverse the graph in reverse dfs order, starting from the list of nodes
+// specified in the 'from' argument. The pre_order and post_order functors will
+// be called on each reachable node (including the 'from' nodes) in pre and post
+// order. If loops are found, the on_back_edge functor will be called on the
+// corresponding back edges. Moreover, the pre and post order will assume that
+// these back edges will be cut.
+void ReverseDfs(const GraphView& graph_view, const std::vector<NodeDef*>& from,
+ const std::function<void(NodeDef*)>& pre_order,
+ const std::function<void(NodeDef*)>& post_order,
+ const std::function<void(NodeDef*, NodeDef*)>& on_back_edge);
+
+} // namespace grappler
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_GRAPPLER_UTILS_TRAVERSAL_H_
diff --git a/tensorflow/core/grappler/utils/traversal_test.cc b/tensorflow/core/grappler/utils/traversal_test.cc
new file mode 100644
index 0000000000..cc68bd1a96
--- /dev/null
+++ b/tensorflow/core/grappler/utils/traversal_test.cc
@@ -0,0 +1,101 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/utils/traversal.h"
+//#include "tensorflow/core/framework/node_def.pb.h"
+//#include "tensorflow/core/lib/core/status_test_util.h"
+//#include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace {
+
+class TraversalTest : public ::testing::Test {
+ protected:
+ static NodeDef CreateNode(const string& name,
+ const std::vector<string>& inputs) {
+ return CreateNode(name, "", inputs);
+ }
+ static NodeDef CreateNode(const string& name, const string& op,
+ const std::vector<string>& inputs) {
+ NodeDef node;
+ node.set_name(name);
+ if (!op.empty()) {
+ node.set_op(op);
+ }
+ for (const string& input : inputs) {
+ node.add_input(input);
+ }
+ return node;
+ }
+};
+
+TEST_F(TraversalTest, ReverseDfsNoLoop) {
+ GraphDef graph;
+ *graph.add_node() = CreateNode("2", {"5"});
+ *graph.add_node() = CreateNode("0", {"5", "4"});
+ *graph.add_node() = CreateNode("1", {"4", "3"});
+ *graph.add_node() = CreateNode("3", {"2"});
+ *graph.add_node() = CreateNode("5", {});
+ *graph.add_node() = CreateNode("4", {});
+
+ std::vector<NodeDef*> start_nodes = {graph.mutable_node(1),
+ graph.mutable_node(2)};
+ std::vector<string> pre_order;
+ std::vector<string> post_order;
+ bool found_back_edge = false;
+ ReverseDfs(
+ GraphView(&graph), start_nodes,
+ [&pre_order](NodeDef* n) { pre_order.push_back(n->name()); },
+ [&post_order](NodeDef* n) { post_order.push_back(n->name()); },
+ [&found_back_edge](NodeDef*, NodeDef*) { found_back_edge = true; });
+
+ EXPECT_EQ(std::vector<string>({"1", "4", "3", "2", "5", "0"}), pre_order);
+ EXPECT_EQ(std::vector<string>({"4", "5", "2", "3", "1", "0"}), post_order);
+ EXPECT_FALSE(found_back_edge);
+}
+
+TEST_F(TraversalTest, ReverseDfsWithLoop) {
+ GraphDef graph;
+ // Create a loop
+ *graph.add_node() = CreateNode("2", "Merge", {"1", "5"});
+ *graph.add_node() = CreateNode("3", "Switch", {"2"});
+ *graph.add_node() = CreateNode("4", "Identity", {"3"});
+ *graph.add_node() = CreateNode("5", "NextIteration", {"4"});
+ *graph.add_node() = CreateNode("1", "Enter", {});
+ *graph.add_node() = CreateNode("6", "Exit", {"3"});
+
+ std::vector<NodeDef*> start_nodes = {graph.mutable_node(5)};
+ std::vector<string> pre_order;
+ std::vector<string> post_order;
+ std::vector<string> back_edges;
+ ReverseDfs(
+ GraphView(&graph), start_nodes,
+ [&pre_order](NodeDef* n) { pre_order.push_back(n->name()); },
+ [&post_order](NodeDef* n) { post_order.push_back(n->name()); },
+ [&back_edges](NodeDef* src, NodeDef* dst) {
+ back_edges.push_back(strings::StrCat(src->name(), "->", dst->name()));
+ });
+
+ EXPECT_EQ(std::vector<string>({"6", "3", "2", "1", "5", "4"}), pre_order);
+ EXPECT_EQ(std::vector<string>({"1", "4", "5", "2", "3", "6"}), post_order);
+ EXPECT_EQ(std::vector<string>({"4->3"}), back_edges);
+}
+
+} // namespace
+} // namespace grappler
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/batch_util.cc b/tensorflow/core/kernels/batch_util.cc
index 7f2df95e2d..87d455faa7 100644
--- a/tensorflow/core/kernels/batch_util.cc
+++ b/tensorflow/core/kernels/batch_util.cc
@@ -19,6 +19,8 @@ limitations under the License.
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/core/errors.h"
+#define TF_CALL_DATASET_TYPES(m) TF_CALL_ALL_TYPES(m) TF_CALL_QUANTIZED_TYPES(m)
+
namespace tensorflow {
namespace batch_util {
@@ -61,6 +63,21 @@ Status HandleElementToSlice<string>(Tensor element, Tensor* parent, int64 index,
return Status::OK();
}
+template <>
+Status HandleElementToSlice<Variant>(Tensor element, Tensor* parent,
+ int64 index, bool can_move) {
+ auto parent_as_matrix = parent->flat_outer_dims<Variant>();
+ auto element_flat = element.flat<Variant>();
+ if (can_move) {
+ for (int64 i = 0; i < element.NumElements(); ++i) {
+ parent_as_matrix(index, i) = std::move(element_flat(i));
+ }
+ } else {
+ parent_as_matrix.chip(index, 0) = element_flat;
+ }
+ return Status::OK();
+}
+
// TODO(jsimsa): Add HandleElementToSlice<variant> specialization that moves
// the data when possible.
@@ -115,5 +132,101 @@ Status CopySliceToElement(const Tensor& parent, Tensor* element, int64 index) {
}
}
+// The following five functions are copied from padding_fifo_queue.cc.
+// TODO(mrry): Reconcile these functions with the similar methods in the
+// queue implementation.
+Status ValidateElementToLargerSlice(const Tensor& element, Tensor* parent) {
+ DCHECK_NE(parent->dim_size(0), 0);
+ if (element.NumElements() > (parent->NumElements() / parent->dim_size(0))) {
+ TensorShape chip_shape = parent->shape();
+ chip_shape.RemoveDim(0);
+ return errors::Internal(
+ "HandleElementToLargerSlice Cannot copy slice: number of entries in "
+ "element is greater than number of elements in parent slice. ",
+ "Shapes are: [element]: ", element.shape().DebugString(),
+ ", [parent slice]: ", chip_shape.DebugString());
+ }
+ return Status::OK();
+}
+
+template <typename T, int NDIMS>
+Status HandleElementToLargerSlice(const Tensor& element, Tensor* parent,
+ int index) {
+ TF_RETURN_IF_ERROR(ValidateElementToLargerSlice(element, parent));
+ if (element.NumElements() == 0) {
+ return Status::OK();
+ }
+ auto element_t = element.tensor<T, NDIMS>();
+ auto parent_t = parent->tensor<T, NDIMS + 1>();
+ Eigen::DSizes<Eigen::DenseIndex, NDIMS + 1> slice_indices;
+ slice_indices[0] = index;
+ Eigen::DSizes<Eigen::DenseIndex, NDIMS + 1> slice_size;
+ slice_size[0] = 1;
+ for (size_t i = 1; i < slice_size.size(); ++i) {
+ slice_size[i] = element_t.dimension(i - 1);
+ }
+ parent_t.slice(slice_indices, slice_size) = element_t.reshape(slice_size);
+ return Status::OK();
+}
+
+template <int NDIMS>
+Status HandleElementToLargerSliceWithRank(const Tensor& element, Tensor* parent,
+ int index) {
+#define HANDLE_TYPE(T) \
+ case DataTypeToEnum<T>::value: { \
+ return HandleElementToLargerSlice<T, NDIMS>(element, parent, index); \
+ }
+
+ switch (element.dtype()) {
+ TF_CALL_DATASET_TYPES(HANDLE_TYPE);
+#undef HANDLE_TYPE
+ default:
+ return errors::Unimplemented(
+ "HandleElementToLargerSliceWithRank Unhandled data type: ",
+ element.dtype());
+ }
+}
+
+Status CopyElementToLargerSlice(const Tensor& element, Tensor* parent,
+ int index) {
+ if (parent->dims() != element.dims() + 1) {
+ return errors::Internal(
+ "Mismatched ranks. Element's rank is: ", element.dims(),
+ " but element is meant to be a slice in output Tensor having rank: ",
+ parent->dims(), " (should be: ", element.dims() + 1, ")");
+ }
+
+#define HANDLE_DIMS(NDIMS) \
+ case NDIMS: { \
+ TF_RETURN_IF_ERROR( \
+ HandleElementToLargerSliceWithRank<NDIMS>(element, parent, index)); \
+ return Status::OK(); \
+ }
+
+ switch (element.dims()) {
+ HANDLE_DIMS(0);
+ HANDLE_DIMS(1);
+ HANDLE_DIMS(2);
+ HANDLE_DIMS(3);
+ HANDLE_DIMS(4);
+#undef HANDLE_DIMS
+ default:
+ return errors::Unimplemented("CopyElementToLargerSlice Unhandled rank: ",
+ element.dims());
+ }
+}
+
+Status SetElementZero(Tensor* element, const Tensor& padding) {
+#define HANDLE_TYPE(T) \
+ if (element->dtype() == DataTypeToEnum<T>::value) { \
+ element->flat<T>().setConstant(padding.scalar<T>()()); \
+ return Status::OK(); \
+ }
+ TF_CALL_DATASET_TYPES(HANDLE_TYPE);
+#undef HANDLE_TYPE
+ return errors::Unimplemented("SetElementZero Unhandled data type: ",
+ element->dtype());
+}
+
} // namespace batch_util
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/batch_util.h b/tensorflow/core/kernels/batch_util.h
index 0d634ae7b0..a47bf1935d 100644
--- a/tensorflow/core/kernels/batch_util.h
+++ b/tensorflow/core/kernels/batch_util.h
@@ -32,6 +32,16 @@ Status CopyElementToSlice(Tensor element, Tensor* parent, int64 index);
// Copies the index^th slice of parent (in the 0th dimension) into element.
Status CopySliceToElement(const Tensor& parent, Tensor* element, int64 index);
+// Zero-initializes the tensor `element` using the scalar stored in `padding`.
+// Both `element` and `padding` must have matching `dtype`.
+Status SetElementZero(Tensor* element, const Tensor& padding);
+
+// Copies `element` into a (0th dimension) slice of `parent`, assuming
+// the shape of `element` is strictly not larger along any axis than a
+// slice.
+Status CopyElementToLargerSlice(const Tensor& element, Tensor* parent,
+ int index);
+
} // namespace batch_util
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/BUILD b/tensorflow/core/kernels/data/BUILD
index cdb4023861..c4e21257ff 100644
--- a/tensorflow/core/kernels/data/BUILD
+++ b/tensorflow/core/kernels/data/BUILD
@@ -121,6 +121,7 @@ tf_kernel_library(
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
+ "//tensorflow/core/kernels:batch_util",
],
)
@@ -402,6 +403,19 @@ tf_kernel_library(
)
tf_kernel_library(
+ name = "tensor_queue_dataset_op",
+ srcs = ["tensor_queue_dataset_op.cc"],
+ deps = [
+ ":dataset",
+ "//tensorflow/core:dataset_ops_op_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core/kernels:batch_util",
+ ],
+)
+
+tf_kernel_library(
name = "tensor_slice_dataset_op",
srcs = ["tensor_slice_dataset_op.cc"],
deps = [
@@ -539,6 +553,7 @@ tf_kernel_library(
":stats_dataset_ops",
":take_dataset_op",
":tensor_dataset_op",
+ ":tensor_queue_dataset_op",
":tensor_slice_dataset_op",
":unique_dataset_op",
":zip_dataset_op",
diff --git a/tensorflow/core/kernels/data/iterator_ops.cc b/tensorflow/core/kernels/data/iterator_ops.cc
index b37bd672ad..dd5f4a4554 100644
--- a/tensorflow/core/kernels/data/iterator_ops.cc
+++ b/tensorflow/core/kernels/data/iterator_ops.cc
@@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/graph_runner.h"
+#include "tensorflow/core/common_runtime/renamed_device.h"
#include "tensorflow/core/common_runtime/threadpool_device.h"
#include "tensorflow/core/framework/iterator.pb.h"
#include "tensorflow/core/framework/partial_tensor_shape.h"
@@ -516,15 +517,32 @@ class IteratorHandleOp : public OpKernel {
return Status::OK();
}
+ template <typename To, typename From> // use like this: down_cast<T*>(foo);
+ static inline To down_cast(From* f) { // so we only accept pointers
+ static_assert(
+ (std::is_base_of<From, typename std::remove_pointer<To>::type>::value),
+ "target type not derived from source type");
+
+ // We skip the assert and hence the dynamic_cast if RTTI is disabled.
+#if !defined(__GNUC__) || defined(__GXX_RTTI)
+ // Uses RTTI in dbg and fastbuild. asserts are disabled in opt builds.
+ assert(f == nullptr || dynamic_cast<To>(f) != nullptr);
+#endif // !defined(__GNUC__) || defined(__GXX_RTTI)
+ return static_cast<To>(f);
+ }
+
FunctionLibraryRuntime* CreatePrivateFLR(
OpKernelContext* ctx, std::unique_ptr<DeviceMgr>* device_mgr,
std::unique_ptr<FunctionLibraryDefinition>* flib_def,
std::unique_ptr<ProcessFunctionLibraryRuntime>* pflr) {
- Device* device = new ThreadPoolDevice(
- SessionOptions(), ctx->device()->attributes().name(), Bytes(256 << 20),
- DeviceLocality(), cpu_allocator());
-
- device_mgr->reset(new DeviceMgr({device}));
+ // Wrap the existing device in order to see any captured resources
+ // in its resource manager. The existing device will outlive the
+ // IteratorResource, because we are storing the IteratorResource
+ // in that device's resourc manager.
+ Device* wrapped_device = RenamedDevice::NewRenamedDevice(
+ ctx->device()->name(), down_cast<Device*>(ctx->device()),
+ false /* owns_underlying */, false /* isolate_session_state */);
+ device_mgr->reset(new DeviceMgr({wrapped_device}));
flib_def->reset(new FunctionLibraryDefinition(
*ctx->function_library()->GetFunctionLibraryDefinition()));
pflr->reset(new ProcessFunctionLibraryRuntime(
@@ -532,7 +550,7 @@ class IteratorHandleOp : public OpKernel {
{} /* TODO(mrry): OptimizerOptions? */,
nullptr /* TODO(mrry): ClusterFLR */));
- return (*pflr)->GetFLR(device->name());
+ return (*pflr)->GetFLR(ctx->device()->name());
}
mutex mu_;
diff --git a/tensorflow/core/kernels/data/padded_batch_dataset_op.cc b/tensorflow/core/kernels/data/padded_batch_dataset_op.cc
index 4fe4e8e294..cfb4efda9a 100644
--- a/tensorflow/core/kernels/data/padded_batch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/padded_batch_dataset_op.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_util.h"
+#include "tensorflow/core/kernels/batch_util.h"
#include "tensorflow/core/kernels/data/dataset.h"
namespace tensorflow {
@@ -24,102 +25,6 @@ namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level
// description of the following op.
-// The following five functions are copied from padding_fifo_queue.cc.
-// TODO(mrry): Reconcile these functions with the similar methods in the
-// queue implementation.
-Status ValidateElementToLargerSlice(const Tensor& element, Tensor* parent) {
- DCHECK_NE(parent->dim_size(0), 0);
- if (element.NumElements() > (parent->NumElements() / parent->dim_size(0))) {
- TensorShape chip_shape = parent->shape();
- chip_shape.RemoveDim(0);
- return errors::Internal(
- "HandleElementToLargerSlice Cannot copy slice: number of entries in "
- "element is greater than number of elements in parent slice. ",
- "Shapes are: [element]: ", element.shape().DebugString(),
- ", [parent slice]: ", chip_shape.DebugString());
- }
- return Status::OK();
-}
-
-template <typename T, int NDIMS>
-Status HandleElementToLargerSlice(const Tensor& element, Tensor* parent,
- int index) {
- TF_RETURN_IF_ERROR(ValidateElementToLargerSlice(element, parent));
- if (element.NumElements() == 0) {
- return Status::OK();
- }
- auto element_t = element.tensor<T, NDIMS>();
- auto parent_t = parent->tensor<T, NDIMS + 1>();
- Eigen::DSizes<Eigen::DenseIndex, NDIMS + 1> slice_indices;
- slice_indices[0] = index;
- Eigen::DSizes<Eigen::DenseIndex, NDIMS + 1> slice_size;
- slice_size[0] = 1;
- for (size_t i = 1; i < slice_size.size(); ++i) {
- slice_size[i] = element_t.dimension(i - 1);
- }
- parent_t.slice(slice_indices, slice_size) = element_t.reshape(slice_size);
- return Status::OK();
-}
-
-template <int NDIMS>
-Status HandleElementToLargerSliceWithRank(const Tensor& element, Tensor* parent,
- int index) {
-#define HANDLE_TYPE(T) \
- case DataTypeToEnum<T>::value: { \
- return HandleElementToLargerSlice<T, NDIMS>(element, parent, index); \
- }
-
- switch (element.dtype()) {
- TF_CALL_DATASET_TYPES(HANDLE_TYPE);
-#undef HANDLE_TYPE
- default:
- return errors::Unimplemented(
- "HandleElementToLargerSliceWithRank Unhandled data type: ",
- element.dtype());
- }
-}
-
-Status CopyElementToLargerSlice(const Tensor& element, Tensor* parent,
- int index) {
- if (parent->dims() != element.dims() + 1) {
- return errors::Internal(
- "Mismatched ranks. Element's rank is: ", element.dims(),
- " but element is meant to be a slice in output Tensor having rank: ",
- parent->dims(), " (should be: ", element.dims() + 1, ")");
- }
-
-#define HANDLE_DIMS(NDIMS) \
- case NDIMS: { \
- TF_RETURN_IF_ERROR( \
- HandleElementToLargerSliceWithRank<NDIMS>(element, parent, index)); \
- return Status::OK(); \
- }
-
- switch (element.dims()) {
- HANDLE_DIMS(0);
- HANDLE_DIMS(1);
- HANDLE_DIMS(2);
- HANDLE_DIMS(3);
- HANDLE_DIMS(4);
-#undef HANDLE_DIMS
- default:
- return errors::Unimplemented("CopyElementToLargerSlice Unhandled rank: ",
- element.dims());
- }
-}
-
-Status SetElementZero(Tensor* element, const Tensor& padding) {
-#define HANDLE_TYPE(T) \
- if (element->dtype() == DataTypeToEnum<T>::value) { \
- element->flat<T>().setConstant(padding.scalar<T>()()); \
- return Status::OK(); \
- }
- TF_CALL_DATASET_TYPES(HANDLE_TYPE);
-#undef HANDLE_TYPE
- return errors::Unimplemented("SetElementZero Unhandled data type: ",
- element->dtype());
-}
-
class PaddedBatchDatasetOp : public UnaryDatasetOpKernel {
public:
explicit PaddedBatchDatasetOp(OpKernelConstruction* ctx)
@@ -379,17 +284,24 @@ class PaddedBatchDatasetOp : public UnaryDatasetOpKernel {
Tensor batch_component(ctx->allocator({}),
output_dtypes()[component_index],
batch_component_shape);
- TF_RETURN_IF_ERROR(SetElementZero(
+ TF_RETURN_IF_ERROR(batch_util::SetElementZero(
&batch_component, dataset()->padding_values_[component_index]));
// Build the output tuple component by copying one slice
// from each input element in the batch.
+ TensorShape component_shape({});
+ for (int i = 1; i < batch_component_shape.dims(); ++i) {
+ component_shape.AddDim(batch_component_shape.dim_size(i));
+ }
for (int64 i = 0; i < num_batch_elements; ++i) {
- TF_RETURN_IF_ERROR(ValidateElementToLargerSlice(
- batch_elements[i][component_index], &batch_component));
-
- TF_RETURN_IF_ERROR(CopyElementToLargerSlice(
- batch_elements[i][component_index], &batch_component, i));
+ // Take the fast path if possible.
+ if (batch_elements[i][component_index].shape() == component_shape) {
+ TF_RETURN_IF_ERROR(batch_util::CopyElementToSlice(
+ batch_elements[i][component_index], &batch_component, i));
+ } else {
+ TF_RETURN_IF_ERROR(batch_util::CopyElementToLargerSlice(
+ batch_elements[i][component_index], &batch_component, i));
+ }
}
out_tensors->push_back(std::move(batch_component));
}
diff --git a/tensorflow/core/kernels/data/tensor_queue_dataset_op.cc b/tensorflow/core/kernels/data/tensor_queue_dataset_op.cc
new file mode 100644
index 0000000000..ff412a4671
--- /dev/null
+++ b/tensorflow/core/kernels/data/tensor_queue_dataset_op.cc
@@ -0,0 +1,646 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <deque>
+
+#include "tensorflow/core/framework/partial_tensor_shape.h"
+#include "tensorflow/core/framework/resource_mgr.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/variant.h"
+#include "tensorflow/core/framework/variant_encode_decode.h"
+#include "tensorflow/core/kernels/batch_util.h"
+#include "tensorflow/core/kernels/data/dataset.h"
+
+namespace tensorflow {
+
+namespace {
+
+bool IsGreaterEqualToOrCompatibleWith(const PartialTensorShape& a,
+ const PartialTensorShape& b) {
+ // Returns true if dims[a] >= dims[b], or are compatible.
+ if (a.unknown_rank()) return true;
+ if (a.dims() != b.dims()) return false;
+ for (int d = 0; d < a.dims(); ++d) {
+ if (a.dim_size(d) == -1 || b.dim_size(d) == -1) continue;
+ if (a.dim_size(d) < b.dim_size(d)) return false;
+ }
+ return true;
+}
+
+DataTypeVector PrependQueueType(const DataTypeVector& dtypes) {
+ DataTypeVector out;
+ out.reserve(dtypes.size() + 1);
+ out.push_back(DT_VARIANT); // The queue component.
+ for (const DataType& d : dtypes) out.push_back(d);
+ return out;
+}
+
+std::vector<PartialTensorShape> PrependQueueShapeWithBatch(
+ const std::vector<PartialTensorShape>& shapes) {
+ std::vector<PartialTensorShape> out;
+ out.reserve(shapes.size() + 1);
+ out.emplace_back(PartialTensorShape({-1})); // The queue component.
+ for (PartialTensorShape s : shapes) {
+ s.InsertDim(0, -1); // Unknown batch size.
+ out.push_back(std::move(s));
+ }
+ return out;
+}
+
+class EnqueueInQueueDatasetOp;
+
+class PrependFromQueueAndPaddedBatchDataset : public GraphDatasetBase {
+ public:
+ PrependFromQueueAndPaddedBatchDataset(
+ OpKernelContext* ctx, const int64 batch_size, const DatasetBase* input,
+ const DataTypeVector& dtypes,
+ const std::vector<PartialTensorShape>& shapes,
+ std::vector<Tensor> padding_values)
+ : GraphDatasetBase(ctx),
+ batch_size_(batch_size),
+ input_(input),
+ dtypes_(dtypes),
+ shapes_(shapes),
+ padding_values_(std::move(padding_values)),
+ dtypes_with_queue_(PrependQueueType(dtypes)),
+ batched_shapes_with_queue_(PrependQueueShapeWithBatch(shapes)) {
+ input_->Ref();
+ }
+
+ ~PrependFromQueueAndPaddedBatchDataset() override { input_->Unref(); }
+
+ std::unique_ptr<IteratorBase> MakeIterator(
+ const string& prefix) const override {
+ return std::unique_ptr<IteratorBase>(new Iterator(
+ {this, strings::StrCat(prefix, "::PrependFromQueueAndPaddedBatch")}));
+ }
+
+ const DataTypeVector& output_dtypes() const override {
+ return dtypes_with_queue_;
+ }
+ const std::vector<PartialTensorShape>& output_shapes() const override {
+ return batched_shapes_with_queue_;
+ }
+
+ string DebugString() override {
+ return "PrependFromQueueAndPaddedBatchDatasetOp::Dataset";
+ }
+
+ protected:
+ Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b,
+ Node** output) const override {
+ Node* input_graph = nullptr;
+ TF_RETURN_IF_ERROR(b->AddParentDataset(ctx, input_, &input_graph));
+ Node* batch_size = nullptr;
+ TF_RETURN_IF_ERROR(b->AddScalar(batch_size_, &batch_size));
+
+ std::vector<Node*> padded_shapes;
+ padded_shapes.reserve(shapes_.size());
+ for (int i = 0; i < shapes_.size(); i++) {
+ Node* node;
+ Tensor t(DT_INT64, TensorShape({shapes_[i].dims()}));
+ for (int j = 0; j < shapes_[i].dims(); j++) {
+ t.vec<int64>()(j) = shapes_[i].dim_size(j);
+ }
+ TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
+ padded_shapes.emplace_back(node);
+ }
+
+ std::vector<Node*> padding_values;
+ padding_values.reserve(padding_values_.size());
+ for (const Tensor& t : padding_values_) {
+ Node* node;
+ TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
+ padding_values.emplace_back(node);
+ }
+
+ AttrValue output_types;
+ b->BuildAttrValue(dtypes_, &output_types);
+
+ AttrValue output_shapes;
+ b->BuildAttrValue(batched_shapes_with_queue_, &output_shapes);
+
+ AttrValue N;
+ b->BuildAttrValue<int64>(shapes_.size(), &N);
+
+ TF_RETURN_IF_ERROR(b->AddDataset(this, {{0, input_graph}, {1, batch_size}},
+ {{2, padded_shapes}, {3, padding_values}},
+ {{"Toutput_types", output_types},
+ {"output_shapes", output_shapes},
+ {"N", N}},
+ output));
+
+ return Status::OK();
+ }
+
+ private:
+ friend class EnqueueInQueueDatasetOp;
+
+ class Iterator
+ : public DatasetIterator<PrependFromQueueAndPaddedBatchDataset> {
+ public:
+ explicit Iterator(const Params& params)
+ : DatasetIterator<PrependFromQueueAndPaddedBatchDataset>(params),
+ queue_(new TensorQueue(/*input_impl*/
+ params.dataset->input_->MakeIterator(
+ params.prefix),
+ params.dataset->dtypes_,
+ params.dataset->shapes_)) {}
+
+ ~Iterator() override { queue_->Unref(); }
+
+ Status GetNextInternal(IteratorContext* ctx,
+ std::vector<Tensor>* out_tensors,
+ bool* end_of_sequence) override {
+ std::vector<std::vector<Tensor>> batch;
+ TF_RETURN_IF_ERROR(queue_->GetNext(ctx, dataset()->batch_size_, &batch,
+ end_of_sequence));
+ const auto& dtypes = dataset()->dtypes_;
+ const auto& shapes = dataset()->shapes_;
+ const auto& input_shapes = dataset()->input_->output_shapes();
+ const auto& padding_values = dataset()->padding_values_;
+ const int64 batch_size = batch.size();
+ out_tensors->reserve(dtypes.size());
+
+ std::vector<TensorShape> max_shapes; // Of non-queue components.
+ for (int i = 0; i < dtypes.size(); ++i) {
+ const PartialTensorShape& shape = shapes[i];
+ TensorShape out_shape({batch_size});
+ for (int r = 0; r < shape.dims(); ++r) {
+ if (shape.dim_size(r) >= 0) {
+ // padded_shape[r] is known.
+ out_shape.AddDim(shape.dim_size(r));
+ } else {
+ // padded_shape[r] is unknown, find the maximum across
+ // the batch.
+ int64 dim = 0;
+ for (int b = 0; b < batch.size(); ++b) {
+ dim = std::max(dim, batch[b][i].dim_size(r));
+ }
+ out_shape.AddDim(dim);
+ }
+ }
+ max_shapes.push_back(std::move(out_shape));
+ }
+
+ Tensor queues_t(cpu_allocator(), DT_VARIANT, TensorShape({batch_size}));
+ if (!batch.empty()) {
+ auto queues = queues_t.flat<Variant>();
+ Variant& queue_inserter = queues(0);
+ queue_inserter = TensorQueueInserter();
+ queue_inserter.get<TensorQueueInserter>()->set_queue(queue_);
+ for (int b = 1; b < batch.size(); ++b) {
+ // Copy the TensorQueueInserter. Each copy increments the
+ // Ref on the queue_.
+ queues(b) = queues(0);
+ }
+ }
+ out_tensors->push_back(std::move(queues_t));
+
+ for (int i = 0; i < max_shapes.size(); ++i) {
+ Tensor component(cpu_allocator(), dtypes[i], max_shapes[i]);
+ // Try hard to take the fast path.
+ if (shapes[i].IsFullyDefined() &&
+ shapes[i].IsIdenticalTo(input_shapes[i])) {
+ // Take the fast path if we know all the shapes statically.
+ for (int64 b = 0; b < batch.size(); ++b) {
+ TF_RETURN_IF_ERROR(
+ batch_util::CopyElementToSlice(batch[b][i], &component, b));
+ }
+ } else {
+ TF_RETURN_IF_ERROR(
+ batch_util::SetElementZero(&component, padding_values[i]));
+ for (int64 b = 0; b < batch.size(); ++b) {
+ if (batch[b][i].shape() == max_shapes[i]) {
+ TF_RETURN_IF_ERROR(
+ batch_util::CopyElementToSlice(batch[b][i], &component, b));
+ } else {
+ TF_RETURN_IF_ERROR(batch_util::CopyElementToLargerSlice(
+ batch[b][i], &component, b));
+ }
+ }
+ }
+ out_tensors->push_back(std::move(component));
+ }
+
+ // end_of_sequence was set before we populated out_tensors, so
+ // it's ok to return now.
+ return Status::OK();
+ }
+
+ protected:
+ // Work around bug in MSVC that disallows access to protected
+ // members of Iterator from within TensorQueue.
+ class TensorQueue;
+ friend class TensorQueue;
+
+ class TensorQueue : public core::RefCounted {
+ public:
+ TensorQueue(std::unique_ptr<IteratorBase> input_impl,
+ const DataTypeVector& dtypes,
+ const std::vector<PartialTensorShape>& shapes)
+ : dtypes_(dtypes),
+ shapes_(shapes),
+ input_impl_(std::move(input_impl)) {}
+
+ void MaybeWaitForNotificationLocked(mutex_lock* lock)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ // This essentially just releases the lock and immediately relocks.
+ cv_.wait_for(*lock, std::chrono::milliseconds(0));
+ }
+
+ void NotifyLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_) { cv_.notify_all(); }
+
+ Status GetNext(IteratorContext* ctx, const int64 batch_size,
+ std::vector<std::vector<Tensor>>* batch,
+ bool* end_of_sequence) {
+ mutex_lock lock(mu_);
+
+ *end_of_sequence = false;
+
+ for (int64 b = 0; b < batch_size;) {
+ if (!entries_.empty()) {
+ batch->push_back(std::move(entries_.front()));
+ entries_.pop_front();
+ ++b;
+ continue;
+ } else {
+ if (input_impl_) {
+ // There's still input coming in.
+ std::vector<Tensor> tensors;
+ bool input_end;
+ TF_RETURN_IF_ERROR(
+ input_impl_->GetNext(ctx, &tensors, &input_end));
+ if (!input_end) {
+ batch->push_back(std::move(tensors));
+ ++b;
+ continue;
+ } else {
+ input_impl_.reset();
+ }
+ }
+ if (!input_impl_) {
+ // There's no more input coming in.
+ if (RefCountIsOne()) {
+ // No TensorQueueInserters in the wild.
+ if (batch->empty()) {
+ *end_of_sequence = true;
+ }
+ break;
+ } else {
+ MaybeWaitForNotificationLocked(&lock);
+ // If there's data available, try to add entries again.
+ // Otherwise return a smaller batch and hope the next
+ // iterator request has a non-empty or unused queue_.
+ if (entries_.empty()) {
+ break;
+ }
+ }
+ }
+ }
+ } // for (int64 b = ... batch_size)
+ return Status::OK();
+ }
+
+ Status Insert(const std::vector<Tensor>& tensors) {
+ if (tensors.size() != dtypes_.size()) {
+ return errors::InvalidArgument(
+ "TensorQueue::Insert: mismatched number of tensors. Queue "
+ "expects ",
+ dtypes_.size(), " tensors but tried to insert ", tensors.size());
+ }
+ for (int i = 0; i < tensors.size(); ++i) {
+ if (tensors[i].dtype() != dtypes_[i]) {
+ return errors::InvalidArgument(
+ "TensorQueue::Insert: mismatched dtypes at component ", i,
+ ". Attempted "
+ "to insert tensor of type ",
+ DataTypeString(tensors[i].dtype()),
+ " but queue expected type: ", DataTypeString(dtypes_[i]));
+ }
+ if (!shapes_[i].IsCompatibleWith(tensors[i].shape())) {
+ return errors::InvalidArgument(
+ "TensorQueue::Insert: mismatched shapes at component ", i,
+ ". Attempted "
+ "to insert tensor with shape ",
+ tensors[i].shape().DebugString(),
+ " but queue expected shape: ", shapes_[i].DebugString());
+ }
+ }
+ mutex_lock lock(mu_);
+ entries_.push_back(tensors);
+ NotifyLocked();
+ return Status::OK();
+ }
+
+ Status Save(Iterator* iter, IteratorStateWriter* writer) {
+ mutex_lock lock(mu_);
+ if (input_impl_) {
+ TF_RETURN_IF_ERROR(iter->SaveParent(writer, input_impl_));
+ } else {
+ TF_RETURN_IF_ERROR(
+ writer->WriteScalar(iter->full_name("input_exhausted"), ""));
+ }
+ TF_RETURN_IF_ERROR(writer->WriteScalar(iter->full_name("entries_size"),
+ entries_.size()));
+ for (int64 b = 0; b < entries_.size(); ++b) {
+ for (int i = 0; i < dtypes_.size(); ++i) {
+ TF_RETURN_IF_ERROR(
+ writer->WriteTensor(strings::StrCat(iter->full_name("entries"),
+ "[", b, "][", i, "]"),
+ entries_[b][i]));
+ }
+ }
+ return Status::OK();
+ }
+
+ Status Restore(Iterator* iter, IteratorContext* ctx,
+ IteratorStateReader* reader) {
+ mutex_lock l(mu_);
+ if (reader->Contains(iter->full_name("input_exhausted"))) {
+ input_impl_.reset();
+ } else {
+ input_impl_ = iter->dataset_input()->MakeIterator(iter->prefix());
+ TF_RETURN_IF_ERROR(iter->RestoreParent(ctx, reader, input_impl_));
+ }
+ entries_.clear();
+ int64 entries_size = -1;
+ TF_RETURN_IF_ERROR(
+ reader->ReadScalar(iter->full_name("entries_size"), &entries_size));
+ if (entries_size < 0) {
+ return errors::DataLoss(
+ "Expected entries_size key '", iter->full_name("entries_size"),
+ "' to have nonnegative value, but saw: ", entries_size);
+ }
+ for (int64 b = 0; b < entries_size; ++b) {
+ std::vector<Tensor> entry;
+ for (int i = 0; i < dtypes_.size(); ++i) {
+ Tensor value;
+ TF_RETURN_IF_ERROR(
+ reader->ReadTensor(strings::StrCat(iter->full_name("entries"),
+ "[", b, "][", i, "]"),
+ &value));
+ entry.push_back(std::move(value));
+ }
+ entries_.push_back(std::move(entry));
+ }
+ return Status::OK();
+ }
+
+ mutex* mu() { return &mu_; }
+
+ private:
+ DataTypeVector dtypes_;
+ std::vector<PartialTensorShape> shapes_;
+
+ mutex mu_;
+ std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
+ std::deque<std::vector<Tensor>> entries_ GUARDED_BY(mu_);
+ condition_variable cv_ GUARDED_BY(mu_);
+ };
+
+ const DatasetBase* dataset_input() const { return dataset()->input_; }
+
+ Status SaveInternal(IteratorStateWriter* writer) override {
+ return queue_->Save(this, writer);
+ }
+
+ Status RestoreInternal(IteratorContext* ctx,
+ IteratorStateReader* reader) override {
+ return queue_->Restore(this, ctx, reader);
+ }
+
+ public:
+ class TensorQueueInserter {
+ public:
+ TensorQueueInserter() : queue_(nullptr) {}
+
+ void set_queue(TensorQueue* queue) {
+ queue_ = queue;
+ queue_->Ref();
+ }
+
+ TensorQueueInserter(const TensorQueueInserter& rhs) {
+ queue_ = rhs.queue_;
+ queue_->Ref();
+ };
+
+ TensorQueueInserter(TensorQueueInserter&& rhs) {
+ queue_ = rhs.queue_;
+ rhs.queue_ = nullptr;
+ }
+
+ TensorQueueInserter& operator=(const TensorQueueInserter& rhs) = delete;
+
+ string TypeName() const { return "tensorflow::TensorQueueInserter"; }
+ string DebugString() const { return TypeName(); }
+
+ void Encode(VariantTensorData*) const {}
+ bool Decode(const VariantTensorData&) { return false; }
+
+ ~TensorQueueInserter() {
+ if (queue_) {
+ mutex_lock lock(*queue_->mu());
+ queue_->Unref();
+ queue_->NotifyLocked();
+ queue_ = nullptr;
+ }
+ }
+
+ Status Insert(const std::vector<Tensor>& tensors) const {
+ CHECK(queue_);
+ return queue_->Insert(tensors);
+ }
+
+ private:
+ mutable TensorQueue* queue_;
+ };
+
+ private:
+ TensorQueue* const queue_;
+ };
+
+ private:
+ const int64 batch_size_;
+ const DatasetBase* input_;
+ const DataTypeVector dtypes_;
+ const std::vector<PartialTensorShape> shapes_;
+ const std::vector<Tensor> padding_values_;
+ const DataTypeVector dtypes_with_queue_;
+ const std::vector<PartialTensorShape> batched_shapes_with_queue_;
+};
+
+class PrependFromQueueAndPaddedBatchDatasetOp : public UnaryDatasetOpKernel {
+ public:
+ explicit PrependFromQueueAndPaddedBatchDatasetOp(OpKernelConstruction* ctx)
+ : UnaryDatasetOpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("Toutput_types", &output_types_));
+ }
+
+ void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
+ DatasetBase** output) override {
+ int64 batch_size = 0;
+ OP_REQUIRES_OK(ctx,
+ ParseScalarArgument<int64>(ctx, "batch_size", &batch_size));
+ OP_REQUIRES(
+ ctx, batch_size > 0,
+ errors::InvalidArgument("Batch size must be greater than zero."));
+
+ OpInputList padded_shape_tensors;
+ OP_REQUIRES_OK(ctx,
+ ctx->input_list("padded_shapes", &padded_shape_tensors));
+ std::vector<PartialTensorShape> padded_shapes;
+ padded_shapes.reserve(padded_shape_tensors.size());
+ OP_REQUIRES(ctx,
+ padded_shape_tensors.size() == input->output_shapes().size(),
+ errors::InvalidArgument("Number of padded shapes (",
+ padded_shape_tensors.size(),
+ ") must match the number of components "
+ "in the input dataset's elements (",
+ input->output_shapes().size(), ")"));
+ for (const Tensor& padded_shape_t : padded_shape_tensors) {
+ OP_REQUIRES(ctx, TensorShapeUtils::IsVector(padded_shape_t.shape()),
+ errors::InvalidArgument("All padded shapes must be vectors"));
+ PartialTensorShape padded_shape;
+ OP_REQUIRES_OK(ctx, PartialTensorShape::MakePartialShape(
+ padded_shape_t.vec<int64>().data(),
+ padded_shape_t.NumElements(), &padded_shape));
+ padded_shapes.push_back(std::move(padded_shape));
+ }
+
+ OP_REQUIRES(
+ ctx, input->output_dtypes() == output_types_,
+ errors::InvalidArgument("Input dataset and this dataset "
+ "have different output_types: ",
+ DataTypeVectorString(input->output_dtypes()),
+ " and ", DataTypeVectorString(output_types_)));
+
+ for (int i = 0; i < input->output_shapes().size(); ++i) {
+ // Exclude the queue from the tensor_shapes calculation.
+ const PartialTensorShape& tensor_shape = padded_shapes[i];
+ OP_REQUIRES(
+ ctx,
+ IsGreaterEqualToOrCompatibleWith(tensor_shape,
+ input->output_shapes()[i]),
+ errors::InvalidArgument("Incompatible input shapes at component ", i,
+ " between input dataset this dataset: ",
+ input->output_shapes()[i].DebugString(),
+ " vs. ", tensor_shape.DebugString()));
+ }
+
+ OpInputList padding_values_list;
+ OP_REQUIRES_OK(ctx,
+ ctx->input_list("padding_values", &padding_values_list));
+ std::vector<Tensor> padding_values;
+ OP_REQUIRES(ctx,
+ padding_values_list.size() == input->output_shapes().size(),
+ errors::InvalidArgument(
+ "Number of padding values (", padding_values_list.size(),
+ ") must match the number of components in the input "
+ "dataset's elements (",
+ input->output_shapes().size(), ")"));
+ for (int i = 0; i < padding_values_list.size(); ++i) {
+ const Tensor& padding_value_t = padding_values_list[i];
+ OP_REQUIRES(
+ ctx, TensorShapeUtils::IsScalar(padding_value_t.shape()),
+ errors::InvalidArgument(
+ "All padding values must be scalars; but at component ", i,
+ " saw shape: ", padding_value_t.shape().DebugString()));
+ OP_REQUIRES(ctx, padding_value_t.dtype() == input->output_dtypes()[i],
+ errors::InvalidArgument(
+ "Mismatched type between padding value ", i,
+ " and input dataset's component ", i, ": ",
+ DataTypeString(padding_value_t.dtype()), " vs. ",
+ DataTypeString(input->output_dtypes()[i])));
+ padding_values.push_back(padding_value_t);
+ }
+
+ *output = new PrependFromQueueAndPaddedBatchDataset(
+ ctx, batch_size, input, output_types_, padded_shapes,
+ std::move(padding_values));
+ }
+
+ private:
+ DataTypeVector output_types_;
+};
+
+REGISTER_KERNEL_BUILDER(
+ Name("PrependFromQueueAndPaddedBatchDataset").Device(DEVICE_CPU),
+ PrependFromQueueAndPaddedBatchDatasetOp);
+
+class EnqueueInQueueDatasetOp : public OpKernel {
+ public:
+ explicit EnqueueInQueueDatasetOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+ void Compute(OpKernelContext* ctx) override {
+ using TensorQueueInserter =
+ PrependFromQueueAndPaddedBatchDataset::Iterator::TensorQueueInserter;
+
+ // TODO(ebrevdo): accept list of sequence lengths to do proper
+ // sub-slicing of tensors for placement into the queue?
+ const Tensor& tensor_queue_t = ctx->input(0);
+ OP_REQUIRES(ctx, TensorShapeUtils::IsVector(tensor_queue_t.shape()),
+ errors::InvalidArgument("queue must be a vector, saw shape: ",
+ tensor_queue_t.shape().DebugString()));
+ std::vector<const TensorQueueInserter*> inserters;
+ const int64 batch_size = tensor_queue_t.NumElements();
+ inserters.reserve(batch_size);
+ const Variant* variants = tensor_queue_t.flat<Variant>().data();
+ for (int i = 0; i < batch_size; ++i) {
+ const auto* inserter = variants[i].get<TensorQueueInserter>();
+ OP_REQUIRES(ctx, inserter != nullptr,
+ errors::InvalidArgument(
+ "Could not access TensorQueueInserter from queue[", i,
+ "]. Received variant: ", variants[i].DebugString()));
+ inserters.push_back(inserter);
+ }
+
+ OpInputList components;
+ OP_REQUIRES_OK(ctx, ctx->input_list("components", &components));
+ for (int i = 0; i < components.size(); ++i) {
+ OP_REQUIRES(
+ ctx,
+ components[i].dims() > 0 && components[i].dim_size(0) == batch_size,
+ errors::InvalidArgument(
+ "Expected component ", i, " to have batched shape [", batch_size,
+ ",...], but saw shape: ", components[i].shape().DebugString()));
+ }
+ std::vector<TensorShape> element_shapes;
+ for (int i = 0; i < components.size(); ++i) {
+ TensorShape element_shape = components[i].shape();
+ element_shape.RemoveDim(0);
+ element_shapes.push_back(std::move(element_shape));
+ }
+ for (int64 b = 0; b < batch_size; ++b) {
+ std::vector<Tensor> tensors;
+ tensors.reserve(components.size());
+ for (int i = 0; i < components.size(); ++i) {
+ Tensor t(components[i].dtype(), element_shapes[i]);
+ OP_REQUIRES_OK(ctx,
+ batch_util::CopySliceToElement(components[i], &t, b));
+ tensors.push_back(std::move(t));
+ }
+ // TODO(ebrevdo): Acquire the lock once for all inserters with
+ // the same underlying queue? Add InsertLocked?
+ OP_REQUIRES_OK(ctx, inserters[b]->Insert(tensors));
+ }
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("EnqueueInQueueDataset").Device(DEVICE_CPU),
+ EnqueueInQueueDatasetOp);
+
+} // namespace
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/gather_op.cc b/tensorflow/core/kernels/gather_op.cc
index d6cbcf1d93..0a38d3d4af 100644
--- a/tensorflow/core/kernels/gather_op.cc
+++ b/tensorflow/core/kernels/gather_op.cc
@@ -18,6 +18,8 @@ limitations under the License.
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/variant.h"
+#include "tensorflow/core/framework/variant_encode_decode.h"
#include "tensorflow/core/kernels/bounds_check.h"
#include "tensorflow/core/kernels/gather_functor.h"
#include "tensorflow/core/platform/mem.h"
@@ -141,6 +143,7 @@ TF_CALL_ALL_TYPES(REGISTER_GATHER_CPU);
TF_CALL_QUANTIZED_TYPES(REGISTER_GATHER_CPU);
TF_CALL_quint16(REGISTER_GATHER_CPU);
TF_CALL_qint16(REGISTER_GATHER_CPU);
+TF_CALL_variant(REGISTER_GATHER_CPU);
#undef REGISTER_GATHER_CPU
diff --git a/tensorflow/core/kernels/matrix_band_part_op.cc b/tensorflow/core/kernels/matrix_band_part_op.cc
index d7fff4bb0c..1439141f64 100644
--- a/tensorflow/core/kernels/matrix_band_part_op.cc
+++ b/tensorflow/core/kernels/matrix_band_part_op.cc
@@ -62,7 +62,15 @@ class MatrixBandPartOp : public OpKernel {
OP_REQUIRES(context, TensorShapeUtils::IsScalar(num_lower_in.shape()),
errors::InvalidArgument("num_lower must be scalar, got shape ",
num_lower_in.shape().DebugString()));
- const int64 num_lower = num_lower_in.scalar<int64>()();
+
+ auto as_int64_scalar = [](const Tensor& tensor) -> int64 {
+ if (tensor.dtype() == DT_INT32) {
+ return tensor.scalar<int32>()();
+ } else {
+ return tensor.scalar<int64>()();
+ }
+ };
+ const int64 num_lower = as_int64_scalar(num_lower_in);
OP_REQUIRES(
context, num_lower <= input_reshaped.dimension(1),
errors::InvalidArgument(
@@ -73,7 +81,7 @@ class MatrixBandPartOp : public OpKernel {
OP_REQUIRES(context, TensorShapeUtils::IsScalar(num_upper_in.shape()),
errors::InvalidArgument("num_upper must be scalar, got shape ",
num_upper_in.shape().DebugString()));
- const int64 num_upper = num_upper_in.scalar<int64>()();
+ const int64 num_upper = as_int64_scalar(num_upper_in);
OP_REQUIRES(context, num_upper <= input_reshaped.dimension(2),
errors::InvalidArgument("num_upper must be negative or less or "
"equal to number of columns (",
diff --git a/tensorflow/core/kernels/strided_slice_op.cc b/tensorflow/core/kernels/strided_slice_op.cc
index 8f7f91c9df..e0b85c6d06 100644
--- a/tensorflow/core/kernels/strided_slice_op.cc
+++ b/tensorflow/core/kernels/strided_slice_op.cc
@@ -294,6 +294,11 @@ class StridedSliceAssignOp : public OpKernel {
OP_REQUIRES_OK(context,
LookupResource(context, HandleFromInput(context, 0), &v));
old_lhs = *v->tensor();
+ OP_REQUIRES(context, old_lhs.dtype() == DataTypeToEnum<T>::value,
+ errors::InvalidArgument(
+ "l-value dtype ", DataTypeString(old_lhs.dtype()),
+ " does not match r-value dtype ",
+ DataTypeString(DataTypeToEnum<T>::value)));
} else {
context->forward_ref_input_to_ref_output(0, 0);
old_lhs = context->mutable_input(0, true);
@@ -386,6 +391,7 @@ class StridedSliceAssignOp : public OpKernel {
StridedSliceAssignOp<CPUDevice, type>)
TF_CALL_ALL_TYPES(REGISTER_STRIDED_SLICE);
+TF_CALL_variant(REGISTER_STRIDED_SLICE);
#undef REGISTER_STRIDED_SLICE
diff --git a/tensorflow/core/kernels/strided_slice_op_impl.h b/tensorflow/core/kernels/strided_slice_op_impl.h
index ac1259a9ac..c3187e49ce 100644
--- a/tensorflow/core/kernels/strided_slice_op_impl.h
+++ b/tensorflow/core/kernels/strided_slice_op_impl.h
@@ -26,6 +26,8 @@ limitations under the License.
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/register_types_traits.h"
#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/variant.h"
+#include "tensorflow/core/framework/variant_encode_decode.h"
#include "tensorflow/core/kernels/bounds_check.h"
#include "tensorflow/core/kernels/dense_update_functor.h"
#include "tensorflow/core/kernels/ops_util.h"
@@ -288,6 +290,7 @@ DECLARE_FOR_N_GPU(int64);
#endif // END GOOGLE_CUDA
TF_CALL_ALL_TYPES(DECLARE_FOR_N_CPU);
+TF_CALL_variant(DECLARE_FOR_N_CPU);
#ifdef TENSORFLOW_USE_SYCL
#define PREVENT_FOR_N_SYCL(T) \
diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc
index 5ec2a4e9b4..267ce88440 100644
--- a/tensorflow/core/ops/array_ops.cc
+++ b/tensorflow/core/ops/array_ops.cc
@@ -708,10 +708,11 @@ REGISTER_OP("MatrixDiagPart")
// --------------------------------------------------------------------------
REGISTER_OP("MatrixBandPart")
.Input("input: T")
- .Input("num_lower: int64")
- .Input("num_upper: int64")
+ .Input("num_lower: Tindex")
+ .Input("num_upper: Tindex")
.Output("band: T")
.Attr("T: type")
+ .Attr("Tindex: {int32, int64} = DT_INT64")
.SetShapeFn(shape_inference::UnchangedShape);
// --------------------------------------------------------------------------
diff --git a/tensorflow/core/ops/compat/ops_history.v1.pbtxt b/tensorflow/core/ops/compat/ops_history.v1.pbtxt
index 65ab81931a..177561161e 100644
--- a/tensorflow/core/ops/compat/ops_history.v1.pbtxt
+++ b/tensorflow/core/ops/compat/ops_history.v1.pbtxt
@@ -17137,6 +17137,24 @@ op {
}
}
op {
+ name: "EnqueueInQueueDataset"
+ input_arg {
+ name: "queue"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "components"
+ type_list_attr: "Tcomponents"
+ }
+ attr {
+ name: "Tcomponents"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ is_stateful: true
+}
+op {
name: "Enter"
input_arg {
name: "data"
@@ -24841,6 +24859,42 @@ op {
}
}
op {
+ name: "MatrixBandPart"
+ input_arg {
+ name: "input"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "num_lower"
+ type_attr: "Tindex"
+ }
+ input_arg {
+ name: "num_upper"
+ type_attr: "Tindex"
+ }
+ output_arg {
+ name: "band"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ }
+ attr {
+ name: "Tindex"
+ type: "type"
+ default_value {
+ type: DT_INT64
+ }
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+}
+op {
name: "MatrixDeterminant"
input_arg {
name: "input"
@@ -32097,6 +32151,48 @@ op {
}
}
op {
+ name: "PrependFromQueueAndPaddedBatchDataset"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "batch_size"
+ type: DT_INT64
+ }
+ input_arg {
+ name: "padded_shapes"
+ type: DT_INT64
+ number_attr: "N"
+ }
+ input_arg {
+ name: "padding_values"
+ type_list_attr: "Toutput_types"
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "Toutput_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "N"
+ type: "int"
+ has_minimum: true
+ minimum: 1
+ }
+}
+op {
name: "PreventGradient"
input_arg {
name: "input"
diff --git a/tensorflow/core/ops/dataset_ops.cc b/tensorflow/core/ops/dataset_ops.cc
index 2cae814eab..3c8e9a8a5f 100644
--- a/tensorflow/core/ops/dataset_ops.cc
+++ b/tensorflow/core/ops/dataset_ops.cc
@@ -491,4 +491,29 @@ REGISTER_OP("StatsAggregatorSummary")
.Output("summary: string")
.SetShapeFn(shape_inference::ScalarShape);
+REGISTER_OP("PrependFromQueueAndPaddedBatchDataset")
+ .Input("input_dataset: variant")
+ .Input("batch_size: int64")
+ .Input("padded_shapes: N * int64")
+ .Input("padding_values: Toutput_types")
+ .Output("handle: variant")
+ .Attr("Toutput_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .Attr("N: int >= 1")
+ // TODO(ebrevdo): Validate that `padded_shapes` are all vectors, the lengths
+ // of `Toutput_types` and `output_shapes` are `N`, that the
+ // length of `output_types` is `N`, the `output_shapes` are
+ // (as far as possible to tell statically) compatible with `padded_shapes`,
+ // and that `padding_values` are all scalars.
+ .SetShapeFn(shape_inference::ScalarShape);
+
+REGISTER_OP("EnqueueInQueueDataset")
+ .Input("queue: variant")
+ .Input("components: Tcomponents")
+ .Attr("Tcomponents: list(type) >= 1")
+ .SetIsStateful() // To avoid CSE on multiple calls to Enqueue.
+ // TODO(ebrevdo): SetShapeFn to test input dtypes and shapes by
+ // reading from queue handle (is that even possible?).
+ .SetShapeFn(shape_inference::NoOutputs);
+
} // namespace tensorflow
diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt
index b57206c9c4..2cd8d8a03b 100644
--- a/tensorflow/core/ops/ops.pbtxt
+++ b/tensorflow/core/ops/ops.pbtxt
@@ -7645,6 +7645,24 @@ op {
}
}
op {
+ name: "EnqueueInQueueDataset"
+ input_arg {
+ name: "queue"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "components"
+ type_list_attr: "Tcomponents"
+ }
+ attr {
+ name: "Tcomponents"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ is_stateful: true
+}
+op {
name: "Enter"
input_arg {
name: "data"
@@ -12330,11 +12348,11 @@ op {
}
input_arg {
name: "num_lower"
- type: DT_INT64
+ type_attr: "Tindex"
}
input_arg {
name: "num_upper"
- type: DT_INT64
+ type_attr: "Tindex"
}
output_arg {
name: "band"
@@ -12344,6 +12362,19 @@ op {
name: "T"
type: "type"
}
+ attr {
+ name: "Tindex"
+ type: "type"
+ default_value {
+ type: DT_INT64
+ }
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
}
op {
name: "MatrixDeterminant"
@@ -15927,6 +15958,48 @@ op {
}
}
op {
+ name: "PrependFromQueueAndPaddedBatchDataset"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "batch_size"
+ type: DT_INT64
+ }
+ input_arg {
+ name: "padded_shapes"
+ type: DT_INT64
+ number_attr: "N"
+ }
+ input_arg {
+ name: "padding_values"
+ type_list_attr: "Toutput_types"
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "Toutput_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "N"
+ type: "int"
+ has_minimum: true
+ minimum: 1
+ }
+}
+op {
name: "PreventGradient"
input_arg {
name: "input"
diff --git a/tensorflow/core/platform/cloud/gcs_throttle.h b/tensorflow/core/platform/cloud/gcs_throttle.h
index 8e46fca6ca..1a89daef08 100644
--- a/tensorflow/core/platform/cloud/gcs_throttle.h
+++ b/tensorflow/core/platform/cloud/gcs_throttle.h
@@ -126,7 +126,7 @@ class GcsThrottle {
void UpdateState() EXCLUSIVE_LOCKS_REQUIRED(mu_);
inline uint64 request_bytes_to_tokens(size_t num_bytes) {
- return num_bytes >> 8;
+ return num_bytes >> 10;
}
mutex mu_;
diff --git a/tensorflow/core/platform/cloud/gcs_throttle_test.cc b/tensorflow/core/platform/cloud/gcs_throttle_test.cc
index a1e8167c27..694756022e 100644
--- a/tensorflow/core/platform/cloud/gcs_throttle_test.cc
+++ b/tensorflow/core/platform/cloud/gcs_throttle_test.cc
@@ -68,7 +68,7 @@ TEST_F(GcsThrottleTest, RejectRequest) {
TEST_F(GcsThrottleTest, MarkResponses) {
time_.AdvanceSeconds(1);
EXPECT_TRUE(throttle_.AdmitRequest());
- throttle_.RecordResponse(32000000); // 32 MB response
+ throttle_.RecordResponse(128000000); // 128 MB response
EXPECT_EQ(-25100, throttle_.available_tokens());
EXPECT_FALSE(throttle_.AdmitRequest());
time_.AdvanceSeconds(1);
diff --git a/tensorflow/docs_src/api_guides/python/TPUEstimator.md b/tensorflow/docs_src/api_guides/python/TPUEstimator.md
new file mode 100644
index 0000000000..d74d7f3181
--- /dev/null
+++ b/tensorflow/docs_src/api_guides/python/TPUEstimator.md
@@ -0,0 +1,396 @@
+# Using TPUs
+
+This document walks through the principal TensorFlow APIs necessary to make
+effective use of a [Cloud TPU](https://cloud.google.com/tpu/), and highlights
+the differences between regular TensorFlow usage, and usage on a TPU.
+
+This doc is aimed at users who:
+
+* Are familiar with TensorFlow's `Estimator` and `Dataset` APIs
+* Have maybe [tried out a Cloud TPU](https://cloud.google.com/tpu/docs/quickstart)
+ using an existing model.
+* Have, perhaps, skimmed the code of an example TPU model
+ [[1]](https://github.com/tensorflow/models/blob/master/official/mnist/mnist_tpu.py)
+ [[2]](https://github.com/tensorflow/tpu-demos/tree/master/cloud_tpu/models).
+* Are interested in porting an existing `Estimator` model to
+ run on Cloud TPUs
+
+## TPUEstimator
+
+@{tf.estimator.Estimator$Estimators} are TensorFlow's model-level abstraction.
+Standard `Estimators` can drive models on CPU and GPUs. You must use
+@{tf.contrib.tpu.TPUEstimator} to drive a model on TPUs.
+
+Refer to TensorFlow's Getting Started section for an introduction to the basics
+of using a @{$get_started/premade_estimators$pre-made `Estimator`}, and
+@{$get_started/custom_estimators$custom `Estimator`s}.
+
+The `TPUEstimator` class differs somewhat from the `Estimator` class.
+
+The simplest way to maintain a model that can be run both on CPU/GPU or on a
+Cloud TPU is to define the model's inference phase (from inputs to predictions)
+outside of the `model_fn`. Then maintain separate implementations of the
+`Estimator` setup and `model_fn`, both wrapping this inference step. For an
+example of this pattern compare the `mnist.py` and `mnist_tpu.py` implementation in
+[tensorflow/models](https://github.com/tensorflow/models/tree/master/official/mnist).
+
+### Running a `TPUEstimator` locally
+
+To create a standard `Estimator` you call the constructor, and pass it a
+`model_fn`, for example:
+
+```
+my_estimator = tf.estimator.Estimator(
+ model_fn=my_model_fn)
+```
+
+The changes required to use a @{tf.contrib.tpu.TPUEstimator} on your local
+machine are relatively minor. The constructor requires two additional arguments.
+You should set the `use_tpu` argument to `False`, and pass a
+@{tf.contrib.tpu.RunConfig} as the `config` argument, as shown below:
+
+``` python
+my_tpu_estimator = tf.contrib.tpu.TPUEstimator(
+ model_fn=my_model_fn,
+ config=tf.contrib.tpu.RunConfig()
+ use_tpu=False)
+```
+
+Just this simple change will allow you to run a `TPUEstimator` locally.
+The majority of example TPU models can be run in this local mode,
+by setting the command line flags as follows:
+
+
+```
+$> python mnist_tpu.py --use_tpu=false --master=''
+```
+
+Note: This `use_tpu=False` argument is useful for trying out the `TPUEstimator`
+API. It is not meant to be a complete TPU compatibility test. Successfully
+running a model locally in a `TPUEstimator` does not guarantee that it will
+work on a TPU.
+
+
+### Building a `tpu.RunConfig`
+
+While the default `RunConfig` is sufficient for local training, these settings
+cannot be ignored in real usage.
+
+A more typical setup for a `RunConfig`, that can be switched to use a Cloud
+TPU, might be as follows:
+
+``` python
+import tempfile
+import subprocess
+
+class FLAGS(object):
+ use_tpu=False
+ tpu_name=None
+ # Use a local temporary path for the `model_dir`
+ model_dir = tempfile.mkdtemp()
+ # Number of training steps to run on the Cloud TPU before returning control.
+ iterations = 50
+ # A single Cloud TPU has 8 shards.
+ num_shards = 8
+
+if FLAGS.use_tpu:
+ my_project_name = subprocess.check_output([
+ 'gcloud','config','get-value','project'])
+ my_zone = subprocess.check_output([
+ 'gcloud','config','get-value','compute/zone'])
+ cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
+ tpu_names=[FLAGS.tpu_name],
+ zone=my_zone,
+ project=my_project)
+ master = tpu_cluster_resolver.get_master()
+else:
+ master = ''
+
+my_tpu_run_config = tf.contrib.tpu.RunConfig(
+ master=master,
+ evaluation_master=master,
+ model_dir=FLAGS.model_dir,
+ session_config=tf.ConfigProto(
+ allow_soft_placement=True, log_device_placement=True),
+ tpu_config=tf.contrib.tpu.TPUConfig(FLAGS.iterations,
+ FLAGS.num_shards),
+)
+```
+
+Then you must pass the @{tf.contrib.tpu.RunConfig} to the constructor:
+
+``` python
+my_tpu_estimator = tf.contrib.tpu.TPUEstimator(
+ model_fn=my_model_fn,
+ config = my_tpu_run_config,
+ use_tpu=FLAGS.use_tpu)
+```
+
+Typically the `FLAGS` would be set by command line arguments. To switch from
+training locally to training on a cloud TPU you would need to:
+
+ 1) Set `FLAGS.use_tpu` to `True`
+ 1) Set `FLAGS.tpu_name` so the
+ `tf.contrib.cluster_resolver.TPUClusterResolver` can find it
+ 1) Set `FLAGS.model_dir` to a Google Cloud Storage bucket url (`gs://`).
+
+
+## Optimizer
+
+When training on a cloud TPU you **must** wrap the optimizer in a
+@{tf.contrib.tpu.CrossShardOptimizer}, which uses an `allreduce` to aggregate
+gradients and broadcast the result to each shard (each TPU core).
+
+The `CrossShardOptimizer` is not compatible with local training. So, to have
+the same code run both locally and on a Cloud TPU, add lines like the following:
+
+``` python
+optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate)
+if FLAGS.use_tpu:
+ optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer)
+```
+
+If you prefer to avoid a global `FLAGS` variable in your model code, one
+approach is to set the optimizer as one of the `Estimator`'s params,
+as follows:
+
+``` python
+my_tpu_estimator = tf.contrib.tpu.TPUEstimator(
+ model_fn=my_model_fn,
+ config = my_tpu_run_config,
+ use_tpu=FLAGS.use_tpu,
+ params={'optimizer':optimizer})
+```
+
+## Model Function
+
+This section details the changes you must make to the model function
+(`model_fn()`) to make it `TPUEstimator` compatible.
+
+### Static shapes
+
+During regular usage TensorFlow attempts to determine the shapes of each
+`tf.Tensor` during graph construction. During execution any unknown shape
+dimensions are determined dynamically,
+see @{$programmers_guide/tensors#shape$Tensor Shapes} for more details.
+
+To run on Cloud TPUs TensorFlow models are compiled using @{$xla$XLA}.
+XLA uses a similar system for determining shapes at compile time. XLA requires
+that all tensor dimensions be statically defined at compile time. All shapes
+must evaluate to a constant, and not depend on external data, or stateful
+operations like variables or a random number generator.
+
+
+### Summaries
+
+Remove any use of `tf.summary` from your model.
+
+@{$summaries_and_tensorboard$TensorBoard summaries} are a great way see inside
+your model. A minimal set of basic summaries are automatically recorded by the
+`TPUEstimator`, to `event` files in the `model_dir`. Custom summaries, however,
+are currently unsupported when training on a Cloud TPU. So while the
+`TPUEstimator` will still run locally with summaries, it will fail if used on a
+TPU.
+
+### Metrics
+
+Build your evaluation metrics dictionary in a stand-alone `metric_fn`.
+
+<!-- TODO(markdaoust) link to programmers_guide/metrics when it exists -->
+
+Evaluation metrics are an essential part of training a model. These are fully
+supported on Cloud TPUs, but with a slightly different syntax.
+
+A standard @{tf.metrics} returns two tensors. The first returns the running
+average of the metric value, while the second updates the running average and
+returns the value for this batch:
+
+```
+running_average, current_batch = tf.metrics.accuracy(labels, predictions)
+```
+
+In a standard `Estimator` you create a dictionary of these pairs, and return it
+as part of the `EstimatorSpec`.
+
+```python
+my_metrics = {'accuracy': tf.metrics.accuracy(labels, predictions)}
+
+return tf.estimator.EstimatorSpec(
+ ...
+ eval_metric_ops=my_metrics
+)
+```
+
+In a `TPUEstimator` you instead pass a function (which returns a metrics
+dictionary) and a list of argument tensors, as shown below:
+
+```python
+def my_metric_fn(labels, predictions):
+ return {'accuracy': tf.metrics.accuracy(labels, predictions)}
+
+return tf.contrib.tpu.TPUEstimatorSpec(
+ ...
+ eval_metrics=(my_metric_fn, [labels, predictions])
+)
+```
+
+### Use `TPUEstimatorSpec`
+
+`TPUEstimatorSpec` do not support hooks, and require function wrappers for
+some fields.
+
+An `Estimator`'s `model_fn` must return an `EstimatorSpec`. An `EstimatorSpec`
+is a simple structure of named fields containing all the `tf.Tensors` of the
+model that the `Estimator` may need to interact with.
+
+`TPUEstimators` use a @{tf.contrib.tpu.TPUEstimatorSpec}. There are a few
+differences between it and a standard @{tf.estimator.EstimatorSpec}:
+
+
+* The `eval_metric_ops` must be wrapped into a `metrics_fn`, this field is
+ renamed `eval_metrics` ([see above](#metrics)).
+* The @{tf.train.SessionRunHook$hooks} are unsupported, so these fields are
+ omitted.
+* The @{tf.train.Scaffold$`scaffold`}, if used, must also be wrapped in a
+ function. This field is renamed to `scaffold_fn`.
+
+`Scaffold` and `Hooks` are for advanced usage, and can typically be omitted.
+
+## Input functions
+
+Input functions work mainly unchanged as they run on the host computer, not the
+Cloud TPU itself. This section explains the two necessary adjustments.
+
+### Params argument
+
+<!-- TODO(markdaoust) link to input_fn doc when it exists -->
+
+The `input_fn` for a standard `Estimator` _can_ include a
+`params` argument; the `input_fn` for a `TPUEstimator` *must* include a
+`params` argument. This is necessary to allow the estimator to set the batch
+size for each replica of the input stream. So the minimum signature for an
+`input_fn` for a `TPUEstimator` is:
+
+```
+def my_input_fn(params):
+ pass
+```
+
+Where `params['batch-size']` will contain the batch size.
+
+### Static shapes and batch size
+
+The input pipeline generated by your `input_fn` is run on CPU. So it is mostly
+free strict static shape requirements imposed by the XLA/TPU environment. The
+one requirement is that the batches of data fed from your input pipeline to
+the TPU have a static shape, as determined by the standard TensorFlow shape
+inference algorithm. Intermediate tensors are free to have a dynamic shapes.
+If shape inference has failed, but the shape is known it is possible to
+impose the correct shape using `tf.set_shape()`.
+
+In the example below the shape
+inference algorithm fails, but it is corrected using `set_shape`:
+
+```
+>>> x = tf.zeros(tf.constant([1,2,3])+1)
+>>> x.shape
+
+TensorShape([Dimension(None), Dimension(None), Dimension(None)])
+
+>>> x.set_shape([2,3,4])
+```
+
+In many cases the batch size is the only unknown dimension.
+
+A typical input pipeline, using `tf.data`, will usually produce batches of a
+fixed size. The last batch of a finite `Dataset`, however, is typically smaller,
+containing just the remaining elements. Since a `Dataset` does not know its own
+length or finiteness, the standard @{tf.data.Dataset.batch$`batch`} method
+cannot determine if all batches will have a fixed size batch on its own:
+
+```
+>>> params = {'batch_size':32}
+>>> ds = tf.data.Dataset.from_tensors([0, 1, 2])
+>>> ds = ds.repeat().batch(params['batch-size'])
+>>> ds
+
+<BatchDataset shapes: (?, 3), types: tf.int32>
+```
+
+The most straightforward fix is to
+@{tf.data.Dataset.apply$apply} @{tf.contrib.data.batch_and_drop_remainder}
+as follows:
+
+```
+>>> params = {'batch_size':32}
+>>> ds = tf.data.Dataset.from_tensors([0, 1, 2])
+>>> ds = ds.repeat().apply(
+... tf.contrib.data.batch_and_drop_remainder(params['batch-size']))
+>>> ds
+
+ <_RestructuredDataset shapes: (32, 3), types: tf.int32>
+```
+
+The one downside to this approach is that, as the name implies, this batching
+method throws out any fractional batch at the end of the dataset. This is fine
+for an infinitely repeating dataset being used for training, but could be a
+problem if you want to train for an exact number of epochs.
+
+To do an exact 1-epoch of _evaluation_ you can work around this by manually
+padding the length of the batches, and setting the padding entries to have zero
+weight when creating your `tf.metrics`.
+
+## Datasets
+
+Efficient use of the `tf.data.Dataset` API is critical when using a Cloud
+TPU, as it is impossible to use the Cloud TPU's unless you can feed it data
+quickly enough. See @{$datasets_performance} for details on dataset performance.
+
+For all but the simplest experimentation (using
+@{tf.data.Dataset.from_tensor_slices} or other in-graph data) you will need to
+store all data files read by the `TPUEstimator`'s `Dataset` in Google Cloud
+Storage Buckets.
+
+<!--TODO(markdaoust): link to the `TFRecord` doc when it exists.-->
+
+For most use-cases, we recommend converting your data into `TFRecord`
+format and using a @{tf.data.TFRecordDataset} to read it. This, however, is not
+a hard requirement and you can use other dataset readers
+(`FixedLengthRecordDataset` or `TextLineDataset`) if you prefer.
+
+Small datasets can be loaded entirely into memory using
+@{tf.data.Dataset.cache}.
+
+Regardless of the data format used, it is strongly recommended that you
+@{$performance_guide#use_large_files$use large files}, on the order of
+100MB. This is especially important in this networked setting as the overhead
+of opening a file is significantly higher.
+
+It is also important, regardless of the type of reader used, to enable buffering
+using the `buffer_size` argument to the constructor. This argument is specified
+in bytes. A minimum of a few MB (`buffer_size=8*1024*1024`) is recommended so
+that data is available when needed.
+
+The TPU-demos repo includes
+[a script](https://github.com/tensorflow/tpu-demos/blob/master/cloud_tpu/datasets/imagenet_to_gcs.py)
+for downloading the imagenet dataset and converting it to an appropriate format.
+This together with the imagenet
+[models](https://github.com/tensorflow/tpu-demos/tree/master/cloud_tpu/models)
+included in the repo demonstrate all of these best-practices.
+
+
+## What Next
+
+For details on how to actually set up and run a Cloud TPU see:
+
+ * [Google Cloud TPU Documentation](https://cloud.google.com/tpu/docs/)
+
+This document is by no means exhaustive. The best source of more detail on how
+to make a Cloud TPU compatible model are the example models published in:
+
+ * The [TPU Demos Repository.](https://github.com/tensorflow/tpu-demos/)
+
+For more information about tuning TensorFlow code for performance see:
+
+ * The @{$performance$Performance Section.}
+
diff --git a/tensorflow/docs_src/install/install_sources.md b/tensorflow/docs_src/install/install_sources.md
index c1e4545d20..bc7d2080dc 100644
--- a/tensorflow/docs_src/install/install_sources.md
+++ b/tensorflow/docs_src/install/install_sources.md
@@ -272,8 +272,6 @@ Found possible Python library paths:
Please input the desired Python library path to use. Default is [/usr/lib/python2.7/dist-packages]
Using python library path: /usr/local/lib/python2.7/dist-packages
-Do you wish to build TensorFlow with MKL support? [y/N]
-No MKL support will be enabled for TensorFlow
Please specify optimization flags to use during compilation when bazel option "--config=opt" is specified [Default is -march=native]:
Do you wish to use jemalloc as the malloc implementation? [Y/n]
jemalloc enabled
diff --git a/tensorflow/docs_src/performance/performance_guide.md b/tensorflow/docs_src/performance/performance_guide.md
index 10e7ad7ada..cd47fc2803 100644
--- a/tensorflow/docs_src/performance/performance_guide.md
+++ b/tensorflow/docs_src/performance/performance_guide.md
@@ -498,7 +498,7 @@ For TensorFlow source versions after 1.3.0:
```bash
./configure
# Pick the desired options
-bazel build --config=mkl -c opt //tensorflow/tools/pip_package:build_pip_package
+bazel build --config=mkl --config=opt //tensorflow/tools/pip_package:build_pip_package
```
diff --git a/tensorflow/docs_src/programmers_guide/saved_model.md b/tensorflow/docs_src/programmers_guide/saved_model.md
index 9f50be5b31..f27a658342 100644
--- a/tensorflow/docs_src/programmers_guide/saved_model.md
+++ b/tensorflow/docs_src/programmers_guide/saved_model.md
@@ -285,7 +285,7 @@ with tf.Session(graph=tf.Graph()) as sess:
```
-### Loading a Savedmodel in C++
+### Loading a SavedModel in C++
The C++ version of the SavedModel
[loader](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/cc/saved_model/loader.h)
@@ -303,6 +303,30 @@ LoadSavedModel(session_options, run_options, export_dir, {kSavedModelTagTrain},
&bundle);
```
+### Loading and Serving a SavedModel in TensorFlow Serving
+
+You can easily load and serve a SavedModel with the TensorFlow Serving Model
+Server binary. See [instructions](https://www.tensorflow.org/serving/setup#installing_using_apt-get)
+on how to install the server, or build it if you wish.
+
+Once you have the Model Server, run it with:
+```
+tensorflow_model_server --port=port-numbers --model_name=your-model-name --model_base_path=your_model_base_path
+```
+Set the port and model_name flags to values of your choosing. The
+model_base_path flag expects to be to a base directory, with each version of
+your model residing in a numerically named subdirectory. If you only have a
+single version of your model, simply place it in a subdirectory like so:
+* Place the model in /tmp/model/0001
+* Set model_base_path to /tmp/model
+
+Store different versions of your model in numerically named subdirectories of a
+common base directory. For example, suppose the base directory is `/tmp/model`.
+If you have only one version of your model, store it in `/tmp/model/0001`. If
+you have two versions of your model, store the second version in
+`/tmp/model/0002`, and so on. Set the `--model-base_path` flag to the base
+directory (`/tmp/model`, in this example). TensorFlow Model Server will serve
+the model in the highest numbered subdirectory of that base directory.
### Standard constants
diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go
index 5b19c90238..cb47651d7b 100644
--- a/tensorflow/go/op/wrappers.go
+++ b/tensorflow/go/op/wrappers.go
@@ -8729,31 +8729,6 @@ func IRFFT2D(scope *Scope, input tf.Output, fft_length tf.Output) (output tf.Out
return op.Output(0)
}
-// Compute the pairwise cross product.
-//
-// `a` and `b` must be the same shape; they can either be simple 3-element vectors,
-// or any shape where the innermost dimension is 3. In the latter case, each pair
-// of corresponding 3-element vectors is cross-multiplied independently.
-//
-// Arguments:
-// a: A tensor containing 3-element vectors.
-// b: Another tensor, of same type and shape as `a`.
-//
-// Returns Pairwise cross product of the vectors in `a` and `b`.
-func Cross(scope *Scope, a tf.Output, b tf.Output) (product tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "Cross",
- Input: []tf.Input{
- a, b,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// Transforms a vector of brain.Example protos (as strings) into typed tensors.
//
// Arguments:
@@ -21290,6 +21265,31 @@ func StatsAggregatorSummary(scope *Scope, iterator tf.Output) (summary tf.Output
return op.Output(0)
}
+// Compute the pairwise cross product.
+//
+// `a` and `b` must be the same shape; they can either be simple 3-element vectors,
+// or any shape where the innermost dimension is 3. In the latter case, each pair
+// of corresponding 3-element vectors is cross-multiplied independently.
+//
+// Arguments:
+// a: A tensor containing 3-element vectors.
+// b: Another tensor, of same type and shape as `a`.
+//
+// Returns Pairwise cross product of the vectors in `a` and `b`.
+func Cross(scope *Scope, a tf.Output, b tf.Output) (product tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "Cross",
+ Input: []tf.Input{
+ a, b,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// Performs a padding as a preprocess during a convolution.
//
// Similar to FusedResizeAndPadConv2d, this op allows for an optimized
diff --git a/tensorflow/python/data/kernel_tests/BUILD b/tensorflow/python/data/kernel_tests/BUILD
index 43cbde69d9..8b8adefa65 100644
--- a/tensorflow/python/data/kernel_tests/BUILD
+++ b/tensorflow/python/data/kernel_tests/BUILD
@@ -357,6 +357,9 @@ tf_py_test(
"//tensorflow/python:session",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/data/ops:iterator_ops",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:string_ops",
+ "//tensorflow/python:lookup_ops",
],
grpc_enabled = True,
tags = [
diff --git a/tensorflow/python/data/kernel_tests/iterator_ops_cluster_test.py b/tensorflow/python/data/kernel_tests/iterator_ops_cluster_test.py
index 45dfa13720..2c65c49ebd 100644
--- a/tensorflow/python/data/kernel_tests/iterator_ops_cluster_test.py
+++ b/tensorflow/python/data/kernel_tests/iterator_ops_cluster_test.py
@@ -21,6 +21,7 @@ from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import session
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
+from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import function
@@ -28,6 +29,8 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import functional_ops
+from tensorflow.python.ops import lookup_ops
+from tensorflow.python.ops import string_ops
from tensorflow.python.platform import test
@@ -103,6 +106,40 @@ class IteratorClusterTest(test.TestCase):
"/job:worker/replica:0/task:1/cpu:0",
workers[0].target)
+ def testCaptureHashTableInSharedIterator(self):
+ worker, _ = test_util.create_local_cluster(1, 1)
+
+ # NOTE(mrry): We must use the V2 variants of `HashTable`
+ # etc. because these produce a `tf.resource`-typed output that is
+ # compatible with the in-graph function implementation.
+ default_val = -1
+ keys = constant_op.constant(["brain", "salad", "surgery"])
+ values = constant_op.constant([0, 1, 2], dtypes.int64)
+ table = lookup_ops.HashTable(
+ lookup_ops.KeyValueTensorInitializer(keys, values),
+ default_val,
+ shared_name="shared_table")
+
+ input_sentences = dataset_ops.Dataset.from_tensor_slices(
+ ["brain brain tank salad surgery", "surgery brain"])
+
+ iterator = (
+ input_sentences.map(lambda x: string_ops.string_split([x]).values).map(
+ table.lookup)
+ .make_initializable_iterator(shared_name="shared_iterator"))
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+
+ with session.Session(worker[0].target) as sess:
+ sess.run(table.init)
+ sess.run(init_op)
+ self.assertAllEqual([0, 0, -1, 1, 2], sess.run(get_next))
+
+ with session.Session(worker[0].target) as sess:
+ self.assertAllEqual([2, 0], sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py
index 7e0feb0669..c4b7e4919b 100644
--- a/tensorflow/python/data/ops/dataset_ops.py
+++ b/tensorflow/python/data/ops/dataset_ops.py
@@ -1459,6 +1459,19 @@ def _padding_value_to_tensor(value, output_type):
return value
+def _default_padding(input_dataset):
+
+ def make_zero(t):
+ if t.base_dtype == dtypes.string:
+ return ""
+ elif t.base_dtype == dtypes.variant:
+ raise TypeError("Unable to create padding for field of type 'variant'")
+ else:
+ return np.zeros_like(t.as_numpy_dtype())
+
+ return nest.map_structure(make_zero, input_dataset.output_types)
+
+
class PaddedBatchDataset(Dataset):
"""A `Dataset` that batches and pads contiguous elements from its input."""
@@ -1474,23 +1487,13 @@ class PaddedBatchDataset(Dataset):
batch_size, dtype=dtypes.int64, name="batch_size")
padding_values = (
padding_values
- if padding_values is not None else self._default_padding(input_dataset))
+ if padding_values is not None else _default_padding(input_dataset))
self._padded_shapes = nest.map_structure_up_to(
input_dataset.output_shapes, _partial_shape_to_tensor, padded_shapes)
self._padding_values = nest.map_structure_up_to(
input_dataset.output_shapes, _padding_value_to_tensor, padding_values,
input_dataset.output_types)
- def _default_padding(self, input_dataset):
-
- def make_zero(t):
- if t.base_dtype == dtypes.string:
- return ""
- else:
- return np.zeros_like(t.as_numpy_dtype())
-
- return nest.map_structure(make_zero, input_dataset.output_types)
-
def _as_variant_tensor(self):
return gen_dataset_ops.padded_batch_dataset(
self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
diff --git a/tensorflow/python/data/util/nest.py b/tensorflow/python/data/util/nest.py
index 6d2f730016..e90ce3fb40 100644
--- a/tensorflow/python/data/util/nest.py
+++ b/tensorflow/python/data/util/nest.py
@@ -383,8 +383,8 @@ def assert_shallow_structure(shallow_tree, input_tree, check_types=True):
"structure has keys %s, while shallow structure has keys %s." %
(list(_six.iterkeys(input_tree)),
list(_six.iterkeys(shallow_tree))))
- input_tree = list(_six.iteritems(input_tree))
- shallow_tree = list(_six.iteritems(shallow_tree))
+ input_tree = list(sorted(_six.iteritems(input_tree)))
+ shallow_tree = list(sorted(_six.iteritems(shallow_tree)))
for shallow_branch, input_branch in zip(shallow_tree, input_tree):
assert_shallow_structure(shallow_branch, input_branch,
diff --git a/tensorflow/python/data/util/nest_test.py b/tensorflow/python/data/util/nest_test.py
index 90dd7dfe77..ff380815a4 100644
--- a/tensorflow/python/data/util/nest_test.py
+++ b/tensorflow/python/data/util/nest_test.py
@@ -277,6 +277,10 @@ class NestTest(test.TestCase):
with self.assertRaisesRegexp(ValueError, expected_message):
nest.assert_shallow_structure(inp_ab2, inp_ab1)
+ inp_ab = collections.OrderedDict([("a", 1), ("b", (2, 3))])
+ inp_ba = collections.OrderedDict([("b", (2, 3)), ("a", 1)])
+ nest.assert_shallow_structure(inp_ab, inp_ba)
+
def testFlattenUpTo(self):
input_tree = (((2, 2), (3, 3)), ((4, 9), (5, 5)))
shallow_tree = ((True, True), (False, True))
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py
index 81b1f6f12a..f5d0759bdc 100644
--- a/tensorflow/python/eager/function.py
+++ b/tensorflow/python/eager/function.py
@@ -292,6 +292,22 @@ def _map_sequence_obj_to_idx(sequence):
return {id(x): i for i, x in enumerate(sequence)}
+def _flatten(sequence):
+ """A wrapper around `nest.flatten` that also unpacks `IndexedSlices`."""
+ # TODO(akshayka): Support `SparseTensor` in a similar fashion.
+ flat_sequence = nest.flatten(sequence)
+ outputs = []
+ for item in flat_sequence:
+ if isinstance(item, ops.IndexedSlices):
+ if item.dense_shape is not None:
+ outputs.extend([item.values, item.indices, item.dense_shape])
+ else:
+ outputs.extend([item.values, item.indices])
+ else:
+ outputs.append(item)
+ return outputs
+
+
class GraphModeFunction(object):
"""Callable object representing a graph-mode function.
@@ -333,14 +349,14 @@ class GraphModeFunction(object):
self._input_placeholders = input_placeholders
self._extra_inputs = list(extra_inputs)
self._graph = graph
- self._has_backprop = False
+ self._backward_function = None
self._func_name = name
self._function_def = defined_function
self._num_outputs = len(defined_function.signature.output_arg)
self._ops = operations
self._func_outputs = func_outputs
self._returns = [func_outputs] if isinstance(
- func_outputs, (ops.Tensor, type(None))) else list(func_outputs)
+ func_outputs, (ops.Tensor, type(None))) else _flatten(func_outputs)
self._output_shapes = output_shapes
self._variables = variables if variables is not None else []
@@ -348,9 +364,8 @@ class GraphModeFunction(object):
def variables(self):
return self._variables
- def _compute_backprop(self):
- """Computes the backprop function object for this function."""
- self._has_backprop = True
+ def _construct_backprop_function(self):
+ """Constructs the backprop function object for this function."""
with self._graph.as_default(), context.graph_mode():
c = _CapturingContext()
with c:
@@ -361,13 +376,16 @@ class GraphModeFunction(object):
filtered_outputs,
self._input_placeholders,
grad_ys=self._out_grad_placeholders)
- shapes = tuple(x.shape for x in in_gradients if x is not None)
+
+ backward_outputs = tuple(
+ grad for grad in _flatten(in_gradients) if grad is not None)
+ output_shapes = tuple(grad.shape for grad in backward_outputs)
+
captures = list(sorted(c.captured_tensors, key=lambda x: x.name))
forward_name = _forward_name(self._func_name)
self._forward_fdef = _EagerDefinedFunction(
forward_name, self._graph, self._ops, self._input_placeholders,
filtered_outputs + captures)
- backward_outputs = tuple(x for x in in_gradients if x is not None)
all_inputs = self._out_grad_placeholders + captures
# Excluding input ops from the body as we do not intend to execute these
# operations when the function is executed.
@@ -381,7 +399,7 @@ class GraphModeFunction(object):
bname = _backward_name(self._func_name)
self._backward_function = GraphModeFunction(
bname, all_inputs, [], self._graph, function_def_ops,
- backward_outputs, in_gradients, shapes)
+ backward_outputs, in_gradients, output_shapes)
def _backprop_call(self, args):
"""Calls the wrapped function and records the result on a tape."""
@@ -426,9 +444,24 @@ class GraphModeFunction(object):
@property
def output_shapes(self):
+ """The function's output shapes."""
# TODO(ebrevdo): Should we only keep the output shapes associated
# with len(self._returns) outputs?
- return nest.pack_sequence_as(self._func_outputs, self._output_shapes)
+ outputs_list = nest.flatten(self._func_outputs)
+ j = 0
+ for i, o in enumerate(outputs_list):
+ if o is not None:
+ if isinstance(o, ops.IndexedSlices):
+ # Extract the shape of the `IndexedSlices` object's `values` field.
+ outputs_list[i] = self._output_shapes[j] # the `values` shape
+ if o.dense_shape is not None:
+ j += 3 # skip over shapes for `values`, `indices`, `dense_shape`
+ else:
+ j += 2 # skip over shapes for `values`, `indices`
+ else:
+ outputs_list[i] = self._output_shapes[j]
+ j += 1
+ return nest.pack_sequence_as(self._func_outputs, outputs_list)
@property
def output_dtypes(self):
@@ -457,12 +490,11 @@ class GraphModeFunction(object):
if v._trainable: # pylint: disable=protected-access
tape.watch_variable(v)
- tensor_inputs = [x for x in nest.flatten(args)
- if isinstance(x, ops.Tensor)]
+ tensor_inputs = [x for x in nest.flatten(args) if isinstance(x, ops.Tensor)]
if tape.should_record(tensor_inputs) or tape.should_record(
self._extra_inputs):
- if not self._has_backprop:
- self._compute_backprop()
+ if self._backward_function is None:
+ self._construct_backprop_function()
return self._backprop_call(tensor_inputs)
ctx = context.context()
@@ -503,13 +535,30 @@ class GraphModeFunction(object):
"""
if self._func_outputs is None:
return None
+ # Use `nest.flatten` instead of `_flatten` in order to preserve any
+ # IndexedSlices in `self._func_outputs`.
outputs_list = nest.flatten(self._func_outputs)
j = 0
for i, o in enumerate(outputs_list):
if o is not None:
- outputs_list[i] = result[j]
- j += 1
- return nest.pack_sequence_as(self._func_outputs, outputs_list)
+ if isinstance(o, ops.IndexedSlices):
+ # Repack Tensors for IndexedSlices.
+ if o.dense_shape is not None:
+ outputs_list[i] = ops.IndexedSlices(
+ values=result[j],
+ indices=result[j + 1],
+ dense_shape=result[j + 2])
+ j += 3
+ else:
+ outputs_list[i] = ops.IndexedSlices(
+ values=result[j],
+ indices=result[j + 1])
+ j += 2
+ else:
+ outputs_list[i] = result[j]
+ j += 1
+ ret = nest.pack_sequence_as(self._func_outputs, outputs_list)
+ return ret
def _get_defun_inputs(args):
@@ -555,7 +604,7 @@ def _defun_internal(name, func, args, kwds):
# Returning a closed-over tensor as an output does not trigger a
# call to convert_to_tensor, so we manually capture all such tensors.
- outputs_list = nest.flatten(func_outputs)
+ outputs_list = _flatten(func_outputs)
func_def_outputs = [
_convert_to_graph_tensor(x) for x in outputs_list if x is not None
]
@@ -600,6 +649,18 @@ def _cache_key(x):
"""Cache key for tfe functions."""
if isinstance(x, ops.Tensor):
return _TensorDtype(x.dtype, x._shape_tuple()) # pylint: disable=protected-access
+ if isinstance(x, ops.IndexedSlices):
+ if x.dense_shape is not None:
+ return tuple([
+ _TensorDtype(x.values.dtype, x.values._shape_tuple()), # pylint: disable=protected-access
+ _TensorDtype(x.indices.dtype, x.indices._shape_tuple()), # pylint: disable=protected-access
+ _TensorDtype(x.dense_shape.dtype, x.dense_shape._shape_tuple()) # pylint: disable=protected-access
+ ])
+ else:
+ return tuple([
+ _TensorDtype(x.values.dtype, x.values._shape_tuple()), # pylint: disable=protected-access
+ _TensorDtype(x.indices.dtype, x.indices._shape_tuple()) # pylint: disable=protected-access
+ ])
if isinstance(x, np.ndarray):
return ("array", x.shape, tuple(x.reshape(-1)))
if isinstance(x, (list, tuple)):
diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py
index 2cb2cfb76c..3e8e67ac7e 100644
--- a/tensorflow/python/eager/function_test.py
+++ b/tensorflow/python/eager/function_test.py
@@ -374,6 +374,78 @@ class FunctionTest(test.TestCase):
self.assertAllEqual(f(constant_op.constant(1.0)), 2.0)
+ def testGradientOfGatherWithDefun(self):
+
+ v = resource_variable_ops.ResourceVariable([0.0, 1.0, 2.0])
+
+ def sum_gather():
+ return math_ops.reduce_sum(array_ops.gather(v, [1, 2]))
+
+ grad_fn = backprop.implicit_grad(sum_gather)
+ gradient = grad_fn()
+ defun_grad_fn = backprop.implicit_grad(function.defun(sum_gather))
+ defun_gradient = defun_grad_fn()
+ self.assertEqual(len(gradient), len(defun_gradient))
+
+ gradient = gradient[0][0]
+ defun_gradient = defun_gradient[0][0]
+ self.assertAllEqual(gradient.values, defun_gradient.values)
+ self.assertAllEqual(gradient.indices, defun_gradient.indices)
+ self.assertAllEqual(gradient.dense_shape, defun_gradient.dense_shape)
+
+ def testReturningIndexedSlicesWithDefun(self):
+
+ def validate(indexed_slice):
+ def f():
+ return indexed_slice
+
+ output = function.defun(f)()
+ self.assertTrue(isinstance(output, ops.IndexedSlices))
+ self.assertAllEqual(indexed_slice.values, output.values)
+ self.assertAllEqual(indexed_slice.indices, output.indices)
+ self.assertAllEqual(indexed_slice.dense_shape, output.dense_shape)
+
+ self.assertEqual(
+ function.make_defun_op(f).output_shapes, indexed_slice.values.shape)
+
+ arg = ops.IndexedSlices(
+ values=constant_op.constant([1, 2]),
+ indices=constant_op.constant([0, 1]),
+ dense_shape=constant_op.constant([2]))
+ validate(arg)
+
+ arg = ops.IndexedSlices(
+ values=constant_op.constant([1, 2]),
+ indices=constant_op.constant([0, 1]),
+ dense_shape=None)
+ validate(arg)
+
+ def testIndexedSliceAsArgumentWithDefun(self):
+
+ @function.defun
+ def f(indexed_slice):
+ return indexed_slice
+
+ def validate(arg):
+ output = f(arg)
+ self.assertTrue(isinstance(output, ops.IndexedSlices))
+ self.assertAllEqual(arg.values, output.values)
+ self.assertAllEqual(arg.indices, output.indices)
+ self.assertAllEqual(arg.dense_shape, output.dense_shape)
+
+ indexed_slice = ops.IndexedSlices(
+ values=constant_op.constant([1]),
+ indices=constant_op.constant([0]),
+ dense_shape=constant_op.constant([1]))
+ validate(indexed_slice)
+
+ # Test that `f` works even when `dense_shape` is None.
+ indexed_slice = ops.IndexedSlices(
+ values=constant_op.constant([1]),
+ indices=constant_op.constant([0]),
+ dense_shape=None)
+ validate(indexed_slice)
+
def testFunctionOnDevice(self):
if not context.context().num_gpus():
self.skipTest('No GPUs found')
diff --git a/tensorflow/python/keras/BUILD b/tensorflow/python/keras/BUILD
index a9dd8d8e9d..fdac22bb53 100755
--- a/tensorflow/python/keras/BUILD
+++ b/tensorflow/python/keras/BUILD
@@ -482,6 +482,7 @@ py_test(
size = "small",
srcs = ["_impl/keras/layers/normalization_test.py"],
srcs_version = "PY2AND3",
+ tags = ["notsan"],
deps = [
":keras",
"//tensorflow/python:client_testlib",
diff --git a/tensorflow/python/kernel_tests/array_ops_test.py b/tensorflow/python/kernel_tests/array_ops_test.py
index 68b7c3a98a..7cff3e227c 100644
--- a/tensorflow/python/kernel_tests/array_ops_test.py
+++ b/tensorflow/python/kernel_tests/array_ops_test.py
@@ -952,6 +952,32 @@ class SliceAssignTest(test_util.TensorFlowTestCase):
v = variables.Variable([1, 2])
sess.run(v[:].assign([1, 2]))
+ def testTypeError(self):
+ init_val = constant_op.constant([1, 2], dtype=dtypes.int32)
+ too_small_val = constant_op.constant([3, 4], dtype=dtypes.int8)
+ too_large_val = constant_op.constant([3, 4], dtype=dtypes.int64)
+ v = variables.Variable(init_val)
+ with self.assertRaises(TypeError):
+ v[:].assign(too_small_val)
+ with self.assertRaises(TypeError):
+ v[:].assign(too_large_val)
+
+ def testTypeErrorResource(self):
+ init_val = constant_op.constant([1, 2], dtype=dtypes.int32)
+ too_small_val = constant_op.constant([3, 4], dtype=dtypes.int8)
+ too_large_val = constant_op.constant([3, 4], dtype=dtypes.int64)
+ v = resource_variable_ops.ResourceVariable(init_val)
+ with self.test_session() as sess:
+ sess.run(v.initializer)
+ with self.assertRaisesRegexp(
+ errors.InvalidArgumentError,
+ "l-value dtype int32 does not match r-value dtype int64"):
+ sess.run(v[:].assign(too_large_val))
+ with self.assertRaisesRegexp(
+ errors.InvalidArgumentError,
+ "l-value dtype int32 does not match r-value dtype int8"):
+ sess.run(v[:].assign(too_small_val))
+
class ShapeSizeRankTest(test_util.TensorFlowTestCase):
diff --git a/tensorflow/python/kernel_tests/losses_test.py b/tensorflow/python/kernel_tests/losses_test.py
index 59fe3df3e5..f1fbe1a745 100644
--- a/tensorflow/python/kernel_tests/losses_test.py
+++ b/tensorflow/python/kernel_tests/losses_test.py
@@ -1345,6 +1345,34 @@ class ComputeWeightedLossTest(test.TestCase):
self.assertAllClose(
np.mean(self._raw_losses), unweighted_loss.eval())
+ def testUnweightedFromPlaceholder(self):
+ for reduction in losses.Reduction.all():
+ with ops.Graph().as_default() as g:
+ self.assertEqual(0, len(util.get_losses()))
+ raw_losses = array_ops.placeholder(dtype=dtypes.float32)
+ feed_dict = {raw_losses: self._raw_losses}
+ unweighted_losses = (
+ losses.compute_weighted_loss(raw_losses, reduction=reduction),
+ losses.compute_weighted_loss(
+ raw_losses, weights=np.ones((1, 1, 1)), reduction=reduction),
+ losses.compute_weighted_loss(
+ raw_losses, weights=np.ones((1, 1, 4)), reduction=reduction),
+ )
+ self.assertEqual(3, len(util.get_losses()))
+ with self.test_session(g):
+ for unweighted_loss in unweighted_losses:
+ if reduction == losses.Reduction.NONE:
+ self.assertAllClose(
+ self._raw_losses, unweighted_loss.eval(feed_dict))
+ elif reduction == losses.Reduction.SUM:
+ self.assertAllClose(
+ np.sum(self._raw_losses), unweighted_loss.eval(feed_dict))
+ else:
+ # reduction one of MEAN, SUM_OVER_NONZERO_WEIGHTS,
+ # SUM_BY_NONZERO_WEIGHTS or SUM_OVER_BATCH_SIZE.
+ self.assertAllClose(
+ np.mean(self._raw_losses), unweighted_loss.eval(feed_dict))
+
def testScalarWeight(self):
with ops.Graph().as_default():
self.assertEqual(0, len(util.get_losses()))
diff --git a/tensorflow/python/kernel_tests/matrix_band_part_op_test.py b/tensorflow/python/kernel_tests/matrix_band_part_op_test.py
index 317b8dc05b..68d626de2c 100644
--- a/tensorflow/python/kernel_tests/matrix_band_part_op_test.py
+++ b/tensorflow/python/kernel_tests/matrix_band_part_op_test.py
@@ -21,6 +21,7 @@ import numpy as np
from tensorflow.python.client import session
from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes as dtypes_lib
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
@@ -54,9 +55,13 @@ def _GetMatrixBandPartTest(dtype_, batch_shape_, shape_):
band_np = np.tril(band_np, upper)
if batch_shape_ is not ():
band_np = np.tile(band_np, batch_shape_ + (1, 1))
- with self.test_session(use_gpu=False):
- band = array_ops.matrix_band_part(batch_mat, lower, upper)
- self.assertAllEqual(band_np, band.eval())
+ for index_dtype in [dtypes_lib.int32, dtypes_lib.int64]:
+ with self.test_session(use_gpu=False):
+ band = array_ops.matrix_band_part(
+ batch_mat,
+ constant_op.constant(lower, index_dtype),
+ constant_op.constant(upper, index_dtype))
+ self.assertAllEqual(band_np, band.eval())
return Test
diff --git a/tensorflow/python/ops/losses/losses_impl.py b/tensorflow/python/ops/losses/losses_impl.py
index 285e544047..8b3c61b933 100644
--- a/tensorflow/python/ops/losses/losses_impl.py
+++ b/tensorflow/python/ops/losses/losses_impl.py
@@ -151,7 +151,7 @@ def _num_present(losses, weights, per_batch=False):
def _num_elements(losses):
"""Computes the number of elements in `losses` tensor."""
with ops.name_scope(None, "num_elements", values=[losses]) as scope:
- return array_ops.size(losses, name=scope, out_type=losses.dtype)
+ return math_ops.cast(array_ops.size(losses, name=scope), dtype=losses.dtype)
@tf_export("losses.compute_weighted_loss")
diff --git a/tensorflow/python/util/nest.py b/tensorflow/python/util/nest.py
index 874df3d108..c8525ed420 100644
--- a/tensorflow/python/util/nest.py
+++ b/tensorflow/python/util/nest.py
@@ -532,8 +532,8 @@ def assert_shallow_structure(shallow_tree, input_tree, check_types=True):
(list(_six.iterkeys(input_tree)),
list(_six.iterkeys(shallow_tree))))
- input_tree = list(_six.iteritems(input_tree))
- shallow_tree = list(_six.iteritems(shallow_tree))
+ input_tree = list(sorted(_six.iteritems(input_tree)))
+ shallow_tree = list(sorted(_six.iteritems(shallow_tree)))
for shallow_branch, input_branch in zip(shallow_tree, input_tree):
assert_shallow_structure(shallow_branch, input_branch,
diff --git a/tensorflow/python/util/nest_test.py b/tensorflow/python/util/nest_test.py
index 6bec397db5..8aaf799fd0 100644
--- a/tensorflow/python/util/nest_test.py
+++ b/tensorflow/python/util/nest_test.py
@@ -425,6 +425,10 @@ class NestTest(test.TestCase):
with self.assertRaisesRegexp(ValueError, expected_message):
nest.assert_shallow_structure(inp_ab2, inp_ab1)
+ inp_ab = collections.OrderedDict([("a", 1), ("b", (2, 3))])
+ inp_ba = collections.OrderedDict([("b", (2, 3)), ("a", 1)])
+ nest.assert_shallow_structure(inp_ab, inp_ba)
+
def testFlattenUpTo(self):
# Shallow tree ends at scalar.
input_tree = [[[2, 2], [3, 3]], [[4, 9], [5, 5]]]
diff --git a/tensorflow/tools/ci_build/windows/cpu/cmake/run_build.bat b/tensorflow/tools/ci_build/windows/cpu/cmake/run_build.bat
index 957729bb37..c1bc718507 100644
--- a/tensorflow/tools/ci_build/windows/cpu/cmake/run_build.bat
+++ b/tensorflow/tools/ci_build/windows/cpu/cmake/run_build.bat
@@ -36,7 +36,7 @@ SET CMAKE_DIR=%REPO_ROOT%\tensorflow\contrib\cmake
SET MSBUILD_EXE="C:\Program Files (x86)\MSBuild\14.0\Bin\msbuild.exe"
:: Run cmake to create Visual Studio Project files.
-%CMAKE_EXE% %CMAKE_DIR% -A x64 -DSWIG_EXECUTABLE=%SWIG_EXE% -DPYTHON_EXECUTABLE=%PY_EXE% -DCMAKE_BUILD_TYPE=Release -DPYTHON_LIBRARIES=%PY_LIB% -Dtensorflow_BUILD_PYTHON_TESTS=%BUILD_PYTHON_TESTS% -Dtensorflow_BUILD_CC_TESTS=%BUILD_CC_TESTS% -Dtensorflow_TF_NIGHTLY=%TF_NIGHTLY% -Dtensorflow_DISABLE_EIGEN_FORCEINLINE=%DISABLE_FORCEINLINE%
+%CMAKE_EXE% %CMAKE_DIR% -A x64 -DSWIG_EXECUTABLE=%SWIG_EXE% -DPYTHON_EXECUTABLE=%PY_EXE% -DCMAKE_BUILD_TYPE=Release -DPYTHON_LIBRARIES=%PY_LIB% -Dtensorflow_BUILD_PYTHON_TESTS=%BUILD_PYTHON_TESTS% -Dtensorflow_BUILD_CC_TESTS=%BUILD_CC_TESTS% -Dtensorflow_TF_NIGHTLY=%TF_NIGHTLY% -Dtensorflow_DISABLE_EIGEN_FORCEINLINE=%DISABLE_FORCEINLINE% -Dtensorflow_WIN_CPU_SIMD_OPTIONS=/arch:AVX
:: Run msbuild in the resulting VS project files to build a pip package.
%MSBUILD_EXE% /p:Configuration=Release /maxcpucount:32 tf_python_build_pip_package.vcxproj
diff --git a/tensorflow/tools/ci_build/windows/gpu/cmake/run_build.bat b/tensorflow/tools/ci_build/windows/gpu/cmake/run_build.bat
index 5a362de399..b87e4a9bec 100644
--- a/tensorflow/tools/ci_build/windows/gpu/cmake/run_build.bat
+++ b/tensorflow/tools/ci_build/windows/gpu/cmake/run_build.bat
@@ -37,7 +37,7 @@ SET CMAKE_DIR=%REPO_ROOT%\tensorflow\contrib\cmake
SET MSBUILD_EXE="C:\Program Files (x86)\MSBuild\14.0\Bin\msbuild.exe"
:: Run cmake to create Visual Studio Project files.
-%CMAKE_EXE% %CMAKE_DIR% -A x64 -DSWIG_EXECUTABLE=%SWIG_EXE% -DPYTHON_EXECUTABLE=%PY_EXE% -DCMAKE_BUILD_TYPE=Release -DPYTHON_LIBRARIES=%PY_LIB% -Dtensorflow_BUILD_PYTHON_TESTS=%BUILD_PYTHON_TESTS% -Dtensorflow_BUILD_CC_TESTS=%BUILD_CC_TESTS% -Dtensorflow_ENABLE_GPU=ON -DCUDNN_HOME=%CUDNN_HOME% -Dtensorflow_TF_NIGHTLY=%TF_NIGHTLY% -Dtensorflow_DISABLE_EIGEN_FORCEINLINE=%DISABLE_FORCEINLINE%
+%CMAKE_EXE% %CMAKE_DIR% -A x64 -DSWIG_EXECUTABLE=%SWIG_EXE% -DPYTHON_EXECUTABLE=%PY_EXE% -DCMAKE_BUILD_TYPE=Release -DPYTHON_LIBRARIES=%PY_LIB% -Dtensorflow_BUILD_PYTHON_TESTS=%BUILD_PYTHON_TESTS% -Dtensorflow_BUILD_CC_TESTS=%BUILD_CC_TESTS% -Dtensorflow_ENABLE_GPU=ON -DCUDNN_HOME=%CUDNN_HOME% -Dtensorflow_TF_NIGHTLY=%TF_NIGHTLY% -Dtensorflow_DISABLE_EIGEN_FORCEINLINE=%DISABLE_FORCEINLINE% -Dtensorflow_WIN_CPU_SIMD_OPTIONS=/arch:AVX
:: Run msbuild in the resulting VS project files to build a pip package.
%MSBUILD_EXE% /p:Configuration=Release /maxcpucount:32 tf_python_build_pip_package.vcxproj
diff --git a/tensorflow/tools/docs/pretty_docs.py b/tensorflow/tools/docs/pretty_docs.py
index b5df633800..543b5fa6fe 100644
--- a/tensorflow/tools/docs/pretty_docs.py
+++ b/tensorflow/tools/docs/pretty_docs.py
@@ -162,7 +162,7 @@ def _build_class_page(page_info):
parts.append(h3.format(**method_info.__dict__))
if method_info.signature is not None:
- parts.append(_build_signature(method_info))
+ parts.append(_build_signature(method_info, use_full_name=False))
parts.append(method_info.doc.docstring)
parts.append(_build_function_details(method_info.doc.function_details))
@@ -259,14 +259,14 @@ def _build_module_page(page_info):
return ''.join(parts)
-def _build_signature(obj_info):
+def _build_signature(obj_info, use_full_name=True):
"""Returns a md code block showing the function signature."""
# Special case tf.range, since it has an optional first argument
if obj_info.full_name == 'tf.range':
return (
'``` python\n'
- "range(limit, delta=1, dtype=None, name='range')\n"
- "range(start, limit, delta=1, dtype=None, name='range')\n"
+ "tf.range(limit, delta=1, dtype=None, name='range')\n"
+ "tf.range(start, limit, delta=1, dtype=None, name='range')\n"
'```\n\n')
parts = ['``` python']
@@ -281,7 +281,11 @@ def _build_signature(obj_info):
sig = ',\n'.join(' %s' % sig_item for sig_item in obj_info.signature)
sig = '\n'+sig+'\n'
- parts.append(signature_template.format(name=obj_info.short_name, sig=sig))
+ if use_full_name:
+ obj_name = obj_info.full_name
+ else:
+ obj_name = obj_info.short_name
+ parts.append(signature_template.format(name=obj_name, sig=sig))
parts.append('```\n\n')
return '\n'.join(parts)
diff --git a/tensorflow/tools/pip_package/setup.py b/tensorflow/tools/pip_package/setup.py
index 6c9b5e46ee..d7fab2b93a 100644
--- a/tensorflow/tools/pip_package/setup.py
+++ b/tensorflow/tools/pip_package/setup.py
@@ -80,13 +80,13 @@ CONSOLE_SCRIPTS = [
# is now declared by the tensorboard pip package. If we remove the
# TensorBoard command, pip will inappropriately remove it during install,
# even though the command is not removed, just moved to a different wheel.
- 'tensorboard = tensorboard.main:main',
+ 'tensorboard = tensorboard.main:run_main',
]
# pylint: enable=line-too-long
# remove the tensorboard console script if building tf_nightly
if 'tf_nightly' in project_name:
- CONSOLE_SCRIPTS.remove('tensorboard = tensorboard.main:main')
+ CONSOLE_SCRIPTS.remove('tensorboard = tensorboard.main:run_main')
TEST_PACKAGES = [
'scipy >= 0.15.1',