aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Amit Patankar <amitpatankar@google.com>2018-01-31 22:41:17 -0800
committerGravatar GitHub <noreply@github.com>2018-01-31 22:41:17 -0800
commitbc69c4ceed6544c109be5693eb40ddcf3a4eb95d (patch)
tree0d7f65dcf49e432c7d4023ca5e0d0631b8fad2c0
parent9e7ce91845500e5111e0400766983e69701a1733 (diff)
parent3bd65900f67af950797ef89dde0d984e8b2d0d7a (diff)
Merge pull request #16637 from case540/branch_184052073
Branch 184052073
-rw-r--r--tensorflow/BUILD1
-rw-r--r--tensorflow/compiler/tests/BUILD12
-rw-r--r--tensorflow/compiler/tests/extract_image_patches_op_test.py134
-rw-r--r--tensorflow/compiler/tests/unary_ops_test.py15
-rw-r--r--tensorflow/compiler/tf2xla/g3doc/cpu_supported_ops.md4
-rw-r--r--tensorflow/compiler/tf2xla/g3doc/gpu_supported_ops.md4
-rw-r--r--tensorflow/compiler/tf2xla/kernels/BUILD1
-rw-r--r--tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc169
-rw-r--r--tensorflow/compiler/tf2xla/kernels/unary_ops.cc23
-rw-r--r--tensorflow/compiler/xla/BUILD1
-rw-r--r--tensorflow/compiler/xla/python/local_computation_builder.cc17
-rw-r--r--tensorflow/compiler/xla/python/local_computation_builder.h14
-rw-r--r--tensorflow/compiler/xla/python/local_computation_builder.i2
-rw-r--r--tensorflow/compiler/xla/python/xla_client.py8
-rw-r--r--tensorflow/compiler/xla/python/xla_client_test.py23
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier.cc7
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emitter.cc26
-rw-r--r--tensorflow/compiler/xla/service/elemental_ir_emitter.cc52
-rw-r--r--tensorflow/compiler/xla/service/elemental_ir_emitter.h6
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter.cc9
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc3
-rw-r--r--tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc20
-rw-r--r--tensorflow/compiler/xla/tests/literal_test_util.cc8
-rw-r--r--tensorflow/compiler/xla/tests/literal_test_util_test.cc24
-rw-r--r--tensorflow/compiler/xla/tests/reduce_test.cc20
-rw-r--r--tensorflow/compiler/xla/tests/scalar_computations_test.cc7
-rw-r--r--tensorflow/contrib/BUILD1
-rw-r--r--tensorflow/contrib/cluster_resolver/BUILD1
-rw-r--r--tensorflow/contrib/cluster_resolver/__init__.py12
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py17
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py28
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/filter_dataset_op_test.py4
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/flat_map_dataset_op_test.py15
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py16
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py29
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/kumaraswamy_test.py388
-rw-r--r--tensorflow/contrib/distributions/python/ops/kumaraswamy.py258
-rw-r--r--tensorflow/contrib/eager/python/tfe_test.py4
-rw-r--r--tensorflow/contrib/factorization/python/ops/kmeans.py2
-rw-r--r--tensorflow/contrib/kfac/python/ops/fisher_factors.py63
-rw-r--r--tensorflow/contrib/lite/examples/ios/camera/CameraExampleViewController.mm10
-rw-r--r--tensorflow/contrib/lite/kernels/add.cc102
-rw-r--r--tensorflow/contrib/lite/kernels/add_test.cc54
-rw-r--r--tensorflow/contrib/lite/kernels/batch_to_space_nd.cc9
-rw-r--r--tensorflow/contrib/lite/kernels/batch_to_space_nd_test.cc13
-rw-r--r--tensorflow/contrib/lite/kernels/mul.cc99
-rw-r--r--tensorflow/contrib/lite/kernels/mul_test.cc59
-rw-r--r--tensorflow/contrib/lite/kernels/strided_slice.cc183
-rw-r--r--tensorflow/contrib/lite/testing/generate_examples.py72
-rw-r--r--tensorflow/contrib/lite/testing/generated_examples_zip_test.cc4
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc3
-rw-r--r--tensorflow/contrib/lite/toco/import_tensorflow.cc11
-rw-r--r--tensorflow/contrib/meta_graph_transform/meta_graph_transform.py7
-rw-r--r--tensorflow/contrib/py2tf/BUILD59
-rw-r--r--tensorflow/contrib/py2tf/__init__.py9
-rw-r--r--tensorflow/contrib/py2tf/impl/BUILD65
-rw-r--r--tensorflow/contrib/py2tf/impl/api.py (renamed from tensorflow/contrib/py2tf/api.py)4
-rw-r--r--tensorflow/contrib/py2tf/impl/api_test.py (renamed from tensorflow/contrib/py2tf/api_test.py)4
-rw-r--r--tensorflow/contrib/py2tf/impl/config.py (renamed from tensorflow/contrib/py2tf/config.py)3
-rw-r--r--tensorflow/contrib/py2tf/impl/conversion.py (renamed from tensorflow/contrib/py2tf/conversion.py)4
-rw-r--r--tensorflow/contrib/py2tf/impl/conversion_test.py (renamed from tensorflow/contrib/py2tf/conversion_test.py)2
-rw-r--r--tensorflow/contrib/py2tf/impl/naming.py (renamed from tensorflow/contrib/py2tf/naming.py)0
-rw-r--r--tensorflow/contrib/py2tf/impl/naming_test.py (renamed from tensorflow/contrib/py2tf/naming_test.py)2
-rw-r--r--tensorflow/contrib/quantize/python/graph_matcher.py111
-rw-r--r--tensorflow/contrib/quantize/python/graph_matcher_test.py40
-rw-r--r--tensorflow/contrib/summary/summary_ops.py2
-rw-r--r--tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc30
-rw-r--r--tensorflow/contrib/tpu/profiler/dump_tpu_profile.cc11
-rw-r--r--tensorflow/contrib/tpu/profiler/dump_tpu_profile.h2
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu_config.py3
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu_estimator.py310
-rw-r--r--tensorflow/core/BUILD7
-rw-r--r--tensorflow/core/api_def/base_api/api_def_AssignAddVariableOp.pbtxt7
-rw-r--r--tensorflow/core/api_def/base_api/api_def_AssignSubVariableOp.pbtxt7
-rw-r--r--tensorflow/core/framework/dataset.h2
-rw-r--r--tensorflow/core/kernels/conv_ops_gpu_3.cu.cc28
-rw-r--r--tensorflow/core/ops/lookup_ops.cc3
-rw-r--r--tensorflow/core/profiler/README.md2
-rw-r--r--tensorflow/core/util/example_proto_fast_parsing_test.cc1
-rw-r--r--tensorflow/go/graph.go9
-rw-r--r--tensorflow/go/op/scope.go32
-rw-r--r--tensorflow/go/op/scope_test.go43
-rw-r--r--tensorflow/python/client/session.py3
-rw-r--r--tensorflow/python/data/ops/dataset_ops.py2
-rw-r--r--tensorflow/python/data/ops/iterator_ops.py2
-rw-r--r--tensorflow/python/data/ops/readers.py4
-rw-r--r--tensorflow/python/eager/BUILD23
-rw-r--r--tensorflow/python/eager/gen_op.bzl65
-rw-r--r--tensorflow/python/eager/python_eager_op_gen_main.cc72
-rw-r--r--tensorflow/python/eager/pywrap_tfe_src.cc41
-rw-r--r--tensorflow/python/estimator/export/export.py4
-rw-r--r--tensorflow/python/estimator/export/export_output.py5
-rw-r--r--tensorflow/python/estimator/inputs/numpy_io.py2
-rw-r--r--tensorflow/python/estimator/inputs/pandas_io.py2
-rw-r--r--tensorflow/python/feature_column/feature_column.py14
-rwxr-xr-xtensorflow/python/keras/BUILD14
-rw-r--r--tensorflow/python/keras/_impl/keras/backend.py14
-rw-r--r--tensorflow/python/keras/_impl/keras/engine/topology.py4
-rw-r--r--tensorflow/python/keras/_impl/keras/engine/training.py246
-rw-r--r--tensorflow/python/keras/_impl/keras/engine/training_eager.py666
-rw-r--r--tensorflow/python/keras/_impl/keras/engine/training_eager_test.py755
-rw-r--r--tensorflow/python/keras/_impl/keras/engine/training_test.py18
-rw-r--r--tensorflow/python/keras/_impl/keras/layers/core.py4
-rw-r--r--tensorflow/python/keras/_impl/keras/layers/normalization.py3
-rw-r--r--tensorflow/python/keras/_impl/keras/optimizers.py10
-rw-r--r--tensorflow/python/kernel_tests/extract_image_patches_op_test.py18
-rw-r--r--tensorflow/python/layers/base.py10
-rw-r--r--tensorflow/python/layers/network.py19
-rw-r--r--tensorflow/python/lib/io/file_io.py13
-rw-r--r--tensorflow/python/lib/io/tf_record.py5
-rw-r--r--tensorflow/python/ops/losses/losses_impl.py13
-rw-r--r--tensorflow/python/ops/losses/util.py6
-rw-r--r--tensorflow/python/platform/app.py2
-rw-r--r--tensorflow/python/platform/resource_loader.py6
-rw-r--r--tensorflow/python/platform/tf_logging.py22
-rw-r--r--tensorflow/python/profiler/model_analyzer.py4
-rw-r--r--tensorflow/python/profiler/option_builder.py2
-rw-r--r--tensorflow/python/profiler/tfprof_logger.py2
-rw-r--r--tensorflow/python/summary/writer/writer.py2
-rw-r--r--tensorflow/python/summary/writer/writer_cache.py2
-rw-r--r--tensorflow/python/training/adadelta.py2
-rw-r--r--tensorflow/python/training/adagrad.py2
-rw-r--r--tensorflow/python/training/adagrad_da.py2
-rw-r--r--tensorflow/python/training/adam.py2
-rw-r--r--tensorflow/python/training/basic_loops.py2
-rw-r--r--tensorflow/python/training/basic_session_run_hooks.py14
-rw-r--r--tensorflow/python/training/checkpoint_utils.py5
-rw-r--r--tensorflow/python/training/coordinator.py3
-rw-r--r--tensorflow/python/training/device_setter.py2
-rw-r--r--tensorflow/python/training/ftrl.py3
-rw-r--r--tensorflow/python/training/gradient_descent.py2
-rw-r--r--tensorflow/python/training/input.py15
-rw-r--r--tensorflow/python/training/learning_rate_decay.py10
-rw-r--r--tensorflow/python/training/momentum.py2
-rw-r--r--tensorflow/python/training/monitored_session.py8
-rw-r--r--tensorflow/python/training/moving_averages.py2
-rw-r--r--tensorflow/python/training/optimizer.py2
-rw-r--r--tensorflow/python/training/proximal_adagrad.py2
-rw-r--r--tensorflow/python/training/proximal_gradient_descent.py2
-rw-r--r--tensorflow/python/training/queue_runner_impl.py5
-rw-r--r--tensorflow/python/training/rmsprop.py2
-rw-r--r--tensorflow/python/training/saver.py10
-rw-r--r--tensorflow/python/training/server_lib.py3
-rw-r--r--tensorflow/python/training/session_manager.py2
-rw-r--r--tensorflow/python/training/session_run_hook.py5
-rw-r--r--tensorflow/python/training/supervisor.py2
-rw-r--r--tensorflow/python/training/sync_replicas_optimizer.py2
-rw-r--r--tensorflow/python/training/training_util.py6
-rw-r--r--tensorflow/python/util/compat.py9
-rw-r--r--tensorflow/tools/api/generator/BUILD10
-rwxr-xr-xtensorflow/tools/ci_build/ci_sanity.sh9
-rwxr-xr-xtensorflow/tools/dist_test/build_server.sh21
-rw-r--r--tensorflow/tools/graph_transforms/sparsify_gather.cc80
-rw-r--r--tensorflow/tools/graph_transforms/sparsify_gather_test.cc86
-rw-r--r--tensorflow/tools/pip_package/BUILD3
-rw-r--r--tensorflow/tools/pip_package/pip_smoke_test.py1
-rw-r--r--tensorflow/tools/pip_package/setup.py1
-rw-r--r--tensorflow/tools/test/file_name_test.py48
-rw-r--r--tensorflow/tools/test/run_and_gather_logs_lib.py2
-rw-r--r--tensorflow/workspace.bzl32
160 files changed, 4559 insertions, 837 deletions
diff --git a/tensorflow/BUILD b/tensorflow/BUILD
index 9e69613c79..2fa02a9b4c 100644
--- a/tensorflow/BUILD
+++ b/tensorflow/BUILD
@@ -550,6 +550,7 @@ filegroup(
"//tensorflow/contrib/predictor:all_files",
"//tensorflow/contrib/py2tf:all_files",
"//tensorflow/contrib/py2tf/converters:all_files",
+ "//tensorflow/contrib/py2tf/impl:all_files",
"//tensorflow/contrib/py2tf/pyct:all_files",
"//tensorflow/contrib/py2tf/pyct/static_analysis:all_files",
"//tensorflow/contrib/quantize:all_files",
diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD
index 9e64f3e9a3..7277ba42ce 100644
--- a/tensorflow/compiler/tests/BUILD
+++ b/tensorflow/compiler/tests/BUILD
@@ -256,6 +256,18 @@ tf_xla_py_test(
)
tf_xla_py_test(
+ name = "extract_image_patches_op_test",
+ size = "small",
+ srcs = ["extract_image_patches_op_test.py"],
+ deps = [
+ ":xla_test",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:platform_test",
+ ],
+)
+
+tf_xla_py_test(
name = "fft_test",
size = "medium",
srcs = ["fft_test.py"],
diff --git a/tensorflow/compiler/tests/extract_image_patches_op_test.py b/tensorflow/compiler/tests/extract_image_patches_op_test.py
new file mode 100644
index 0000000000..0361702e7a
--- /dev/null
+++ b/tensorflow/compiler/tests/extract_image_patches_op_test.py
@@ -0,0 +1,134 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Functional tests for ExtractImagePatches op."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import array_ops
+from tensorflow.python.platform import test
+
+
+class ExtractImagePatches(XLATestCase):
+ """Functional tests for ExtractImagePatches op."""
+
+ def _VerifyValues(self, image, ksizes, strides, rates, padding, patches):
+ """Tests input-output pairs for the ExtractImagePatches op.
+
+ Args:
+ image: Input tensor with shape: [batch, in_rows, in_cols, depth].
+ ksizes: Patch size specified as: [ksize_rows, ksize_cols].
+ strides: Output strides, specified as [stride_rows, stride_cols].
+ rates: Atrous rates, specified as [rate_rows, rate_cols].
+ padding: Padding type.
+ patches: Expected output.
+ """
+ ksizes = [1] + ksizes + [1]
+ strides = [1] + strides + [1]
+ rates = [1] + rates + [1]
+
+ with self.test_session():
+ image_placeholder = array_ops.placeholder(dtypes.float32)
+ with self.test_scope():
+ out_tensor = array_ops.extract_image_patches(
+ image_placeholder,
+ ksizes=ksizes,
+ strides=strides,
+ rates=rates,
+ padding=padding,
+ name="im2col")
+ feed_dict = {image_placeholder: image}
+ self.assertAllClose(patches, out_tensor.eval(feed_dict=feed_dict))
+
+ def testKsize1x1Stride1x1Rate1x1(self):
+ """Verifies that for 1x1 kernel the output equals the input."""
+ # [2, 3, 4, 5]
+ image = np.reshape(range(120), [2, 3, 4, 5])
+ # [2, 3, 4, 5]
+ patches = np.reshape(range(120), [2, 3, 4, 5])
+ for padding in ["VALID", "SAME"]:
+ self._VerifyValues(
+ image,
+ ksizes=[1, 1],
+ strides=[1, 1],
+ rates=[1, 1],
+ padding=padding,
+ patches=patches)
+
+ def testKsize1x1Stride2x3Rate1x1(self):
+ """Test for 1x1 kernel and strides."""
+ # [2, 4, 5, 3]
+ image = np.reshape(range(120), [2, 4, 5, 3])
+ # [2, 2, 2, 3]
+ patches = image[:, ::2, ::3, :]
+ for padding in ["VALID", "SAME"]:
+ self._VerifyValues(
+ image,
+ ksizes=[1, 1],
+ strides=[2, 3],
+ rates=[1, 1],
+ padding=padding,
+ patches=patches)
+
+ def testKsize2x2Stride1x1Rate1x1Valid(self):
+ """Test for 2x2 kernel with VALID padding."""
+ # [1, 2, 2, 1]
+ image = [[[[1], [2]], [[3], [4]]]]
+ # [1, 1, 1, 4]
+ patches = [[[[1, 2, 3, 4]]]]
+ self._VerifyValues(
+ image,
+ ksizes=[2, 2],
+ strides=[1, 1],
+ rates=[1, 1],
+ padding="VALID",
+ patches=patches)
+
+ def testKsize2x2Stride1x1Rate1x1Same(self):
+ """Test for 2x2 kernel with SAME padding."""
+ # [1, 2, 2, 1]
+ image = [[[[1], [2]], [[3], [4]]]]
+ # [1, 2, 2, 4]
+ patches = [[[[1, 2, 3, 4], [2, 0, 4, 0]], [[3, 4, 0, 0], [4, 0, 0, 0]]]]
+ self._VerifyValues(
+ image,
+ ksizes=[2, 2],
+ strides=[1, 1],
+ rates=[1, 1],
+ padding="SAME",
+ patches=patches)
+
+ def testKsize2x2Stride1x1Rate2x2Valid(self):
+ """Test for 2x2 kernel with 2x2 dilation."""
+ # [1, 2, 2, 1]
+ image = np.arange(16).reshape(1, 4, 4, 1).astype(np.float32)
+ # [1, 2, 2, 4]
+ patches = [[[[0, 2, 8, 10], [1, 3, 9, 11]],
+ [[4, 6, 12, 14], [5, 7, 13, 15]]]]
+ self._VerifyValues(
+ image,
+ ksizes=[2, 2],
+ strides=[1, 1],
+ rates=[2, 2],
+ padding="VALID",
+ patches=patches)
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py
index 8e4b8a3833..3d3e112f48 100644
--- a/tensorflow/compiler/tests/unary_ops_test.py
+++ b/tensorflow/compiler/tests/unary_ops_test.py
@@ -154,6 +154,21 @@ class UnaryOpsTest(XLATestCase):
def testFloatOps(self):
for dtype in self.float_types:
+ x = np.arange(-0.90, 0.90, 0.25)
+ self._assertOpOutputMatchesExpected(
+ math_ops.acos,
+ x.astype(dtype),
+ expected=np.arccos(x).astype(dtype))
+ self._assertOpOutputMatchesExpected(
+ math_ops.asin,
+ x.astype(dtype),
+ expected=np.arcsin(x).astype(dtype))
+ x = np.arange(-3, 3).reshape(1, 3, 2)
+ self._assertOpOutputMatchesExpected(
+ math_ops.atan,
+ x.astype(dtype),
+ expected=np.arctan(x).astype(dtype))
+
self._assertOpOutputMatchesExpected(
math_ops.acosh,
np.array([1, 2, 3, 4], dtype=dtype),
diff --git a/tensorflow/compiler/tf2xla/g3doc/cpu_supported_ops.md b/tensorflow/compiler/tf2xla/g3doc/cpu_supported_ops.md
index 44f7db5ffd..91351421bc 100644
--- a/tensorflow/compiler/tf2xla/g3doc/cpu_supported_ops.md
+++ b/tensorflow/compiler/tf2xla/g3doc/cpu_supported_ops.md
@@ -71,6 +71,7 @@ Operator | Type Constraint
`Exp` | `T={complex64,double,float}`
`ExpandDims` | `Tdim={int32,int64}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
`Expm1` | `T={complex64,double,float}`
+`ExtractImagePatches` | `T={double,float,int32,int64,uint32,uint64}`
`FFT` |
`FFT2D` |
`FFT3D` |
@@ -124,6 +125,8 @@ Operator | Type Constraint
`MaxPool3D` | `T={float}`
`MaxPool3DGrad` | `TInput={float}`<br>`T={float}`
`MaxPoolGrad` | `T={double,float,int32,int64,uint32,uint64}`
+`MaxPoolGradV2` | `T={double,float,int32,int64,uint32,uint64}`
+`MaxPoolV2` | `T={double,float,int32,int64}`
`Maximum` | `T={double,float,int32,int64}`
`Mean` | `Tidx={int32,int64}`<br>`T={complex64,double,float,int32,int64,uint32,uint64}`
`Min` | `Tidx={int32,int64}`<br>`T={complex64,double,float,int32,int64,uint32,uint64}`
@@ -176,6 +179,7 @@ Operator | Type Constraint
`ResourceGather` | `Tindices={int32,int64}`<br>`dtype={complex64,double,float,int32,int64,uint32,uint64}`
`ResourceStridedSliceAssign` | `Index={int32,int64}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
`Reverse` | `T={bool,complex64,double,float,int32,int64}`
+`ReverseSequence` | `Tlen={int32,int64}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
`ReverseV2` | `T={bool,complex64,double,float,int32,int64}`<br>`Tidx={int32,int64}`
`RightShift` | `T={int32,int64,uint32,uint64}`
`Rint` | `T={double,float}`
diff --git a/tensorflow/compiler/tf2xla/g3doc/gpu_supported_ops.md b/tensorflow/compiler/tf2xla/g3doc/gpu_supported_ops.md
index eb1f891125..b9bdb829d7 100644
--- a/tensorflow/compiler/tf2xla/g3doc/gpu_supported_ops.md
+++ b/tensorflow/compiler/tf2xla/g3doc/gpu_supported_ops.md
@@ -71,6 +71,7 @@ Operator | Type Constraint
`Exp` | `T={complex64,double,float}`
`ExpandDims` | `Tdim={int32,int64}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
`Expm1` | `T={complex64,double,float}`
+`ExtractImagePatches` | `T={double,float,int32,int64,uint32,uint64}`
`FFT` |
`FFT2D` |
`FFT3D` |
@@ -124,6 +125,8 @@ Operator | Type Constraint
`MaxPool3D` | `T={float}`
`MaxPool3DGrad` | `TInput={float}`<br>`T={float}`
`MaxPoolGrad` | `T={double,float,int32,int64,uint32,uint64}`
+`MaxPoolGradV2` | `T={double,float,int32,int64,uint32,uint64}`
+`MaxPoolV2` | `T={double,float,int32,int64}`
`Maximum` | `T={double,float,int32,int64}`
`Mean` | `Tidx={int32,int64}`<br>`T={complex64,double,float,int32,int64,uint32,uint64}`
`Min` | `Tidx={int32,int64}`<br>`T={complex64,double,float,int32,int64,uint32,uint64}`
@@ -173,6 +176,7 @@ Operator | Type Constraint
`ResourceGather` | `Tindices={int32,int64}`<br>`dtype={complex64,double,float,int32,int64,uint32,uint64}`
`ResourceStridedSliceAssign` | `Index={int32,int64}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
`Reverse` | `T={bool,complex64,double,float,int32,int64}`
+`ReverseSequence` | `Tlen={int32,int64}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
`ReverseV2` | `T={bool,complex64,double,float,int32,int64}`<br>`Tidx={int32,int64}`
`RightShift` | `T={int32,int64,uint32,uint64}`
`Rint` | `T={double,float}`
diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD
index 84fa43f4fb..67be1a4ba6 100644
--- a/tensorflow/compiler/tf2xla/kernels/BUILD
+++ b/tensorflow/compiler/tf2xla/kernels/BUILD
@@ -31,6 +31,7 @@ tf_kernel_library(
"diag_op.cc",
"dynamic_stitch_op.cc",
"elu_op.cc",
+ "extract_image_patches_op.cc",
"fft_ops.cc",
"fill_op.cc",
"function_ops.cc",
diff --git a/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc b/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc
new file mode 100644
index 0000000000..b2970eae20
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc
@@ -0,0 +1,169 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/tf2xla/type_util.h"
+#include "tensorflow/compiler/tf2xla/xla_helpers.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/core/util/tensor_format.h"
+
+namespace tensorflow {
+
+namespace {
+
+class ExtractImagePatchesOp : public XlaOpKernel {
+ public:
+ explicit ExtractImagePatchesOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("ksizes", &ksizes_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &strides_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("rates", &dilations_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_));
+ }
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ const TensorFormat data_format = FORMAT_NHWC;
+ const int num_dims = ksizes_.size();
+
+ OP_REQUIRES(
+ ctx, num_dims >= 3,
+ errors::InvalidArgument("Kernel size must have at least 3 dimensions"));
+ const int num_spatial_dims = num_dims - 2;
+
+ OP_REQUIRES(ctx, strides_.size() == num_dims,
+ errors::InvalidArgument("Sliding window strides field must "
+ "specify ",
+ num_dims, " dimensions"));
+ OP_REQUIRES(ctx, dilations_.size() == num_dims,
+ errors::InvalidArgument("Dilations field must "
+ "specify ",
+ num_dims, " dimensions"));
+
+ int batch_dim = GetTensorBatchDimIndex(num_dims, data_format);
+ int feature_dim = GetTensorFeatureDimIndex(num_dims, data_format);
+ OP_REQUIRES(
+ ctx, ksizes_[batch_dim] == 1 && ksizes_[feature_dim] == 1,
+ errors::Unimplemented("Current implementation does not yet support "
+ "kernel sizes > 1 in the batch and depth "
+ "dimensions."));
+ OP_REQUIRES(
+ ctx, strides_[batch_dim] == 1 && strides_[feature_dim] == 1,
+ errors::Unimplemented("Current implementation does not yet support "
+ "strides in the batch and depth dimensions."));
+ OP_REQUIRES(
+ ctx, dilations_[batch_dim] == 1 && dilations_[feature_dim] == 1,
+ errors::Unimplemented("Current implementation does not support "
+ "dilations in the batch and depth dimensions."));
+
+ for (int i = 0; i < num_spatial_dims; ++i) {
+ int input_dim = GetTensorSpatialDimIndex(num_dims, data_format, i);
+ OP_REQUIRES(
+ ctx, ksizes_[input_dim] >= 0,
+ errors::Unimplemented("Kernel size values must be non-negative; ", i,
+ "th spatial dimension had dilation ",
+ dilations_[input_dim]));
+ OP_REQUIRES(ctx, strides_[input_dim] >= 1,
+ errors::Unimplemented("Stride values must be positive; ", i,
+ "th spatial dimension had dilation ",
+ dilations_[input_dim]));
+ OP_REQUIRES(ctx, dilations_[input_dim] >= 1,
+ errors::Unimplemented("Dilation values must be positive; ", i,
+ "th spatial dimension had dilation ",
+ dilations_[input_dim]));
+ }
+
+ xla::PrimitiveType type;
+ OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(ctx->input_type(0), &type));
+
+ const TensorShape input_shape = ctx->InputShape(0);
+ OP_REQUIRES(
+ ctx, input_shape.dims() == num_dims,
+ errors::InvalidArgument("input must be ", num_dims, "-dimensional",
+ input_shape.DebugString()));
+ const int64 depth = input_shape.dim_size(feature_dim);
+
+ xla::ComputationBuilder* builder = ctx->builder();
+
+ // The following code is equivalent to:
+ // eye = np.eye(kH * kW * D).reshape([kH, kW, D, kH * kW * kD])
+ int64 kernel_size = 1;
+ std::vector<int64> lhs_shape(num_dims, 1);
+ for (int i = 0; i < num_spatial_dims; ++i) {
+ int input_dim = GetTensorSpatialDimIndex(num_dims, data_format, i);
+ lhs_shape[i] = ksizes_[input_dim];
+ kernel_size *= ksizes_[input_dim];
+ }
+ lhs_shape[num_spatial_dims] = depth;
+ lhs_shape[num_spatial_dims + 1] = 1;
+
+ // Builds an identity matrix as a broadcast equality of iotas.
+ // iota = np.arange(np.prod(ksize), depth)
+ // filter = np.equal(np.reshape(iota, [-1, 1]), iota).astype(np.float32)
+ xla::ComputationDataHandle iota;
+ TF_CHECK_OK(XlaHelpers::Iota(builder, DataType::DT_INT32,
+ kernel_size * depth, &iota));
+
+ auto lhs = builder->Reshape(iota, lhs_shape);
+ auto filter = builder->ConvertElementType(
+ builder->Eq(lhs, iota, {num_spatial_dims + 1}), type);
+
+ xla::ConvolutionDimensionNumbers dims;
+ std::vector<int64> window_strides(num_spatial_dims);
+ std::vector<int64> lhs_dilation(num_spatial_dims, 1);
+ std::vector<int64> rhs_dilation(num_spatial_dims);
+ std::vector<std::pair<int64, int64>> padding(num_spatial_dims);
+
+ dims.set_input_batch_dimension(batch_dim);
+ dims.set_output_batch_dimension(batch_dim);
+ dims.set_input_feature_dimension(feature_dim);
+ dims.set_output_feature_dimension(feature_dim);
+ dims.set_kernel_input_feature_dimension(num_spatial_dims);
+ dims.set_kernel_output_feature_dimension(num_spatial_dims + 1);
+
+ for (int i = 0; i < num_spatial_dims; ++i) {
+ const int64 dim = GetTensorSpatialDimIndex(num_dims, data_format, i);
+ dims.add_input_spatial_dimensions(dim);
+ dims.add_kernel_spatial_dimensions(i);
+ dims.add_output_spatial_dimensions(dim);
+ window_strides[i] = strides_.at(dim);
+ rhs_dilation[i] = dilations_.at(dim);
+
+ int64 unused_output_size;
+ OP_REQUIRES_OK(
+ ctx, GetWindowedOutputSizeVerboseV2(
+ input_shape.dim_size(dim), ksizes_[dim], rhs_dilation[i],
+ window_strides[i], padding_, &unused_output_size,
+ &padding[i].first, &padding[i].second));
+ }
+
+ xla::ComputationDataHandle conv =
+ builder->ConvGeneralDilated(ctx->Input(0), filter, window_strides,
+ padding, lhs_dilation, rhs_dilation, dims);
+ ctx->SetOutput(0, conv);
+ }
+
+ protected:
+ std::vector<int32> ksizes_;
+ std::vector<int32> dilations_;
+ std::vector<int32> strides_;
+ Padding padding_;
+
+ private:
+ TF_DISALLOW_COPY_AND_ASSIGN(ExtractImagePatchesOp);
+};
+
+REGISTER_XLA_OP(Name("ExtractImagePatches"), ExtractImagePatchesOp);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc
index a266e9013c..0c5ad9e525 100644
--- a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc
@@ -50,18 +50,41 @@ XLAJIT_MAKE_UNARY(Conj, b->Conj(x));
// Return x if x>0, otherwise -x.
XLAJIT_MAKE_UNARY(Abs, b->Abs(x));
+// acos(x) = 2 * atan(sqrt(1 - x^2) / (1 + x))
+XLAJIT_MAKE_UNARY(
+ Acos,
+ b->Mul(XlaHelpers::FloatLiteral(b, input_type(0), 2.0),
+ b->Atan2(b->Pow(b->Sub(XlaHelpers::One(b, input_type(0)),
+ b->Mul(x, x)),
+ XlaHelpers::FloatLiteral(b, input_type(0), 0.5)),
+ b->Add(XlaHelpers::One(b, input_type(0)), x))));
+
// acosh(x) = log(x + sqrt(x^2 - 1))
XLAJIT_MAKE_UNARY(
Acosh,
b->Log(b->Add(x, b->Pow(b->Sub(b->Mul(x, x),
XlaHelpers::One(b, input_type(0))),
XlaHelpers::FloatLiteral(b, input_type(0), 0.5)))));
+
+// asin(x) = 2 * atan(x / (1 + sqrt(1 - x^2)))
+XLAJIT_MAKE_UNARY(
+ Asin,
+ b->Mul(XlaHelpers::FloatLiteral(b, input_type(0), 2.0),
+ b->Atan2(x, b->Add(XlaHelpers::One(b, input_type(0)),
+ b->Pow(b->Sub(XlaHelpers::One(b, input_type(0)),
+ b->Mul(x, x)),
+ XlaHelpers::FloatLiteral(b, input_type(0),
+ 0.5))))));
+
// asinh(x) = log(x + sqrt(x^2 + 1))
XLAJIT_MAKE_UNARY(
Asinh,
b->Log(b->Add(x, b->Pow(b->Add(b->Mul(x, x),
XlaHelpers::One(b, input_type(0))),
XlaHelpers::FloatLiteral(b, input_type(0), 0.5)))));
+
+XLAJIT_MAKE_UNARY(Atan, b->Atan2(x, XlaHelpers::One(b, input_type(0))));
+
// atanh(x) = 0.5 * log((1 + x) / (1 - x))
XLAJIT_MAKE_UNARY(
Atanh, b->Mul(b->Log(b->Div(b->Add(XlaHelpers::One(b, input_type(0)), x),
diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD
index c22fd37129..34e733bc8d 100644
--- a/tensorflow/compiler/xla/BUILD
+++ b/tensorflow/compiler/xla/BUILD
@@ -88,7 +88,6 @@ cc_library(
visibility = [":friends"],
deps = [
"//tensorflow/core:framework_lite",
- "//tensorflow/core:lib",
"//third_party/eigen3",
],
)
diff --git a/tensorflow/compiler/xla/python/local_computation_builder.cc b/tensorflow/compiler/xla/python/local_computation_builder.cc
index 5772532b84..67a73bc33d 100644
--- a/tensorflow/compiler/xla/python/local_computation_builder.cc
+++ b/tensorflow/compiler/xla/python/local_computation_builder.cc
@@ -367,12 +367,6 @@ LocalComputationBuilder::SelectAndScatterWithGeneralPadding(
source, init_value, scatter.computation());
}
-ComputationDataHandle LocalComputationBuilder::Select(
- const ComputationDataHandle& pred, const ComputationDataHandle& on_true,
- const ComputationDataHandle& on_false) {
- return builder_.Select(pred, on_true, on_false);
-}
-
ComputationDataHandle LocalComputationBuilder::Tuple(
tensorflow::gtl::ArraySlice<ComputationDataHandle> elements) {
return builder_.Tuple(elements);
@@ -487,6 +481,15 @@ ComputationDataHandle LocalComputationBuilder::While(
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions), \
(lhs, rhs, broadcast_dimensions))
+#define _FORWARD_TRIOP(method_name) \
+ _FORWARD( \
+ method_name, ComputationDataHandle, \
+ (const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, \
+ const ComputationDataHandle& ehs), \
+ (lhs, rhs, ehs))
+
+_FORWARD_TRIOP(Select)
+_FORWARD_TRIOP(Clamp)
_FORWARD_BINOP(Eq)
_FORWARD_BINOP(Ne)
_FORWARD_BINOP(Ge)
@@ -507,6 +510,7 @@ _FORWARD_UNOP(Abs)
_FORWARD_UNOP(Exp)
_FORWARD_UNOP(Floor)
_FORWARD_UNOP(Ceil)
+_FORWARD_UNOP(Round)
_FORWARD_UNOP(Log)
_FORWARD_UNOP(Sign)
_FORWARD_UNOP(Cos)
@@ -523,6 +527,7 @@ _FORWARD_UNOP(Sort)
#undef _FORWARD
#undef _FORWARD_UNOP
#undef _FORWARD_BINOP
+#undef _FORWARD_TRIOP
void DeleteLocalShapedBuffer(LocalShapedBuffer* local_shaped_buffer) {
delete local_shaped_buffer;
diff --git a/tensorflow/compiler/xla/python/local_computation_builder.h b/tensorflow/compiler/xla/python/local_computation_builder.h
index 6851c2644d..d5c4c58040 100644
--- a/tensorflow/compiler/xla/python/local_computation_builder.h
+++ b/tensorflow/compiler/xla/python/local_computation_builder.h
@@ -174,10 +174,6 @@ class LocalComputationBuilder {
const ComputationDataHandle& source,
const ComputationDataHandle& init_value, const LocalComputation& scatter);
- ComputationDataHandle Select(const ComputationDataHandle& pred,
- const ComputationDataHandle& on_true,
- const ComputationDataHandle& on_false);
-
ComputationDataHandle Tuple(
tensorflow::gtl::ArraySlice<ComputationDataHandle> elements);
@@ -254,6 +250,14 @@ class LocalComputationBuilder {
(const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, \
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions))
+#define _FORWARD_TRIOP(method_name) \
+ _FORWARD( \
+ method_name, ComputationDataHandle, \
+ (const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, \
+ const ComputationDataHandle& ehs))
+
+ _FORWARD_TRIOP(Select)
+ _FORWARD_TRIOP(Clamp)
_FORWARD_BINOP(Eq)
_FORWARD_BINOP(Ne)
_FORWARD_BINOP(Ge)
@@ -274,6 +278,7 @@ class LocalComputationBuilder {
_FORWARD_UNOP(Exp)
_FORWARD_UNOP(Floor)
_FORWARD_UNOP(Ceil)
+ _FORWARD_UNOP(Round)
_FORWARD_UNOP(Log)
_FORWARD_UNOP(Sign)
_FORWARD_UNOP(Cos)
@@ -290,6 +295,7 @@ class LocalComputationBuilder {
#undef _FORWARD
#undef _FORWARD_UNOP
#undef _FORWARD_BINOP
+#undef _FORWARD_TRIOP
private:
ComputationBuilder builder_;
diff --git a/tensorflow/compiler/xla/python/local_computation_builder.i b/tensorflow/compiler/xla/python/local_computation_builder.i
index 6a52a088dd..89f8385501 100644
--- a/tensorflow/compiler/xla/python/local_computation_builder.i
+++ b/tensorflow/compiler/xla/python/local_computation_builder.i
@@ -701,6 +701,7 @@ tensorflow::ImportNumpy();
%unignore xla::swig::LocalComputationBuilder::Call;
%unignore xla::swig::LocalComputationBuilder::Transpose;
%unignore xla::swig::LocalComputationBuilder::Rev;
+%unignore xla::swig::LocalComputationBuilder::Clamp;
%unignore xla::swig::LocalComputationBuilder::Map;
%unignore xla::swig::LocalComputationBuilder::Reduce;
%unignore xla::swig::LocalComputationBuilder::ReduceWindowWithGeneralPadding;
@@ -730,6 +731,7 @@ tensorflow::ImportNumpy();
%unignore xla::swig::LocalComputationBuilder::Exp;
%unignore xla::swig::LocalComputationBuilder::Floor;
%unignore xla::swig::LocalComputationBuilder::Ceil;
+%unignore xla::swig::LocalComputationBuilder::Round;
%unignore xla::swig::LocalComputationBuilder::Log;
%unignore xla::swig::LocalComputationBuilder::Sign;
%unignore xla::swig::LocalComputationBuilder::Cos;
diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py
index a89e2643c8..7ee5febc09 100644
--- a/tensorflow/compiler/xla/python/xla_client.py
+++ b/tensorflow/compiler/xla/python/xla_client.py
@@ -89,6 +89,7 @@ _UNARY_OPS = [
'Abs',
'Exp',
'Floor',
+ 'Round',
'Ceil',
'Log',
'Sign',
@@ -619,6 +620,13 @@ class ComputationBuilder(object):
return _wrap_data_handle(
self._client.Rev(_unwrap_data_handle(operand), dimensions))
+ def Clamp(self, min, operand, max): # pylint: disable=redefined-builtin
+ """Clamp op."""
+ return _wrap_data_handle(
+ self._client.Clamp(_unwrap_data_handle(min),
+ _unwrap_data_handle(operand),
+ _unwrap_data_handle(max)))
+
def SelectAndScatter(self, operand, select, window_dimensions, window_strides,
padding, source, init_value, scatter):
"""Select and scatter op, used by the gradient of ReduceWindow.
diff --git a/tensorflow/compiler/xla/python/xla_client_test.py b/tensorflow/compiler/xla/python/xla_client_test.py
index c0413b9bbc..3b5bbfd786 100644
--- a/tensorflow/compiler/xla/python/xla_client_test.py
+++ b/tensorflow/compiler/xla/python/xla_client_test.py
@@ -496,6 +496,12 @@ class SingleOpTest(LocalComputationTest):
c.Exp(c.Constant(arr))
self._ExecuteAndCompareClose(c, expected=np.exp(arr))
+ def testRound(self):
+ c = self._NewComputation()
+ arr = NumpyArrayF32([3.3, 12.1])
+ c.Round(c.Constant(arr))
+ self._ExecuteAndCompareClose(c, expected=np.round(arr))
+
def testLog(self):
c = self._NewComputation()
arr = NumpyArrayF32([3.3, 12.1])
@@ -699,6 +705,23 @@ class SingleOpTest(LocalComputationTest):
self._ExecuteAndCompareExact(
c, expected=[[[6, 5], [8, 7]], [[2, 1], [4, 3]]])
+ def testClampF32(self):
+ c = self._NewComputation()
+ c.Clamp(
+ c.Constant(NumpyArrayF32(-1)),
+ c.Constant(NumpyArrayF32([-2, -1, 0, 1, 2, 3])),
+ c.Constant(NumpyArrayF32(2)))
+ self._ExecuteAndCompareExact(c, expected=[-1, -1, 0, 1, 2, 2])
+
+ # TODO(b/72689392): re-enable when bug S32 resolved
+ def DISABLED_testClampS32(self):
+ c = self._NewComputation()
+ c.Clamp(
+ c.Constant(NumpyArrayS32(-1)),
+ c.Constant(NumpyArrayS32([-2, -1, 0, 1, 2, 3])),
+ c.Constant(NumpyArrayS32(2)))
+ self._ExecuteAndCompareExact(c, expected=[-1, 0, 1, 2, 2])
+
def testSelect(self):
c = self._NewComputation()
c.Select(
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
index ba82e822b2..fb857559f9 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
@@ -1618,9 +1618,12 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* reduce) {
reduce,
HloInstruction::CreateBroadcast(reduce->shape(), init_value, {}));
}
+
// A Transpose feeding a reduce can simply permute the reduction dimensions
- // field.
- if (arg->opcode() == HloOpcode::kTranspose) {
+ // field if the output of the reduce is a vector or scalar. Higher ranked
+ // result may require a transpose of the output.
+ if (ShapeUtil::Rank(reduce->shape()) <= 1 &&
+ arg->opcode() == HloOpcode::kTranspose) {
auto transpose_dimensions = arg->dimensions();
std::vector<int64> new_reduce_dimensions;
for (auto dim : dimensions) {
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
index 71e8133189..0b2d3d4746 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
@@ -479,7 +479,7 @@ Status IrEmitter::HandleOutfeed(HloInstruction* outfeed) {
Status IrEmitter::HandleSort(HloInstruction* sort) {
// TODO(b/26783907): Implement sort on CPU.
- return Unimplemented("Sort is not supported on CPU (b/26783907).");
+ return Unimplemented("Sort is not implemented on CPU.");
}
Status IrEmitter::HandleTuple(HloInstruction* tuple) {
@@ -522,7 +522,7 @@ Status IrEmitter::HandleReduceWindow(HloInstruction* reduce_window) {
// TODO(b/31410564): Implement dilation for reduce-window.
if (window_util::HasDilation(window)) {
return Unimplemented(
- "Dilation for reduce-window not implemented on CPU. See b/31410564.");
+ "Dilation for ReduceWindow is not implemented on CPU.");
}
// The called computation should have been emitted previously.
@@ -625,8 +625,7 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) {
// TODO(b/31410564): Implement dilation for select-and-scatter.
if (window_util::HasDilation(window)) {
return Unimplemented(
- "Dilation for select-and-scatter not implemented on CPU. "
- "See b/31410564.");
+ "Dilation for SelectAndScatter is not implemented on CPU. ");
}
// The select and scatter computations should have been emitted previously.
@@ -1196,8 +1195,7 @@ Status IrEmitter::HandleCrossReplicaSum(HloInstruction* crs) {
}
// TODO(b/33011107): Support cross replica sum on CPU.
- return Unimplemented(
- "Cross replica sum is not implemented on CPU. See b/33011107.");
+ return Unimplemented("CrossReplicaSum is not implemented on CPU.");
}
// Fills up the free variables in 'index_with_free_var' with values from
@@ -1811,12 +1809,12 @@ Status IrEmitter::HandleReduce(HloInstruction* reduce) {
Status IrEmitter::HandleSend(HloInstruction* send) {
// TODO(b/33942983): Support Send/Recv on CPU.
- return Unimplemented("Send is not implemented on CPU. See b/33942983.");
+ return Unimplemented("Send is not implemented on CPU.");
}
Status IrEmitter::HandleSendDone(HloInstruction* send_done) {
// TODO(b/33942983): Support Send/Recv on CPU.
- return Unimplemented("Send-done is not implemented on CPU. See b/33942983.");
+ return Unimplemented("Send-done is not implemented on CPU.");
}
Status IrEmitter::HandleSlice(HloInstruction* slice) {
@@ -1981,12 +1979,12 @@ Status IrEmitter::HandleDynamicUpdateSlice(
Status IrEmitter::HandleRecv(HloInstruction* recv) {
// TODO(b/33942983): Support Send/Recv on CPU.
- return Unimplemented("Recv is not implemented on CPU. See b/33942983.");
+ return Unimplemented("Recv is not implemented on CPU.");
}
Status IrEmitter::HandleRecvDone(HloInstruction* recv_done) {
// TODO(b/33942983): Support Send/Recv on CPU.
- return Unimplemented("Recv-done is not implemented on CPU. See b/33942983.");
+ return Unimplemented("Recv-done is not implemented on CPU.");
}
Status IrEmitter::HandlePad(HloInstruction* pad) {
@@ -1995,10 +1993,10 @@ Status IrEmitter::HandlePad(HloInstruction* pad) {
for (auto& padding_dimension : pad->padding_config().dimensions()) {
if (padding_dimension.edge_padding_low() < 0 ||
padding_dimension.edge_padding_high() < 0) {
- return Unimplemented(
- "Negative padding not supported in the CPU backend (b/34628603); "
- "this should have been eliminated at the HLO level: %s",
- pad->padding_config().ShortDebugString().c_str());
+ return InternalErrorStrCat(
+ "Encountered negative padding in IrEmitter on CPU. "
+ "This should have been eliminated at the HLO level. ",
+ pad->ToString());
}
}
diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
index 9780bac16e..4468adbadb 100644
--- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
@@ -428,7 +428,7 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatUnaryOp(
llvm::Intrinsic::round, {operand_value}, {operand_value->getType()},
ir_builder_);
case HloOpcode::kSign: {
- // TODO(b/32151903): Ensure consistent sign behavior for -0.0
+ // TODO(b/32151903): Ensure consistent sign behavior for -0.0.
auto type = operand_value->getType();
auto zero = llvm::ConstantFP::get(type, 0.0);
auto oeq = ir_builder_->CreateFCmpOEQ(operand_value, zero);
@@ -870,7 +870,10 @@ llvm::Value* ElementalIrEmitter::EmitFloatMin(llvm::Value* lhs_value,
StatusOr<llvm::Value*> ElementalIrEmitter::EmitErfInv(PrimitiveType prim_type,
llvm::Value* x) const {
if (prim_type != F32) {
- return Unimplemented("inverse erf only implemented for F32 (b/34339814)");
+ // TODO(b/34339814): Implement inverse erf for F64.
+ return Unimplemented(
+ "Inverse erf is only implemented for element "
+ "type F32.");
}
auto getFloat = [&](const float f) {
return llvm::ConstantFP::get(ir_builder_->getFloatTy(), f);
@@ -1040,17 +1043,9 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerBinaryOp(
is_signed ? llvm::CmpInst::ICMP_SGE : llvm::CmpInst::ICMP_UGE,
lhs_value, rhs_value, ir_builder_);
case HloOpcode::kMinimum:
- return ir_builder_->CreateSelect(
- ir_builder_->CreateICmp(
- is_signed ? llvm::ICmpInst::ICMP_SLE : llvm::ICmpInst::ICMP_ULE,
- lhs_value, rhs_value),
- lhs_value, rhs_value);
+ return EmitIntegralMin(lhs_value, rhs_value, is_signed);
case HloOpcode::kMaximum:
- return ir_builder_->CreateSelect(
- ir_builder_->CreateICmp(
- is_signed ? llvm::ICmpInst::ICMP_SGE : llvm::ICmpInst::ICMP_UGE,
- lhs_value, rhs_value),
- lhs_value, rhs_value);
+ return EmitIntegralMax(lhs_value, rhs_value, is_signed);
case HloOpcode::kAnd:
return ir_builder_->CreateAnd(lhs_value, rhs_value);
case HloOpcode::kOr:
@@ -1067,6 +1062,26 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerBinaryOp(
}
}
+llvm::Value* ElementalIrEmitter::EmitIntegralMax(llvm::Value* lhs_value,
+ llvm::Value* rhs_value,
+ bool is_signed) const {
+ return ir_builder_->CreateSelect(
+ ir_builder_->CreateICmp(
+ is_signed ? llvm::ICmpInst::ICMP_SGE : llvm::ICmpInst::ICMP_UGE,
+ lhs_value, rhs_value),
+ lhs_value, rhs_value);
+}
+
+llvm::Value* ElementalIrEmitter::EmitIntegralMin(llvm::Value* lhs_value,
+ llvm::Value* rhs_value,
+ bool is_signed) const {
+ return ir_builder_->CreateSelect(
+ ir_builder_->CreateICmp(
+ is_signed ? llvm::ICmpInst::ICMP_SLE : llvm::ICmpInst::ICMP_ULE,
+ lhs_value, rhs_value),
+ lhs_value, rhs_value);
+}
+
llvm_ir::IrArray::Index ElementalIrEmitter::ElementwiseSourceIndex(
const llvm_ir::IrArray::Index& target_index, const HloInstruction& hlo,
int64 operand_no) const {
@@ -1363,7 +1378,18 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator(
TF_ASSIGN_OR_RETURN(llvm::Value * max_value,
operand_to_generator.at(hlo->operand(2))(
ElementwiseSourceIndex(index, *hlo, 2)));
- return EmitFloatMin(max_value, EmitFloatMax(min_value, arg_value));
+ PrimitiveType prim_type = hlo->shape().element_type();
+ if (primitive_util::IsFloatingPointType(prim_type)) {
+ return EmitFloatMin(max_value, EmitFloatMax(min_value, arg_value));
+ } else if (primitive_util::IsIntegralType(prim_type)) {
+ bool is_signed = primitive_util::IsSignedIntegralType(prim_type);
+ return EmitIntegralMin(
+ max_value, EmitIntegralMax(min_value, arg_value, is_signed),
+ is_signed);
+ } else {
+ return Unimplemented("Clamp unimplemented for %s",
+ PrimitiveType_Name(prim_type).c_str());
+ }
};
case HloOpcode::kReducePrecision:
return [this, hlo, &operand_to_generator](
diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/elemental_ir_emitter.h
index 1a48eb5fcb..c516a826d9 100644
--- a/tensorflow/compiler/xla/service/elemental_ir_emitter.h
+++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.h
@@ -86,6 +86,12 @@ class ElementalIrEmitter {
virtual llvm::Value* EmitFloatMin(llvm::Value* lhs_value,
llvm::Value* rhs_value) const;
+ llvm::Value* EmitIntegralMax(llvm::Value* lhs_value, llvm::Value* rhs_value,
+ bool is_signed) const;
+
+ llvm::Value* EmitIntegralMin(llvm::Value* lhs_value, llvm::Value* rhs_value,
+ bool is_signed) const;
+
virtual StatusOr<llvm::Value*> EmitErfInv(PrimitiveType prim_type,
llvm::Value* value) const;
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc
index 23b72c3f71..affd2ffa8e 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc
@@ -615,8 +615,7 @@ Status IrEmitter::HandleFft(HloInstruction* fft) {
Status IrEmitter::HandleCrossReplicaSum(HloInstruction* crs) {
// TODO(b/33011107): Support cross replica sum on GPU.
- return Unimplemented(
- "Cross replica sum not implemented on GPU. See b/33011107.");
+ return Unimplemented("CrossReplicaSum is not implemented on GPU.");
}
Status IrEmitter::HandleParameter(HloInstruction* parameter) {
@@ -710,11 +709,13 @@ Status IrEmitter::HandleCustomCall(HloInstruction*) {
}
Status IrEmitter::HandleInfeed(HloInstruction*) {
- return Unimplemented("Infeed is not supported on GPU (b/30467474).");
+ // TODO(b/30467474): Implement infeed on GPU.
+ return Unimplemented("Infeed is not supported on GPU.");
}
Status IrEmitter::HandleOutfeed(HloInstruction*) {
- return Unimplemented("Outfeed is not supported on GPU (b/34359662).");
+ // TODO(b/34359662): Implement outfeed on GPU.
+ return Unimplemented("Outfeed is not supported on GPU.");
}
Status IrEmitter::HandleRng(HloInstruction* random) {
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
index fc8783e753..bd428f8028 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
@@ -1658,8 +1658,7 @@ Status IrEmitterUnnested::HandleSelectAndScatter(
// TODO(b/31410564): Implement dilation rate for select-and-scatter.
if (window_util::HasDilation(window)) {
return Unimplemented(
- "Dilation for select-and-scatter not implemented on GPU. "
- "See b/31410564.");
+ "Dilation for SelectAndScatter not implemented on GPU.");
}
// kSelectAndScatter is implemented as two kernel launches: the first launch
diff --git a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc
index 56fc21d019..52e14a1f7b 100644
--- a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc
+++ b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc
@@ -1893,6 +1893,26 @@ XLA_TEST_F(ArrayElementwiseOpTest, ClampF32ScalarVector) {
error_spec_);
}
+XLA_TEST_F(ArrayElementwiseOpTest, ClampS32Vector) {
+ ComputationBuilder builder(client_, TestName());
+ auto min_vector = builder.ConstantR1<int32>({1, -6, 1, 2, 0, -5});
+ auto arg_vector = builder.ConstantR1<int32>({2, 10, -5, 1, 4, 10});
+ auto max_vector = builder.ConstantR1<int32>({3, 0, 25, 5, 123, -1});
+ auto clamp = builder.Clamp(min_vector, arg_vector, max_vector);
+
+ ComputeAndCompareR1<int32>(&builder, {2, 0, 1, 2, 4, -1}, {});
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, ClampU32Vector) {
+ ComputationBuilder builder(client_, TestName());
+ auto min_vector = builder.ConstantR1<uint32>({1, 2, 1, 2, 0, ~0u - 4});
+ auto arg_vector = builder.ConstantR1<uint32>({2, 10, 5, 1, 4, 10});
+ auto max_vector = builder.ConstantR1<uint32>({3, 5, 25, 5, 123, ~0u});
+ auto clamp = builder.Clamp(min_vector, arg_vector, max_vector);
+
+ ComputeAndCompareR1<uint32>(&builder, {2, 5, 5, 2, 4, ~0u - 4}, {});
+}
+
XLA_TEST_F(ArrayElementwiseOpTest, AddTwoParametersF32s) {
ComputationBuilder builder(client_, TestName());
diff --git a/tensorflow/compiler/xla/tests/literal_test_util.cc b/tensorflow/compiler/xla/tests/literal_test_util.cc
index 39c07297d6..474d2547ae 100644
--- a/tensorflow/compiler/xla/tests/literal_test_util.cc
+++ b/tensorflow/compiler/xla/tests/literal_test_util.cc
@@ -376,6 +376,10 @@ class NearComparator {
abs_expected_miscompare_sum_ = 0.0;
max_rel_err_ = 0.0;
max_abs_err_ = 0.0;
+ first_linear_index_ = -1;
+ last_linear_index_ = -1;
+ max_rel_linear_index_ = -1;
+ max_abs_linear_index_ = -1;
miscompares_ = Literal(ShapeUtil::ChangeElementType(actual.shape(), PRED));
miscompares_.PopulateWithValue(false);
multi_index_.resize(expected.shape().dimensions_size(), 0);
@@ -482,11 +486,11 @@ class NearComparator {
const float rel_err = abs_diff / std::abs(expected);
abs_diff_sum_ += abs_diff;
abs_expected_sum_ += std::abs(expected);
- if (rel_err > max_rel_err_) {
+ if (rel_err > max_rel_err_ || std::isnan(rel_err)) {
max_rel_err_ = rel_err;
max_rel_linear_index_ = linear_index;
}
- if (abs_diff > max_abs_err_) {
+ if (abs_diff > max_abs_err_ || std::isnan(abs_diff)) {
max_abs_err_ = abs_diff;
max_abs_linear_index_ = linear_index;
}
diff --git a/tensorflow/compiler/xla/tests/literal_test_util_test.cc b/tensorflow/compiler/xla/tests/literal_test_util_test.cc
index e477784557..3a421f8458 100644
--- a/tensorflow/compiler/xla/tests/literal_test_util_test.cc
+++ b/tensorflow/compiler/xla/tests/literal_test_util_test.cc
@@ -97,5 +97,29 @@ TEST(LiteralTestUtilTest, ExpectNearFailurePlacesResultsInTemporaryDirectory) {
}
}
+TEST(LiteralTestUtilTest, NearComparatorR1) {
+ auto a =
+ Literal::CreateR1<float>({0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8});
+ auto b =
+ Literal::CreateR1<float>({0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8});
+ EXPECT_TRUE(LiteralTestUtil::Near(*a, *b, ErrorSpec{0.0001}));
+}
+
+TEST(LiteralTestUtilTest, NearComparatorR1Nan) {
+ auto a =
+ Literal::CreateR1<float>({0.0, 0.1, 0.2, 0.3, NAN, 0.5, 0.6, 0.7, 0.8});
+ auto b =
+ Literal::CreateR1<float>({0.0, 0.1, 0.2, 0.3, NAN, 0.5, 0.6, 0.7, 0.8});
+ EXPECT_TRUE(LiteralTestUtil::Near(*a, *b, ErrorSpec{0.0001}));
+}
+
+TEST(LiteralTestUtil, NearComparatorDifferentLengths) {
+ auto a =
+ Literal::CreateR1<float>({0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8});
+ auto b = Literal::CreateR1<float>({0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7});
+ EXPECT_FALSE(LiteralTestUtil::Near(*a, *b, ErrorSpec{0.0001}));
+ EXPECT_FALSE(LiteralTestUtil::Near(*b, *a, ErrorSpec{0.0001}));
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/reduce_test.cc b/tensorflow/compiler/xla/tests/reduce_test.cc
index a766fa2db0..50d7b5074d 100644
--- a/tensorflow/compiler/xla/tests/reduce_test.cc
+++ b/tensorflow/compiler/xla/tests/reduce_test.cc
@@ -494,6 +494,26 @@ XLA_TEST_F(ReduceTest, TransposeAndReduceElementwiseR2_111x50_To_R1) {
ErrorSpec(0.01, 1e-4));
}
+// Test that algebraic simplifier does not incorrectly fold a transpose into a
+// reduction operation.
+XLA_TEST_F(ReduceTest, TransposeAndReduceR3_12x111x50_To_R2) {
+ ComputationBuilder builder(client_, TestName());
+ Computation add_f32 = CreateScalarAddComputation(F32, &builder);
+ const Shape input_shape = ShapeUtil::MakeShape(F32, {12, 111, 50});
+ ComputationDataHandle input = builder.Parameter(0, input_shape, "input");
+ ComputationDataHandle zero = builder.ConstantR0<float>(0.0);
+ ComputationDataHandle transpose =
+ builder.Transpose(input, /*permutation=*/{1, 0, 2});
+ ComputationDataHandle reduce =
+ builder.Reduce(transpose, zero, add_f32, /*dimensions_to_reduce=*/{0});
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> input_data,
+ MakeFakeLiteral(input_shape));
+
+ ComputeAndCompare(&builder, reduce, {std::move(*input_data)},
+ ErrorSpec(0.01, 1e-4));
+}
+
XLA_TEST_F(ReduceTest, Reshape_111x2x25Reduce_111x50_To_R1) {
const int64 rows = 111, cols = 50;
diff --git a/tensorflow/compiler/xla/tests/scalar_computations_test.cc b/tensorflow/compiler/xla/tests/scalar_computations_test.cc
index debf2d2d31..43e4d891a1 100644
--- a/tensorflow/compiler/xla/tests/scalar_computations_test.cc
+++ b/tensorflow/compiler/xla/tests/scalar_computations_test.cc
@@ -852,5 +852,12 @@ XLA_TEST_F(ScalarComputationsTest, SqrtF320) {
ComputeAndCompareR0<float>(&builder, 0.0f, {zero_data.get()}, error_spec_);
}
+XLA_TEST_F(ScalarComputationsTest, RoundScalar) {
+ ComputationBuilder builder(client_, TestName());
+ builder.Round(builder.ConstantR0<float>(1.4f));
+
+ ComputeAndCompareR0<float>(&builder, 1.0f, {}, error_spec_);
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD
index 8ec4be6dad..1c497c666b 100644
--- a/tensorflow/contrib/BUILD
+++ b/tensorflow/contrib/BUILD
@@ -24,6 +24,7 @@ py_library(
"//tensorflow/contrib/bayesflow:bayesflow_py",
"//tensorflow/contrib/boosted_trees:init_py",
"//tensorflow/contrib/cloud:cloud_py",
+ "//tensorflow/contrib/cluster_resolver:cluster_resolver_pip",
"//tensorflow/contrib/cluster_resolver:cluster_resolver_py",
"//tensorflow/contrib/coder:coder_ops_py",
"//tensorflow/contrib/compiler:compiler_py",
diff --git a/tensorflow/contrib/cluster_resolver/BUILD b/tensorflow/contrib/cluster_resolver/BUILD
index 15abd2be03..80e18a43a7 100644
--- a/tensorflow/contrib/cluster_resolver/BUILD
+++ b/tensorflow/contrib/cluster_resolver/BUILD
@@ -34,6 +34,7 @@ py_library(
":cluster_resolver_py",
":gce_cluster_resolver_py",
":tpu_cluster_resolver_py",
+ "//tensorflow/python:util",
],
)
diff --git a/tensorflow/contrib/cluster_resolver/__init__.py b/tensorflow/contrib/cluster_resolver/__init__.py
index d17501e87e..b4d8cd4a7c 100644
--- a/tensorflow/contrib/cluster_resolver/__init__.py
+++ b/tensorflow/contrib/cluster_resolver/__init__.py
@@ -26,3 +26,15 @@ from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import
from tensorflow.contrib.cluster_resolver.python.training.gce_cluster_resolver import GceClusterResolver
from tensorflow.contrib.cluster_resolver.python.training.tpu_cluster_resolver import TPUClusterResolver
# pylint: enable=wildcard-import,unused-import
+
+from tensorflow.python.util.all_util import remove_undocumented
+
+_allowed_symbols = [
+ 'ClusterResolver',
+ 'SimpleClusterResolver',
+ 'UnionClusterResolver',
+ 'GceClusterResolver',
+ 'TPUClusterResolver',
+]
+
+remove_undocumented(__name__, _allowed_symbols)
diff --git a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py
index 015f69c567..0c2827b1e4 100644
--- a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py
@@ -744,6 +744,23 @@ class BatchDatasetSerializationTest(
lambda: self._build_dataset_dense_to_sparse(diff_comp),
num_outputs)
+ def _sparse(self, i):
+ return sparse_tensor.SparseTensorValue(
+ indices=[[0]], values=(i * [1]), dense_shape=[1])
+
+ def _build_dataset_sparse(self, batch_size=5):
+ return dataset_ops.Dataset.range(10).map(self._sparse).batch(batch_size)
+
+ def testSparseCore(self):
+ self.run_core_tests(self._build_dataset_sparse,
+ lambda: self._build_dataset_sparse(2), 2)
+
+ def _build_dataset_nested_sparse(self):
+ return dataset_ops.Dataset.range(10).map(self._sparse).batch(5).batch(2)
+
+ def testNestedSparseCore(self):
+ self.run_core_tests(self._build_dataset_nested_sparse, None, 1)
+
class PaddedBatchDatasetSerializationTest(
dataset_serialization_test_base.DatasetSerializationTestBase):
diff --git a/tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py b/tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py
index 701fc8247e..3f64475e47 100644
--- a/tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py
+++ b/tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py
@@ -41,8 +41,9 @@ class DatasetSerializationTestBase(test.TestCase):
def tearDown(self):
self._delete_ckpt()
- # TODO(b/70988345): Support native `tf.SparseTensor` objects and get rid of
- # `sparse_tensors` argument.
+ # TODO(b/72657739): Remove sparse_tensor argument, which is to test the
+ # (deprecated) saveable `SparseTensorSliceDataset`, once the API
+ # `from_sparse_tensor_slices()`and related tests are deleted.
def run_core_tests(self, ds_fn1, ds_fn2, num_outputs, sparse_tensors=False):
"""Runs the core tests.
@@ -559,13 +560,16 @@ class DatasetSerializationTestBase(test.TestCase):
get_next = sparse_tensor.SparseTensor(*iterator.get_next())
else:
get_next = iterator.get_next()
- self._add_iterator_ops_to_collection(init_op, get_next, sparse_tensors)
+ self._add_iterator_ops_to_collection(init_op, get_next, ds_fn,
+ sparse_tensors)
saver = saver_lib.Saver(allow_empty=True)
return init_op, get_next, saver
def _build_empty_graph(self, ds_fn, sparse_tensors=False):
iterator = iterator_ops.Iterator.from_structure(
- self._get_output_types(ds_fn), self._get_output_shapes(ds_fn))
+ self._get_output_types(ds_fn),
+ output_shapes=self._get_output_shapes(ds_fn),
+ output_classes=self._get_output_classes(ds_fn))
saveable = contrib_iterator_ops.make_saveable_from_iterator(iterator)
ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
if sparse_tensors:
@@ -578,12 +582,19 @@ class DatasetSerializationTestBase(test.TestCase):
def _add_iterator_ops_to_collection(self,
init_op,
get_next,
+ ds_fn,
sparse_tensors=False):
ops.add_to_collection("iterator_ops", init_op)
# `get_next` may be a tuple e.g. in TensorSliceDataset. Since Collections
# do not support tuples we flatten the tensors and restore the shape in
# `_get_iterator_ops_from_collection`.
- if sparse_tensors:
+
+ # TODO(shivaniagrwal): `output_classes` is a nested structure of classes,
+ # this base class is specific to current test cases. Update when tests are
+ # added with `output_classes` as a nested structure with at least one of the
+ # component being `tf.SparseTensor`.
+ if (sparse_tensors or
+ self._get_output_classes(ds_fn) is sparse_tensor.SparseTensor):
ops.add_to_collection("iterator_ops", get_next.indices)
ops.add_to_collection("iterator_ops", get_next.values)
ops.add_to_collection("iterator_ops", get_next.dense_shape)
@@ -593,7 +604,8 @@ class DatasetSerializationTestBase(test.TestCase):
def _get_iterator_ops_from_collection(self, ds_fn, sparse_tensors=False):
all_ops = ops.get_collection("iterator_ops")
- if sparse_tensors:
+ if (sparse_tensors or
+ self._get_output_classes(ds_fn) is sparse_tensor.SparseTensor):
init_op, indices, values, dense_shape = all_ops
return init_op, sparse_tensor.SparseTensor(indices, values, dense_shape)
else:
@@ -608,6 +620,10 @@ class DatasetSerializationTestBase(test.TestCase):
with ops.Graph().as_default():
return ds_fn().output_shapes
+ def _get_output_classes(self, ds_fn):
+ with ops.Graph().as_default():
+ return ds_fn().output_classes
+
def _ckpt_path(self):
return os.path.join(self.get_temp_dir(), "iterator")
diff --git a/tensorflow/contrib/data/python/kernel_tests/filter_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/filter_dataset_op_test.py
index 5921be2ae8..06883934d0 100644
--- a/tensorflow/contrib/data/python/kernel_tests/filter_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/filter_dataset_op_test.py
@@ -194,6 +194,10 @@ class FilterDatasetSerializationTest(
return dataset_ops.Dataset.range(10).map(_map_fn).filter(_filter_fn).map(
lambda x, i: x)
+ def testSparseCore(self):
+ num_outputs = 5
+ self.run_core_tests(self._build_sparse_filter, None, num_outputs)
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/flat_map_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/flat_map_dataset_op_test.py
index d4fbaa5cdc..86d69495ef 100644
--- a/tensorflow/contrib/data/python/kernel_tests/flat_map_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/flat_map_dataset_op_test.py
@@ -225,6 +225,21 @@ class FlatMapDatasetSerializationTest(
self.verify_error_on_save(build_ds, 500, errors.InvalidArgumentError)
+ def testSparseCore(self):
+
+ def _map_fn(i):
+ return sparse_tensor.SparseTensorValue(
+ indices=[[0, 0], [1, 1]], values=(i * [1, -1]), dense_shape=[2, 2])
+
+ def _flat_map_fn(x):
+ return dataset_ops.Dataset.from_tensor_slices(
+ sparse_ops.sparse_to_dense(x.indices, x.dense_shape, x.values))
+
+ def _build_ds():
+ return dataset_ops.Dataset.range(10).map(_map_fn).flat_map(_flat_map_fn)
+
+ self.run_core_tests(_build_ds, None, 20)
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py
index b1937c08f3..db8429512b 100644
--- a/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py
@@ -252,6 +252,22 @@ class InterleaveDatasetSeriazationTest(
None, num_outputs)
# pylint: enable=g-long-lambda
+ def testSparseCore(self):
+
+ def _map_fn(i):
+ return sparse_tensor.SparseTensorValue(
+ indices=[[0, 0], [1, 1]], values=(i * [1, -1]), dense_shape=[2, 2])
+
+ def _interleave_fn(x):
+ return dataset_ops.Dataset.from_tensor_slices(
+ sparse_ops.sparse_to_dense(x.indices, x.dense_shape, x.values))
+
+ def _build_dataset():
+ return dataset_ops.Dataset.range(10).map(_map_fn).interleave(
+ _interleave_fn, cycle_length=1)
+
+ self.run_core_tests(_build_dataset, None, 20)
+
class ParallelInterleaveDatasetTest(test.TestCase):
diff --git a/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py
index dd8247bfd4..d3ce89298b 100644
--- a/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py
@@ -805,6 +805,21 @@ class MapDatasetSerializationTest(
self.run_core_tests(_build_ds, None, num_outputs)
+ def testSparseCore(self):
+
+ def _sparse(i):
+ return sparse_tensor.SparseTensorValue(
+ indices=np.array([[0, 0]]),
+ values=(i * np.array([1])),
+ dense_shape=np.array([1, 1]))
+
+ def _build_ds(num_outputs):
+ return contrib_dataset_ops.Dataset.range(num_outputs).map(_sparse)
+
+ num_outputs = 10
+ self.run_core_tests(lambda: _build_ds(num_outputs),
+ lambda: _build_ds(int(num_outputs / 2)), num_outputs)
+
class ParallelMapDatasetSerializationTest(
dataset_serialization_test_base.DatasetSerializationTestBase):
@@ -851,7 +866,8 @@ class ParallelMapDatasetSerializationTest(
return random_ops.random_uniform(
(), 0, 10, dtype=dtypes.int32) * math_ops.to_int32(x)
- return contrib_dataset_ops.Dataset.range(100).map(_map_fn)
+ return contrib_dataset_ops.Dataset.range(100).map(
+ _map_fn, num_parallel_calls=2).prefetch(2)
self.verify_error_on_save(_build_ds, 15, errors.InvalidArgumentError)
@@ -861,7 +877,8 @@ class ParallelMapDatasetSerializationTest(
counter_var = variable_scope.get_variable(
"counter", (), dtypes.int32, use_resource=True)
return (contrib_dataset_ops.Dataset.from_tensors(0).repeat(10).map(
- lambda _: counter_var.assign_add(1)))
+ lambda _: counter_var.assign_add(1),
+ num_parallel_calls=2).prefetch(2))
self.verify_error_on_save(_build_ds, 15, errors.InvalidArgumentError)
@@ -870,7 +887,7 @@ class ParallelMapDatasetSerializationTest(
def _build_ds():
constant_var = constant_op.constant(5)
return (contrib_dataset_ops.Dataset.from_tensors(0).repeat(10).map(
- lambda x: x + constant_var))
+ lambda x: x + constant_var, num_parallel_calls=2).prefetch(2))
self.run_core_tests(_build_ds, None, 10)
@@ -883,7 +900,8 @@ class ParallelMapDatasetSerializationTest(
def defun_fn(x):
return constant_op.constant(1000) + math_ops.to_int32(x)
- return contrib_dataset_ops.Dataset.range(num_outputs).map(defun_fn)
+ return contrib_dataset_ops.Dataset.range(num_outputs).map(
+ defun_fn, num_parallel_calls=2).prefetch(2)
self.run_core_tests(_build_ds, None, num_outputs)
@@ -901,7 +919,8 @@ class ParallelMapDatasetSerializationTest(
return constant_op.constant(11000) + defun_fn_deep(math_ops.to_int32(x))
- return contrib_dataset_ops.Dataset.range(num_outputs).map(defun_fn)
+ return contrib_dataset_ops.Dataset.range(num_outputs).map(
+ defun_fn, num_parallel_calls=2).prefetch(2)
self.run_core_tests(_build_ds, None, num_outputs)
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/kumaraswamy_test.py b/tensorflow/contrib/distributions/python/kernel_tests/kumaraswamy_test.py
new file mode 100644
index 0000000000..ea3c86b5c0
--- /dev/null
+++ b/tensorflow/contrib/distributions/python/kernel_tests/kumaraswamy_test.py
@@ -0,0 +1,388 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import importlib
+
+import numpy as np
+
+from tensorflow.contrib.distributions.python.ops import kumaraswamy as kumaraswamy_lib
+from tensorflow.python.client import session
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import random_seed
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import test
+from tensorflow.python.platform import tf_logging
+
+
+def try_import(name): # pylint: disable=invalid-name
+ module = None
+ try:
+ module = importlib.import_module(name)
+ except ImportError as e:
+ tf_logging.warning("Could not import %s: %s" % (name, str(e)))
+ return module
+
+
+special = try_import("scipy.special")
+stats = try_import("scipy.stats")
+
+
+def _kumaraswamy_mode(a, b):
+ a = np.asarray(a)
+ b = np.asarray(b)
+ return ((a - 1) / (a * b - 1))**(1 / a)
+
+
+def _kumaraswamy_moment(a, b, n):
+ a = np.asarray(a)
+ b = np.asarray(b)
+ return b * special.beta(1.0 + n / a, b)
+
+
+def _harmonic_number(b):
+ b = np.asarray(b)
+ return special.psi(b + 1) - special.psi(1)
+
+
+def _kumaraswamy_cdf(a, b, x):
+ a = np.asarray(a)
+ b = np.asarray(b)
+ x = np.asarray(x)
+ return 1 - (1 - x**a)**b
+
+
+def _kumaraswamy_pdf(a, b, x):
+ a = np.asarray(a)
+ b = np.asarray(b)
+ x = np.asarray(x)
+ return a * b * x ** (a - 1) * (1 - x ** a) ** (b - 1)
+
+
+class KumaraswamyTest(test.TestCase):
+
+ def testSimpleShapes(self):
+ with self.test_session():
+ a = np.random.rand(3)
+ b = np.random.rand(3)
+ dist = kumaraswamy_lib.Kumaraswamy(a, b)
+ self.assertAllEqual([], dist.event_shape_tensor().eval())
+ self.assertAllEqual([3], dist.batch_shape_tensor().eval())
+ self.assertEqual(tensor_shape.TensorShape([]), dist.event_shape)
+ self.assertEqual(tensor_shape.TensorShape([3]), dist.batch_shape)
+
+ def testComplexShapes(self):
+ with self.test_session():
+ a = np.random.rand(3, 2, 2)
+ b = np.random.rand(3, 2, 2)
+ dist = kumaraswamy_lib.Kumaraswamy(a, b)
+ self.assertAllEqual([], dist.event_shape_tensor().eval())
+ self.assertAllEqual([3, 2, 2], dist.batch_shape_tensor().eval())
+ self.assertEqual(tensor_shape.TensorShape([]), dist.event_shape)
+ self.assertEqual(tensor_shape.TensorShape([3, 2, 2]), dist.batch_shape)
+
+ def testComplexShapesBroadcast(self):
+ with self.test_session():
+ a = np.random.rand(3, 2, 2)
+ b = np.random.rand(2, 2)
+ dist = kumaraswamy_lib.Kumaraswamy(a, b)
+ self.assertAllEqual([], dist.event_shape_tensor().eval())
+ self.assertAllEqual([3, 2, 2], dist.batch_shape_tensor().eval())
+ self.assertEqual(tensor_shape.TensorShape([]), dist.event_shape)
+ self.assertEqual(tensor_shape.TensorShape([3, 2, 2]), dist.batch_shape)
+
+ def testAProperty(self):
+ a = [[1., 2, 3]]
+ b = [[2., 4, 3]]
+ with self.test_session():
+ dist = kumaraswamy_lib.Kumaraswamy(a, b)
+ self.assertEqual([1, 3], dist.concentration1.get_shape())
+ self.assertAllClose(a, dist.concentration1.eval())
+
+ def testBProperty(self):
+ a = [[1., 2, 3]]
+ b = [[2., 4, 3]]
+ with self.test_session():
+ dist = kumaraswamy_lib.Kumaraswamy(a, b)
+ self.assertEqual([1, 3], dist.concentration0.get_shape())
+ self.assertAllClose(b, dist.concentration0.eval())
+
+ def testPdfXProper(self):
+ a = [[1., 2, 3]]
+ b = [[2., 4, 3]]
+ with self.test_session():
+ dist = kumaraswamy_lib.Kumaraswamy(a, b, validate_args=True)
+ dist.prob([.1, .3, .6]).eval()
+ dist.prob([.2, .3, .5]).eval()
+ # Either condition can trigger.
+ with self.assertRaisesOpError("sample must be positive"):
+ dist.prob([-1., 0.1, 0.5]).eval()
+ with self.assertRaisesOpError("sample must be positive"):
+ dist.prob([0., 0.1, 0.5]).eval()
+ with self.assertRaisesOpError("sample must be no larger than `1`"):
+ dist.prob([.1, .2, 1.2]).eval()
+
+ def testPdfTwoBatches(self):
+ with self.test_session():
+ a = [1., 2]
+ b = [1., 2]
+ x = [.5, .5]
+ dist = kumaraswamy_lib.Kumaraswamy(a, b)
+ pdf = dist.prob(x)
+ expected_pdf = _kumaraswamy_pdf(a, b, x)
+ self.assertAllClose(expected_pdf, pdf.eval())
+ self.assertEqual((2,), pdf.get_shape())
+
+ def testPdfTwoBatchesNontrivialX(self):
+ with self.test_session():
+ a = [1., 2]
+ b = [1., 2]
+ x = [.3, .7]
+ dist = kumaraswamy_lib.Kumaraswamy(a, b)
+ pdf = dist.prob(x)
+ expected_pdf = _kumaraswamy_pdf(a, b, x)
+ self.assertAllClose(expected_pdf, pdf.eval())
+ self.assertEqual((2,), pdf.get_shape())
+
+ def testPdfUniformZeroBatch(self):
+ with self.test_session():
+ # This is equivalent to a uniform distribution
+ a = 1.
+ b = 1.
+ x = np.array([.1, .2, .3, .5, .8], dtype=np.float32)
+ dist = kumaraswamy_lib.Kumaraswamy(a, b)
+ pdf = dist.prob(x)
+ expected_pdf = _kumaraswamy_pdf(a, b, x)
+ self.assertAllClose(expected_pdf, pdf.eval())
+ self.assertEqual((5,), pdf.get_shape())
+
+ def testPdfAStretchedInBroadcastWhenSameRank(self):
+ with self.test_session():
+ a = [[1., 2]]
+ b = [[1., 2]]
+ x = [[.5, .5], [.3, .7]]
+ dist = kumaraswamy_lib.Kumaraswamy(a, b)
+ pdf = dist.prob(x)
+ expected_pdf = _kumaraswamy_pdf(a, b, x)
+ self.assertAllClose(expected_pdf, pdf.eval())
+ self.assertEqual((2, 2), pdf.get_shape())
+
+ def testPdfAStretchedInBroadcastWhenLowerRank(self):
+ with self.test_session():
+ a = [1., 2]
+ b = [1., 2]
+ x = [[.5, .5], [.2, .8]]
+ pdf = kumaraswamy_lib.Kumaraswamy(a, b).prob(x)
+ expected_pdf = _kumaraswamy_pdf(a, b, x)
+ self.assertAllClose(expected_pdf, pdf.eval())
+ self.assertEqual((2, 2), pdf.get_shape())
+
+ def testPdfXStretchedInBroadcastWhenSameRank(self):
+ with self.test_session():
+ a = [[1., 2], [2., 3]]
+ b = [[1., 2], [2., 3]]
+ x = [[.5, .5]]
+ pdf = kumaraswamy_lib.Kumaraswamy(a, b).prob(x)
+ expected_pdf = _kumaraswamy_pdf(a, b, x)
+ self.assertAllClose(expected_pdf, pdf.eval())
+ self.assertEqual((2, 2), pdf.get_shape())
+
+ def testPdfXStretchedInBroadcastWhenLowerRank(self):
+ with self.test_session():
+ a = [[1., 2], [2., 3]]
+ b = [[1., 2], [2., 3]]
+ x = [.5, .5]
+ pdf = kumaraswamy_lib.Kumaraswamy(a, b).prob(x)
+ expected_pdf = _kumaraswamy_pdf(a, b, x)
+ self.assertAllClose(expected_pdf, pdf.eval())
+ self.assertEqual((2, 2), pdf.get_shape())
+
+ def testKumaraswamyMean(self):
+ with session.Session():
+ a = [1., 2, 3]
+ b = [2., 4, 1.2]
+ dist = kumaraswamy_lib.Kumaraswamy(a, b)
+ self.assertEqual(dist.mean().get_shape(), (3,))
+ if not stats:
+ return
+ expected_mean = _kumaraswamy_moment(a, b, 1)
+ self.assertAllClose(expected_mean, dist.mean().eval())
+
+ def testKumaraswamyVariance(self):
+ with session.Session():
+ a = [1., 2, 3]
+ b = [2., 4, 1.2]
+ dist = kumaraswamy_lib.Kumaraswamy(a, b)
+ self.assertEqual(dist.variance().get_shape(), (3,))
+ if not stats:
+ return
+ expected_variance = _kumaraswamy_moment(a, b, 2) - _kumaraswamy_moment(
+ a, b, 1)**2
+ self.assertAllClose(expected_variance, dist.variance().eval())
+
+ def testKumaraswamyMode(self):
+ with session.Session():
+ a = np.array([1.1, 2, 3])
+ b = np.array([2., 4, 1.2])
+ expected_mode = _kumaraswamy_mode(a, b)
+ dist = kumaraswamy_lib.Kumaraswamy(a, b)
+ self.assertEqual(dist.mode().get_shape(), (3,))
+ self.assertAllClose(expected_mode, dist.mode().eval())
+
+ def testKumaraswamyModeInvalid(self):
+ with session.Session():
+ a = np.array([1., 2, 3])
+ b = np.array([2., 4, 1.2])
+ dist = kumaraswamy_lib.Kumaraswamy(a, b, allow_nan_stats=False)
+ with self.assertRaisesOpError("Condition x < y.*"):
+ dist.mode().eval()
+
+ a = np.array([2., 2, 3])
+ b = np.array([1., 4, 1.2])
+ dist = kumaraswamy_lib.Kumaraswamy(a, b, allow_nan_stats=False)
+ with self.assertRaisesOpError("Condition x < y.*"):
+ dist.mode().eval()
+
+ def testKumaraswamyModeEnableAllowNanStats(self):
+ with session.Session():
+ a = np.array([1., 2, 3])
+ b = np.array([2., 4, 1.2])
+ dist = kumaraswamy_lib.Kumaraswamy(a, b, allow_nan_stats=True)
+
+ expected_mode = _kumaraswamy_mode(a, b)
+ expected_mode[0] = np.nan
+ self.assertEqual((3,), dist.mode().get_shape())
+ self.assertAllClose(expected_mode, dist.mode().eval())
+
+ a = np.array([2., 2, 3])
+ b = np.array([1., 4, 1.2])
+ dist = kumaraswamy_lib.Kumaraswamy(a, b, allow_nan_stats=True)
+
+ expected_mode = _kumaraswamy_mode(a, b)
+ expected_mode[0] = np.nan
+ self.assertEqual((3,), dist.mode().get_shape())
+ self.assertAllClose(expected_mode, dist.mode().eval())
+
+ def testKumaraswamyEntropy(self):
+ with session.Session():
+ a = np.array([1., 2, 3])
+ b = np.array([2., 4, 1.2])
+ dist = kumaraswamy_lib.Kumaraswamy(a, b)
+ self.assertEqual(dist.entropy().get_shape(), (3,))
+ if not stats:
+ return
+ expected_entropy = (1 - 1. / a) + (
+ 1 - 1. / b) * _harmonic_number(b) + np.log(a * b)
+ self.assertAllClose(expected_entropy, dist.entropy().eval())
+
+ def testKumaraswamySample(self):
+ with self.test_session():
+ a = 1.
+ b = 2.
+ kumaraswamy = kumaraswamy_lib.Kumaraswamy(a, b)
+ n = constant_op.constant(100000)
+ samples = kumaraswamy.sample(n)
+ sample_values = samples.eval()
+ self.assertEqual(sample_values.shape, (100000,))
+ self.assertFalse(np.any(sample_values < 0.0))
+ if not stats:
+ return
+ self.assertLess(
+ stats.kstest(
+ # Kumaraswamy is a univariate distribution.
+ sample_values,
+ lambda x: _kumaraswamy_cdf(1., 2., x))[0],
+ 0.01)
+ # The standard error of the sample mean is 1 / (sqrt(18 * n))
+ expected_mean = _kumaraswamy_moment(a, b, 1)
+ self.assertAllClose(sample_values.mean(axis=0), expected_mean, atol=1e-2)
+ expected_variance = _kumaraswamy_moment(a, b, 2) - _kumaraswamy_moment(
+ a, b, 1)**2
+ self.assertAllClose(
+ np.cov(sample_values, rowvar=0), expected_variance, atol=1e-1)
+
+ # Test that sampling with the same seed twice gives the same results.
+ def testKumaraswamySampleMultipleTimes(self):
+ with self.test_session():
+ a_val = 1.
+ b_val = 2.
+ n_val = 100
+
+ random_seed.set_random_seed(654321)
+ kumaraswamy1 = kumaraswamy_lib.Kumaraswamy(
+ concentration1=a_val, concentration0=b_val, name="kumaraswamy1")
+ samples1 = kumaraswamy1.sample(n_val, seed=123456).eval()
+
+ random_seed.set_random_seed(654321)
+ kumaraswamy2 = kumaraswamy_lib.Kumaraswamy(
+ concentration1=a_val, concentration0=b_val, name="kumaraswamy2")
+ samples2 = kumaraswamy2.sample(n_val, seed=123456).eval()
+
+ self.assertAllClose(samples1, samples2)
+
+ def testKumaraswamySampleMultidimensional(self):
+ with self.test_session():
+ a = np.random.rand(3, 2, 2).astype(np.float32)
+ b = np.random.rand(3, 2, 2).astype(np.float32)
+ kumaraswamy = kumaraswamy_lib.Kumaraswamy(a, b)
+ n = constant_op.constant(100000)
+ samples = kumaraswamy.sample(n)
+ sample_values = samples.eval()
+ self.assertEqual(sample_values.shape, (100000, 3, 2, 2))
+ self.assertFalse(np.any(sample_values < 0.0))
+ if not stats:
+ return
+ self.assertAllClose(
+ sample_values[:, 1, :].mean(axis=0),
+ _kumaraswamy_moment(a, b, 1)[1, :],
+ atol=1e-1)
+
+ def testKumaraswamyCdf(self):
+ with self.test_session():
+ shape = (30, 40, 50)
+ for dt in (np.float32, np.float64):
+ a = 10. * np.random.random(shape).astype(dt)
+ b = 10. * np.random.random(shape).astype(dt)
+ x = np.random.random(shape).astype(dt)
+ actual = kumaraswamy_lib.Kumaraswamy(a, b).cdf(x).eval()
+ self.assertAllEqual(np.ones(shape, dtype=np.bool), 0. <= x)
+ self.assertAllEqual(np.ones(shape, dtype=np.bool), 1. >= x)
+ if not stats:
+ return
+ self.assertAllClose(
+ _kumaraswamy_cdf(a, b, x), actual, rtol=1e-4, atol=0)
+
+ def testKumaraswamyLogCdf(self):
+ with self.test_session():
+ shape = (30, 40, 50)
+ for dt in (np.float32, np.float64):
+ a = 10. * np.random.random(shape).astype(dt)
+ b = 10. * np.random.random(shape).astype(dt)
+ x = np.random.random(shape).astype(dt)
+ actual = math_ops.exp(kumaraswamy_lib.Kumaraswamy(a,
+ b).log_cdf(x)).eval()
+ self.assertAllEqual(np.ones(shape, dtype=np.bool), 0. <= x)
+ self.assertAllEqual(np.ones(shape, dtype=np.bool), 1. >= x)
+ if not stats:
+ return
+ self.assertAllClose(
+ _kumaraswamy_cdf(a, b, x), actual, rtol=1e-4, atol=0)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/distributions/python/ops/kumaraswamy.py b/tensorflow/contrib/distributions/python/ops/kumaraswamy.py
new file mode 100644
index 0000000000..74d5d8773c
--- /dev/null
+++ b/tensorflow/contrib/distributions/python/ops/kumaraswamy.py
@@ -0,0 +1,258 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""The Kumaraswamy distribution class."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import check_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import random_ops
+from tensorflow.python.ops import special_math_ops
+from tensorflow.python.ops.distributions import beta
+from tensorflow.python.ops.distributions import distribution
+from tensorflow.python.ops.distributions import util as distribution_util
+from tensorflow.python.util.tf_export import tf_export
+
+__all__ = [
+ "Kumaraswamy",
+]
+
+_kumaraswamy_sample_note = """Note: `x` must have dtype `self.dtype` and be in
+`[0, 1].` It must have a shape compatible with `self.batch_shape()`."""
+
+
+def _harmonic_number(x):
+ """Compute the harmonic number from its analytic continuation.
+
+ Derivation from [1] and Euler's constant [2].
+ [1] -
+ https://en.wikipedia.org/wiki/Digamma_function#Relation_to_harmonic_numbers
+ [2] - https://en.wikipedia.org/wiki/Euler%E2%80%93Mascheroni_constant
+
+
+ Args:
+ x: input float.
+
+ Returns:
+ z: The analytic continuation of the harmonic number for the input.
+
+ """
+ one = array_ops.ones([], dtype=x.dtype)
+ return math_ops.digamma(x + one) - math_ops.digamma(one)
+
+
+@tf_export("distributions.Kumaraswamy")
+class Kumaraswamy(beta.Beta):
+ """Kumaraswamy distribution.
+
+ The Kumaraswamy distribution is defined over the `(0, 1)` interval using
+ parameters
+ `concentration1` (aka "alpha") and `concentration0` (aka "beta"). It has a
+ shape similar to the Beta distribution, but is reparameterizeable.
+
+ #### Mathematical Details
+
+ The probability density function (pdf) is,
+
+ ```none
+ pdf(x; alpha, beta) = alpha * beta * x**(alpha - 1) * (1 - x**alpha)**(beta -
+ 1)
+ ```
+
+ where:
+
+ * `concentration1 = alpha`,
+ * `concentration0 = beta`,
+
+ Distribution parameters are automatically broadcast in all functions; see
+ examples for details.
+
+ #### Examples
+
+ ```python
+ # Create a batch of three Kumaraswamy distributions.
+ alpha = [1, 2, 3]
+ beta = [1, 2, 3]
+ dist = Kumaraswamy(alpha, beta)
+
+ dist.sample([4, 5]) # Shape [4, 5, 3]
+
+ # `x` has three batch entries, each with two samples.
+ x = [[.1, .4, .5],
+ [.2, .3, .5]]
+ # Calculate the probability of each pair of samples under the corresponding
+ # distribution in `dist`.
+ dist.prob(x) # Shape [2, 3]
+ ```
+
+ ```python
+ # Create batch_shape=[2, 3] via parameter broadcast:
+ alpha = [[1.], [2]] # Shape [2, 1]
+ beta = [3., 4, 5] # Shape [3]
+ dist = Kumaraswamy(alpha, beta)
+
+ # alpha broadcast as: [[1., 1, 1,],
+ # [2, 2, 2]]
+ # beta broadcast as: [[3., 4, 5],
+ # [3, 4, 5]]
+ # batch_Shape [2, 3]
+ dist.sample([4, 5]) # Shape [4, 5, 2, 3]
+
+ x = [.2, .3, .5]
+ # x will be broadcast as [[.2, .3, .5],
+ # [.2, .3, .5]],
+ # thus matching batch_shape [2, 3].
+ dist.prob(x) # Shape [2, 3]
+ ```
+
+ """
+
+ def __init__(self,
+ concentration1=None,
+ concentration0=None,
+ validate_args=False,
+ allow_nan_stats=True,
+ name="Kumaraswamy"):
+ """Initialize a batch of Kumaraswamy distributions.
+
+ Args:
+ concentration1: Positive floating-point `Tensor` indicating mean
+ number of successes; aka "alpha". Implies `self.dtype` and
+ `self.batch_shape`, i.e.,
+ `concentration1.shape = [N1, N2, ..., Nm] = self.batch_shape`.
+ concentration0: Positive floating-point `Tensor` indicating mean
+ number of failures; aka "beta". Otherwise has same semantics as
+ `concentration1`.
+ validate_args: Python `bool`, default `False`. When `True` distribution
+ parameters are checked for validity despite possibly degrading runtime
+ performance. When `False` invalid inputs may silently render incorrect
+ outputs.
+ allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
+ (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
+ result is undefined. When `False`, an exception is raised if one or
+ more of the statistic's batch members are undefined.
+ name: Python `str` name prefixed to Ops created by this class.
+ """
+ super(Kumaraswamy, self).__init__(
+ concentration1=concentration1,
+ concentration0=concentration0,
+ validate_args=validate_args,
+ allow_nan_stats=allow_nan_stats,
+ name=name)
+ self._reparameterization_type = distribution.FULLY_REPARAMETERIZED
+
+ def _sample_n(self, n, seed=None):
+ expanded_concentration1 = array_ops.ones_like(
+ self.total_concentration, dtype=self.dtype) * self.concentration1
+ expanded_concentration0 = array_ops.ones_like(
+ self.total_concentration, dtype=self.dtype) * self.concentration0
+ shape = array_ops.concat([[n], self.batch_shape_tensor()], 0)
+ uniform_sample = random_ops.random_uniform(
+ shape=shape, minval=0.0, maxval=1.0, dtype=self.dtype, seed=seed)
+
+ kumaraswamy_sample = (1 - uniform_sample**(1. / expanded_concentration0))**(
+ 1. / expanded_concentration1)
+ return kumaraswamy_sample
+
+ @distribution_util.AppendDocstring(_kumaraswamy_sample_note)
+ def _log_cdf(self, x):
+ a = self.concentration1
+ b = self.concentration0
+ return math_ops.log1p(-(1 - x**a)**b)
+
+ @distribution_util.AppendDocstring(_kumaraswamy_sample_note)
+ def _cdf(self, x):
+ a = self.concentration1
+ b = self.concentration0
+ return 1 - (1 - x**a)**b
+
+ def _survival_function(self, x):
+ a = self.concentration1
+ b = self.concentration0
+ return (1 - x**a)**b
+
+ def _log_survival_function(self, x):
+ a = self.concentration1
+ b = self.concentration0
+ return b * math_ops.log1p(-x**a)
+
+ def _log_unnormalized_prob(self, x):
+ x = self._maybe_assert_valid_sample(x)
+ a = self.concentration1
+ b = self.concentration0
+ return (a - 1) * math_ops.log(x) + (b - 1) * math_ops.log1p(-x**a)
+
+ def _log_normalization(self):
+ a = self.concentration1
+ b = self.concentration0
+ return -(math_ops.log(a) + math_ops.log(b))
+
+ def _entropy(self):
+ a = self.concentration1
+ b = self.concentration0
+ return (1 - 1. / a) + (
+ 1 - 1. / b) * _harmonic_number(b) + math_ops.log(a) + math_ops.log(b)
+
+ def _moment(self, n):
+ """Compute the n'th (uncentered) moment."""
+ expanded_concentration1 = array_ops.ones_like(
+ self.total_concentration, dtype=self.dtype) * self.concentration1
+ expanded_concentration0 = array_ops.ones_like(
+ self.total_concentration, dtype=self.dtype) * self.concentration0
+ beta_arg0 = 1 + n / expanded_concentration1
+ beta_arg = array_ops.stack([beta_arg0, expanded_concentration0], -1)
+ log_moment = math_ops.log(expanded_concentration0) + special_math_ops.lbeta(
+ beta_arg)
+ return math_ops.exp(log_moment)
+
+ def _mean(self):
+ return self._moment(1)
+
+ def _variance(self):
+ # TODO(b/72696533): Investigate a more numerically stable version.
+ return self._moment(2) - math_ops.square(self._moment(1))
+
+ @distribution_util.AppendDocstring(
+ """Note: The mode is undefined when `concentration1 <= 1` or
+ `concentration0 <= 1`. If `self.allow_nan_stats` is `True`, `NaN`
+ is used for undefined modes. If `self.allow_nan_stats` is `False` an
+ exception is raised when one or more modes are undefined.""")
+ def _mode(self):
+ a = self.concentration1
+ b = self.concentration0
+ mode = ((a - 1) / (a * b - 1))**(1. / a)
+ if self.allow_nan_stats:
+ nan = array_ops.fill(
+ self.batch_shape_tensor(),
+ np.array(np.nan, dtype=self.dtype.as_numpy_dtype),
+ name="nan")
+ is_defined = (self.concentration1 > 1.) & (self.concentration0 > 1.)
+ return array_ops.where(is_defined, mode, nan)
+ return control_flow_ops.with_dependencies([
+ check_ops.assert_less(
+ array_ops.ones([], dtype=self.dtype),
+ self.concentration1,
+ message="Mode undefined for concentration1 <= 1."),
+ check_ops.assert_less(
+ array_ops.ones([], dtype=self.dtype),
+ self.concentration0,
+ message="Mode undefined for concentration0 <= 1.")
+ ], mode)
diff --git a/tensorflow/contrib/eager/python/tfe_test.py b/tensorflow/contrib/eager/python/tfe_test.py
index 0dedb2fd7c..b6659c2a17 100644
--- a/tensorflow/contrib/eager/python/tfe_test.py
+++ b/tensorflow/contrib/eager/python/tfe_test.py
@@ -102,10 +102,6 @@ class TFETest(test_util.TensorFlowTestCase):
# Expect at least one device.
self.assertTrue(tfe.list_devices())
- def testNumGPUs(self):
- devices = tfe.list_devices()
- self.assertEqual(len(devices) - 1, tfe.num_gpus())
-
def testAddCheckNumericsOpsRaisesError(self):
with self.assertRaisesRegexp(
RuntimeError,
diff --git a/tensorflow/contrib/factorization/python/ops/kmeans.py b/tensorflow/contrib/factorization/python/ops/kmeans.py
index 4d0f9b2424..c861cfff54 100644
--- a/tensorflow/contrib/factorization/python/ops/kmeans.py
+++ b/tensorflow/contrib/factorization/python/ops/kmeans.py
@@ -143,7 +143,7 @@ class _ModelFn(object):
def model_fn(self, features, mode, config):
"""Model function for the estimator.
- Note that this does not take a `1abels` arg. This works, but `input_fn` must
+ Note that this does not take a `labels` arg. This works, but `input_fn` must
return either `features` or, equivalently, `(features, None)`.
Args:
diff --git a/tensorflow/contrib/kfac/python/ops/fisher_factors.py b/tensorflow/contrib/kfac/python/ops/fisher_factors.py
index f59168cbc0..bcba18ae14 100644
--- a/tensorflow/contrib/kfac/python/ops/fisher_factors.py
+++ b/tensorflow/contrib/kfac/python/ops/fisher_factors.py
@@ -31,6 +31,7 @@ from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn
from tensorflow.python.ops import special_math_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
@@ -111,6 +112,54 @@ def diagonal_covariance_initializer(shape, dtype, partition_info): # pylint: di
return array_ops.ones(shape, dtype)
+def extract_image_patches(image, ksizes, strides, padding, name=None):
+ """Extracts image patches for an N-dimensional convolution.
+
+ This function is a compatibility wrapper over tf.extract_image_patches(), as
+ ExtractImagePatches isn't yet implemented in XLA.
+
+ Args:
+ image: Tensor of shape [batch, in_x, in_y, ..., in_channels]. Input images.
+ All dimensions except 'batch' must be defined.
+ ksizes: [filter_x, filter_y, ...]. Spatial shape of filter in each
+ dimension.
+ strides: [stride_x, stride_y, ...]. Spatial stride for filter in each
+ dimension.
+ padding: str. "VALID" or "SAME".
+ name: str or None. name of Op.
+
+ Returns:
+ result: [batch, out_x, out_y, ..., filter_x, filter_y, ..., in_channels].
+ Contains image patches to which conv kernel would be applied for each
+ output location. [out_x, out_y, ...] depends on padding.
+ """
+ if not utils.on_tpu():
+ return array_ops.extract_image_patches(
+ image,
+ ksizes=([1] + list(ksizes) + [1]),
+ strides=([1] + list(strides) + [1]),
+ rates=[1, 1, 1, 1],
+ padding=padding,
+ name=name)
+
+ with tf_ops.name_scope(name, "extract_image_patches",
+ [image, ksizes, strides, padding]):
+ batch = image.shape.as_list()[0]
+ in_channels = image.shape.as_list()[-1]
+
+ # Map each input feature to a location in the output.
+ out_channels = np.prod(ksizes) * in_channels
+ filters = linalg_ops.eye(out_channels),
+ filters = array_ops.reshape(filters, ksizes + [in_channels, out_channels])
+
+ result = nn.convolution(image, filters, padding, strides=strides)
+ out_spatial = result.shape.as_list()[1:-1]
+ result = array_ops.reshape(
+ result, [batch or -1] + out_spatial + ksizes + [in_channels])
+
+ return result
+
+
def compute_cov(tensor, tensor_right=None, normalizer=None):
"""Compute the empirical second moment of the rows of a 2D Tensor.
@@ -668,11 +717,10 @@ class ConvDiagonalFactor(DiagonalFactor):
# TODO(b/64144716): there is potential here for a big savings in terms
# of memory use.
- patches = array_ops.extract_image_patches(
+ patches = extract_image_patches(
self._inputs,
- ksizes=[1, filter_height, filter_width, 1],
- strides=self._strides,
- rates=[1, 1, 1, 1],
+ ksizes=[filter_height, filter_width],
+ strides=self._strides[1:-1],
padding=self._padding)
if self._has_bias:
@@ -816,11 +864,10 @@ class ConvInputKroneckerFactor(InverseProvidingFactor):
# TODO(b/64144716): there is potential here for a big savings in terms of
# memory use.
- patches = array_ops.extract_image_patches(
+ patches = extract_image_patches(
self._inputs,
- ksizes=[1, filter_height, filter_width, 1],
- strides=self._strides,
- rates=[1, 1, 1, 1],
+ ksizes=[filter_height, filter_width],
+ strides=self._strides[1:-1],
padding=self._padding)
flatten_size = (filter_height * filter_width * in_channels)
diff --git a/tensorflow/contrib/lite/examples/ios/camera/CameraExampleViewController.mm b/tensorflow/contrib/lite/examples/ios/camera/CameraExampleViewController.mm
index 10f31bb6f1..d74e275f04 100644
--- a/tensorflow/contrib/lite/examples/ios/camera/CameraExampleViewController.mm
+++ b/tensorflow/contrib/lite/examples/ios/camera/CameraExampleViewController.mm
@@ -225,14 +225,8 @@ static void GetTopN(const uint8_t* prediction, const int prediction_size, const
assert(pixelBuffer != NULL);
OSType sourcePixelFormat = CVPixelBufferGetPixelFormatType(pixelBuffer);
- int doReverseChannels;
- if (kCVPixelFormatType_32ARGB == sourcePixelFormat) {
- doReverseChannels = 1;
- } else if (kCVPixelFormatType_32BGRA == sourcePixelFormat) {
- doReverseChannels = 0;
- } else {
- assert(false); // Unknown source format
- }
+ assert(sourcePixelFormat == kCVPixelFormatType_32ARGB ||
+ sourcePixelFormat == kCVPixelFormatType_32BGRA);
const int sourceRowBytes = (int)CVPixelBufferGetBytesPerRow(pixelBuffer);
const int image_width = (int)CVPixelBufferGetWidth(pixelBuffer);
diff --git a/tensorflow/contrib/lite/kernels/add.cc b/tensorflow/contrib/lite/kernels/add.cc
index fb5764f280..63ea89df56 100644
--- a/tensorflow/contrib/lite/kernels/add.cc
+++ b/tensorflow/contrib/lite/kernels/add.cc
@@ -37,7 +37,23 @@ constexpr int kInputTensor1 = 0;
constexpr int kInputTensor2 = 1;
constexpr int kOutputTensor = 0;
+struct OpData {
+ bool requires_broadcast;
+};
+
+void* Init(TfLiteContext* context, const char* buffer, size_t length) {
+ auto* data = new OpData;
+ data->requires_broadcast = false;
+ return data;
+}
+
+void Free(TfLiteContext* context, void* buffer) {
+ delete reinterpret_cast<OpData*>(buffer);
+}
+
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ OpData* data = reinterpret_cast<OpData*>(node->user_data);
+
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
@@ -45,43 +61,56 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
- TF_LITE_ENSURE_EQ(context, NumDimensions(input1), NumDimensions(input2));
- for (int i = 0; i < NumDimensions(input1); ++i) {
- TF_LITE_ENSURE_EQ(context, SizeOfDimension(input1, i),
- SizeOfDimension(input2, i));
- }
+ TF_LITE_ENSURE_EQ(context, input1->type, input2->type);
+ output->type = input2->type;
+
+ data->requires_broadcast = !HaveSameShapes(input1, input2);
- TF_LITE_ENSURE_EQ(context, input1->type, output->type);
- TF_LITE_ENSURE_EQ(context, input2->type, output->type);
+ TfLiteIntArray* output_size = nullptr;
+ if (data->requires_broadcast) {
+ TF_LITE_ENSURE_OK(context, CalculateShapeForBroadcast(
+ context, input1, input2, &output_size));
+ } else {
+ output_size = TfLiteIntArrayCopy(input1->dims);
+ }
- TfLiteIntArray* output_size = TfLiteIntArrayCopy(input1->dims);
return context->ResizeTensor(context, output, output_size);
}
template <KernelType kernel_type>
void EvalAddFloat(TfLiteContext* context, TfLiteNode* node,
- TfLiteAddParams* params, TfLiteTensor* input1,
- TfLiteTensor* input2, TfLiteTensor* output) {
+ TfLiteAddParams* params, const OpData* data,
+ TfLiteTensor* input1, TfLiteTensor* input2,
+ TfLiteTensor* output) {
float output_activation_min, output_activation_max;
CalculateActivationRangeFloat(params->activation, &output_activation_min,
&output_activation_max);
-#define TF_LITE_ADD(type) \
- type::Add(GetTensorData<float>(input1), GetTensorDims(input1), \
- GetTensorData<float>(input2), GetTensorDims(input2), \
- output_activation_min, output_activation_max, \
- GetTensorData<float>(output), GetTensorDims(output))
+#define TF_LITE_ADD(type, opname) \
+ type::opname(GetTensorData<float>(input1), GetTensorDims(input1), \
+ GetTensorData<float>(input2), GetTensorDims(input2), \
+ output_activation_min, output_activation_max, \
+ GetTensorData<float>(output), GetTensorDims(output))
if (kernel_type == kReference) {
- TF_LITE_ADD(reference_ops);
+ if (data->requires_broadcast) {
+ TF_LITE_ADD(reference_ops, BroadcastAdd);
+ } else {
+ TF_LITE_ADD(reference_ops, Add);
+ }
} else {
- TF_LITE_ADD(optimized_ops);
+ if (data->requires_broadcast) {
+ TF_LITE_ADD(optimized_ops, BroadcastAdd);
+ } else {
+ TF_LITE_ADD(optimized_ops, Add);
+ }
}
#undef TF_LITE_ADD
}
template <KernelType kernel_type>
void EvalAddQuantized(TfLiteContext* context, TfLiteNode* node,
- TfLiteAddParams* params, TfLiteTensor* input1,
- TfLiteTensor* input2, TfLiteTensor* output) {
+ TfLiteAddParams* params, const OpData* data,
+ TfLiteTensor* input1, TfLiteTensor* input2,
+ TfLiteTensor* output) {
auto input1_offset = -input1->params.zero_point;
auto input2_offset = -input2->params.zero_point;
auto output_offset = output->params.zero_point;
@@ -112,19 +141,20 @@ void EvalAddQuantized(TfLiteContext* context, TfLiteNode* node,
CalculateActivationRangeUint8(params->activation, output,
&output_activation_min, &output_activation_max);
-#define TF_LITE_ADD(type) \
- type::BroadcastAdd( \
- left_shift, GetTensorData<uint8_t>(input1), GetTensorDims(input1), \
- input1_offset, input1_multiplier, input1_shift, \
- GetTensorData<uint8_t>(input2), GetTensorDims(input2), input2_offset, \
- input2_multiplier, input2_shift, output_offset, output_multiplier, \
- output_shift, output_activation_min, output_activation_max, \
- GetTensorData<uint8_t>(output), GetTensorDims(output));
-
+#define TF_LITE_ADD(type, opname) \
+ type::opname(left_shift, GetTensorData<uint8_t>(input1), \
+ GetTensorDims(input1), input1_offset, input1_multiplier, \
+ input1_shift, GetTensorData<uint8_t>(input2), \
+ GetTensorDims(input2), input2_offset, input2_multiplier, \
+ input2_shift, output_offset, output_multiplier, output_shift, \
+ output_activation_min, output_activation_max, \
+ GetTensorData<uint8_t>(output), GetTensorDims(output));
+ // The quantized version of Add doesn't support activations, so we
+ // always use BroadcastAdd.
if (kernel_type == kReference) {
- TF_LITE_ADD(reference_ops);
+ TF_LITE_ADD(reference_ops, BroadcastAdd);
} else {
- TF_LITE_ADD(optimized_ops);
+ TF_LITE_ADD(optimized_ops, BroadcastAdd);
}
#undef TF_LITE_ADD
}
@@ -132,15 +162,17 @@ void EvalAddQuantized(TfLiteContext* context, TfLiteNode* node,
template <KernelType kernel_type>
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
auto* params = reinterpret_cast<TfLiteAddParams*>(node->builtin_data);
+ OpData* data = reinterpret_cast<OpData*>(node->user_data);
TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
if (output->type == kTfLiteFloat32) {
- EvalAddFloat<kernel_type>(context, node, params, input1, input2, output);
+ EvalAddFloat<kernel_type>(context, node, params, data, input1, input2,
+ output);
} else if (output->type == kTfLiteUInt8) {
- EvalAddQuantized<kernel_type>(context, node, params, input1, input2,
+ EvalAddQuantized<kernel_type>(context, node, params, data, input1, input2,
output);
} else {
context->ReportError(context,
@@ -154,19 +186,19 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace add
TfLiteRegistration* Register_ADD_REF() {
- static TfLiteRegistration r = {nullptr, nullptr, add::Prepare,
+ static TfLiteRegistration r = {add::Init, add::Free, add::Prepare,
add::Eval<add::kReference>};
return &r;
}
TfLiteRegistration* Register_ADD_GENERIC_OPT() {
- static TfLiteRegistration r = {nullptr, nullptr, add::Prepare,
+ static TfLiteRegistration r = {add::Init, add::Free, add::Prepare,
add::Eval<add::kGenericOptimized>};
return &r;
}
TfLiteRegistration* Register_ADD_NEON_OPT() {
- static TfLiteRegistration r = {nullptr, nullptr, add::Prepare,
+ static TfLiteRegistration r = {add::Init, add::Free, add::Prepare,
add::Eval<add::kNeonOptimized>};
return &r;
}
diff --git a/tensorflow/contrib/lite/kernels/add_test.cc b/tensorflow/contrib/lite/kernels/add_test.cc
index 306dfc3e80..956d05bed5 100644
--- a/tensorflow/contrib/lite/kernels/add_test.cc
+++ b/tensorflow/contrib/lite/kernels/add_test.cc
@@ -25,10 +25,11 @@ using ::testing::ElementsAreArray;
class BaseAddOpModel : public SingleOpModel {
public:
- BaseAddOpModel(const TensorData& input, const TensorData& output,
+ BaseAddOpModel(const TensorData& input1, const TensorData& input2,
+ const TensorData& output,
ActivationFunctionType activation_type) {
- input1_ = AddInput(input);
- input2_ = AddInput(input);
+ input1_ = AddInput(input1);
+ input2_ = AddInput(input2);
output_ = AddOutput(output);
SetBuiltinOp(BuiltinOperator_ADD, BuiltinOptions_AddOptions,
CreateAddOptions(builder_, activation_type).Union());
@@ -70,6 +71,7 @@ float GetTolerance(int min, int max) {
TEST(FloatAddOpModel, NoActivation) {
FloatAddOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}},
+ {TensorType_FLOAT32, {1, 2, 2, 1}},
{TensorType_FLOAT32, {}}, ActivationFunctionType_NONE);
m.PopulateTensor<float>(m.input1(), {-2.0, 0.2, 0.7, 0.8});
m.PopulateTensor<float>(m.input2(), {0.1, 0.2, 0.3, 0.5});
@@ -78,9 +80,9 @@ TEST(FloatAddOpModel, NoActivation) {
}
TEST(FloatAddOpModel, ActivationRELU_N1_TO_1) {
- FloatAddOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}},
- {TensorType_FLOAT32, {}},
- ActivationFunctionType_RELU_N1_TO_1);
+ FloatAddOpModel m(
+ {TensorType_FLOAT32, {1, 2, 2, 1}}, {TensorType_FLOAT32, {1, 2, 2, 1}},
+ {TensorType_FLOAT32, {}}, ActivationFunctionType_RELU_N1_TO_1);
m.PopulateTensor<float>(m.input1(), {-2.0, 0.2, 0.7, 0.8});
m.PopulateTensor<float>(m.input2(), {0.1, 0.2, 0.3, 0.5});
m.Invoke();
@@ -92,6 +94,7 @@ TEST(FloatAddOpModel, VariousInputShapes) {
{6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}};
for (int i = 0; i < test_shapes.size(); ++i) {
FloatAddOpModel m({TensorType_FLOAT32, test_shapes[i]},
+ {TensorType_FLOAT32, test_shapes[i]},
{TensorType_FLOAT32, {}}, ActivationFunctionType_NONE);
m.PopulateTensor<float>(m.input1(), {-2.0, 0.2, 0.7, 0.8, 1.1, 2.0});
m.PopulateTensor<float>(m.input2(), {0.1, 0.2, 0.3, 0.5, 1.1, 0.1});
@@ -102,6 +105,23 @@ TEST(FloatAddOpModel, VariousInputShapes) {
}
}
+TEST(FloatAddOpModel, WithBroadcast) {
+ std::vector<std::initializer_list<int>> test_shapes = {
+ {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}};
+ for (int i = 0; i < test_shapes.size(); ++i) {
+ FloatAddOpModel m({TensorType_FLOAT32, test_shapes[i]},
+ {TensorType_FLOAT32, {}}, // always a scalar
+ {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE);
+ m.PopulateTensor<float>(m.input1(), {-2.0, 0.2, 0.7, 0.8, 1.1, 2.0});
+ m.PopulateTensor<float>(m.input2(), {0.1});
+ m.Invoke();
+ EXPECT_THAT(
+ m.GetOutput(),
+ ElementsAreArray(ArrayFloatNear({-1.9, 0.3, 0.8, 0.9, 1.2, 2.1})))
+ << "With shape number " << i;
+ }
+}
+
TEST(QuantizedAddOpModel, QuantizedTestsNoActivation) {
float kQuantizedTolerance = GetTolerance(-1.0, 1.0);
std::vector<std::initializer_list<float>> inputs1 = {
@@ -112,6 +132,7 @@ TEST(QuantizedAddOpModel, QuantizedTestsNoActivation) {
{0.7, 0.6, 0.6, 0.5}, {-0.2, 0.6, 0.9, -0.1}, {-0.2, 0.6, -0.1, 0.8}};
for (int i = 0; i < inputs1.size(); ++i) {
QuantizedAddOpModel m({TensorType_UINT8, {1, 2, 2, 1}, -1.0, 1.0},
+ {TensorType_UINT8, {1, 2, 2, 1}, -1.0, 1.0},
{TensorType_UINT8, {}, -1.0, 1.0},
ActivationFunctionType_NONE);
m.QuantizeAndPopulate<uint8_t>(m.input1(), inputs1[i]);
@@ -133,6 +154,7 @@ TEST(QuantizedAddOpModel, QuantizedTestsActivationRELU_N1_TO_1) {
{-0.2, 0.6, -0.1, 0.8}};
for (int i = 0; i < inputs1.size(); ++i) {
QuantizedAddOpModel m({TensorType_UINT8, {1, 2, 2, 1}, -1.0, 1.0},
+ {TensorType_UINT8, {1, 2, 2, 1}, -1.0, 1.0},
{TensorType_UINT8, {}, -1.0, 1.0},
ActivationFunctionType_RELU_N1_TO_1);
m.QuantizeAndPopulate<uint8_t>(m.input1(), inputs1[i]);
@@ -150,6 +172,7 @@ TEST(QuantizedAddOpModel, QuantizedVariousInputShapes) {
{6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}};
for (int i = 0; i < test_shapes.size(); ++i) {
QuantizedAddOpModel m({TensorType_UINT8, test_shapes[i], -3.0, 3.0},
+ {TensorType_UINT8, test_shapes[i], -3.0, 3.0},
{TensorType_UINT8, {}, -3.0, 3.0},
ActivationFunctionType_NONE);
m.QuantizeAndPopulate<uint8_t>(m.input1(), {-2.0, 0.2, 0.7, 0.8, 1.1, 2.0});
@@ -162,6 +185,25 @@ TEST(QuantizedAddOpModel, QuantizedVariousInputShapes) {
}
}
+TEST(QuantizedAddOpModel, QuantizedWithBroadcast) {
+ float kQuantizedTolerance = GetTolerance(-3.0, 3.0);
+ std::vector<std::initializer_list<int>> test_shapes = {
+ {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}};
+ for (int i = 0; i < test_shapes.size(); ++i) {
+ QuantizedAddOpModel m({TensorType_UINT8, test_shapes[i], -3.0, 3.0},
+ {TensorType_UINT8, {}, -3.0, 3.0},
+ {TensorType_UINT8, {}, -3.0, 3.0},
+ ActivationFunctionType_NONE);
+ m.QuantizeAndPopulate<uint8_t>(m.input1(), {-2.0, 0.2, 0.7, 0.8, 1.1, 2.0});
+ m.QuantizeAndPopulate<uint8_t>(m.input2(), {0.1});
+ m.Invoke();
+ EXPECT_THAT(m.GetDequantizedOutput(),
+ ElementsAreArray(ArrayFloatNear({-1.9, 0.3, 0.8, 0.9, 1.2, 2.1},
+ kQuantizedTolerance)))
+ << "With shape number " << i;
+ }
+}
+
} // namespace
} // namespace tflite
int main(int argc, char** argv) {
diff --git a/tensorflow/contrib/lite/kernels/batch_to_space_nd.cc b/tensorflow/contrib/lite/kernels/batch_to_space_nd.cc
index d84a77039b..889239f932 100644
--- a/tensorflow/contrib/lite/kernels/batch_to_space_nd.cc
+++ b/tensorflow/contrib/lite/kernels/batch_to_space_nd.cc
@@ -57,6 +57,7 @@ TfLiteStatus ResizeOutputTensor(TfLiteContext* context,
BatchToSpaceNDContext* op_context) {
TfLiteIntArray* input_size = op_context->input->dims;
const int* block_shape = GetTensorData<int32>(op_context->block_shape);
+ const int* crops = GetTensorData<int32>(op_context->crops);
TF_LITE_ENSURE_EQ(context, NumDimensions(op_context->block_shape),
kBlockSizeDimensionNum);
@@ -65,7 +66,13 @@ TfLiteStatus ResizeOutputTensor(TfLiteContext* context,
TF_LITE_ENSURE_EQ(context, NumDimensions(op_context->crops),
kSpatialDimensionNum);
- // TODO(ycling): Add crops as part of calculation.
+ // TODO(ycling): Add crops as part of calculation. Remove check for a crops
+ // containing all zeroes.
+ TF_LITE_ENSURE_EQ(context, crops[0], 0);
+ TF_LITE_ENSURE_EQ(context, crops[1], 0);
+ TF_LITE_ENSURE_EQ(context, crops[2], 0);
+ TF_LITE_ENSURE_EQ(context, crops[3], 0);
+
// Number of batch must be multiple of (block_shape[0] * block_shape[1]).
TF_LITE_ENSURE_EQ(context,
input_size->data[0] % (block_shape[0] * block_shape[1]), 0);
diff --git a/tensorflow/contrib/lite/kernels/batch_to_space_nd_test.cc b/tensorflow/contrib/lite/kernels/batch_to_space_nd_test.cc
index c9152bf967..8485cde1b4 100644
--- a/tensorflow/contrib/lite/kernels/batch_to_space_nd_test.cc
+++ b/tensorflow/contrib/lite/kernels/batch_to_space_nd_test.cc
@@ -119,6 +119,19 @@ TEST(BatchToSpaceNDOpTest, InvalidShapeTest) {
"Cannot allocate tensors");
}
+TEST(BatchToSpaceNDOpTest, InvalidCropsConstTest) {
+ EXPECT_DEATH(BatchToSpaceNDOpConstModel({3, 2, 2, 1}, {2, 2}, {0, 0, 0, 1}),
+ "1 != 0");
+}
+
+TEST(BatchToSpaceNDOpTest, InvalidCropsDynamicTest) {
+ BatchToSpaceNDOpDynamicModel m({4, 2, 2, 1});
+ m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16});
+ m.SetBlockShape({2, 2});
+ m.SetCrops({0, 0, 1, 0});
+ EXPECT_DEATH(m.Invoke(), "1 != 0");
+}
+
} // namespace
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/mul.cc b/tensorflow/contrib/lite/kernels/mul.cc
index 81c73f2523..54575019de 100644
--- a/tensorflow/contrib/lite/kernels/mul.cc
+++ b/tensorflow/contrib/lite/kernels/mul.cc
@@ -37,7 +37,23 @@ constexpr int kInputTensor1 = 0;
constexpr int kInputTensor2 = 1;
constexpr int kOutputTensor = 0;
+struct OpData {
+ bool requires_broadcast;
+};
+
+void* Init(TfLiteContext* context, const char* buffer, size_t length) {
+ auto* data = new OpData;
+ data->requires_broadcast = false;
+ return data;
+}
+
+void Free(TfLiteContext* context, void* buffer) {
+ delete reinterpret_cast<OpData*>(buffer);
+}
+
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ OpData* data = reinterpret_cast<OpData*>(node->user_data);
+
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
@@ -45,43 +61,56 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
- TF_LITE_ENSURE_EQ(context, NumDimensions(input1), NumDimensions(input2));
- for (int i = 0; i < NumDimensions(input1); ++i) {
- TF_LITE_ENSURE_EQ(context, SizeOfDimension(input1, i),
- SizeOfDimension(input2, i));
- }
+ TF_LITE_ENSURE_EQ(context, input1->type, input2->type);
+ output->type = input2->type;
+
+ data->requires_broadcast = !HaveSameShapes(input1, input2);
- TF_LITE_ENSURE_EQ(context, input1->type, output->type);
- TF_LITE_ENSURE_EQ(context, input2->type, output->type);
+ TfLiteIntArray* output_size = nullptr;
+ if (data->requires_broadcast) {
+ TF_LITE_ENSURE_OK(context, CalculateShapeForBroadcast(
+ context, input1, input2, &output_size));
+ } else {
+ output_size = TfLiteIntArrayCopy(input1->dims);
+ }
- TfLiteIntArray* output_size = TfLiteIntArrayCopy(input1->dims);
return context->ResizeTensor(context, output, output_size);
}
template <KernelType kernel_type>
void EvalFloat(TfLiteContext* context, TfLiteNode* node,
- TfLiteMulParams* params, TfLiteTensor* input1,
- TfLiteTensor* input2, TfLiteTensor* output) {
+ TfLiteMulParams* params, const OpData* data,
+ TfLiteTensor* input1, TfLiteTensor* input2,
+ TfLiteTensor* output) {
float output_activation_min, output_activation_max;
CalculateActivationRangeFloat(params->activation, &output_activation_min,
&output_activation_max);
-#define TF_LITE_MUL(type) \
- type::Mul(GetTensorData<float>(input1), GetTensorDims(input1), \
- GetTensorData<float>(input2), GetTensorDims(input2), \
- output_activation_min, output_activation_max, \
- GetTensorData<float>(output), GetTensorDims(output))
+#define TF_LITE_MUL(type, opname) \
+ type::opname(GetTensorData<float>(input1), GetTensorDims(input1), \
+ GetTensorData<float>(input2), GetTensorDims(input2), \
+ output_activation_min, output_activation_max, \
+ GetTensorData<float>(output), GetTensorDims(output))
if (kernel_type == kReference) {
- TF_LITE_MUL(reference_ops);
+ if (data->requires_broadcast) {
+ TF_LITE_MUL(reference_ops, BroadcastMul);
+ } else {
+ TF_LITE_MUL(reference_ops, Mul);
+ }
} else {
- TF_LITE_MUL(optimized_ops);
+ if (data->requires_broadcast) {
+ TF_LITE_MUL(optimized_ops, BroadcastMul);
+ } else {
+ TF_LITE_MUL(optimized_ops, Mul);
+ }
}
#undef TF_LITE_MUL
}
template <KernelType kernel_type>
void EvalQuantized(TfLiteContext* context, TfLiteNode* node,
- TfLiteMulParams* params, TfLiteTensor* input1,
- TfLiteTensor* input2, TfLiteTensor* output) {
+ TfLiteMulParams* params, const OpData* data,
+ TfLiteTensor* input1, TfLiteTensor* input2,
+ TfLiteTensor* output) {
auto input1_offset = -input1->params.zero_point;
auto input2_offset = -input2->params.zero_point;
auto output_offset = output->params.zero_point;
@@ -98,17 +127,19 @@ void EvalQuantized(TfLiteContext* context, TfLiteNode* node,
CalculateActivationRangeUint8(params->activation, output,
&output_activation_min, &output_activation_max);
-#define TF_LITE_MUL(type) \
- type::BroadcastMul(GetTensorData<uint8_t>(input1), GetTensorDims(input1), \
- input1_offset, GetTensorData<uint8_t>(input2), \
- GetTensorDims(input2), input2_offset, output_offset, \
- output_multiplier, output_shift, output_activation_min, \
- output_activation_max, GetTensorData<uint8_t>(output), \
- GetTensorDims(output));
+#define TF_LITE_MUL(type, opname) \
+ type::opname(GetTensorData<uint8_t>(input1), GetTensorDims(input1), \
+ input1_offset, GetTensorData<uint8_t>(input2), \
+ GetTensorDims(input2), input2_offset, output_offset, \
+ output_multiplier, output_shift, output_activation_min, \
+ output_activation_max, GetTensorData<uint8_t>(output), \
+ GetTensorDims(output));
+ // The quantized version of Mul doesn't support activations, so we
+ // always use BroadcastMul.
if (kernel_type == kReference) {
- TF_LITE_MUL(reference_ops);
+ TF_LITE_MUL(reference_ops, BroadcastMul);
} else {
- TF_LITE_MUL(optimized_ops);
+ TF_LITE_MUL(optimized_ops, BroadcastMul);
}
#undef TF_LITE_MUL
}
@@ -116,15 +147,17 @@ void EvalQuantized(TfLiteContext* context, TfLiteNode* node,
template <KernelType kernel_type>
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
auto* params = reinterpret_cast<TfLiteMulParams*>(node->builtin_data);
+ OpData* data = reinterpret_cast<OpData*>(node->user_data);
TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
if (output->type == kTfLiteFloat32) {
- EvalFloat<kernel_type>(context, node, params, input1, input2, output);
+ EvalFloat<kernel_type>(context, node, params, data, input1, input2, output);
} else if (output->type == kTfLiteUInt8) {
- EvalQuantized<kernel_type>(context, node, params, input1, input2, output);
+ EvalQuantized<kernel_type>(context, node, params, data, input1, input2,
+ output);
} else {
context->ReportError(context,
"Mul only supports FLOAT32 and quantized UINT8 now.");
@@ -137,19 +170,19 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace mul
TfLiteRegistration* Register_MUL_REF() {
- static TfLiteRegistration r = {nullptr, nullptr, mul::Prepare,
+ static TfLiteRegistration r = {mul::Init, mul::Free, mul::Prepare,
mul::Eval<mul::kReference>};
return &r;
}
TfLiteRegistration* Register_MUL_GENERIC_OPT() {
- static TfLiteRegistration r = {nullptr, nullptr, mul::Prepare,
+ static TfLiteRegistration r = {mul::Init, mul::Free, mul::Prepare,
mul::Eval<mul::kGenericOptimized>};
return &r;
}
TfLiteRegistration* Register_MUL_NEON_OPT() {
- static TfLiteRegistration r = {nullptr, nullptr, mul::Prepare,
+ static TfLiteRegistration r = {mul::Init, mul::Free, mul::Prepare,
mul::Eval<mul::kNeonOptimized>};
return &r;
}
diff --git a/tensorflow/contrib/lite/kernels/mul_test.cc b/tensorflow/contrib/lite/kernels/mul_test.cc
index 8838b300c0..f1a30f8263 100644
--- a/tensorflow/contrib/lite/kernels/mul_test.cc
+++ b/tensorflow/contrib/lite/kernels/mul_test.cc
@@ -25,10 +25,11 @@ using ::testing::ElementsAreArray;
class BaseMulOpModel : public SingleOpModel {
public:
- BaseMulOpModel(TensorData input, TensorData output,
+ BaseMulOpModel(const TensorData& input1, const TensorData& input2,
+ const TensorData& output,
ActivationFunctionType activation_type) {
- input1_ = AddInput(input);
- input2_ = AddInput(input);
+ input1_ = AddInput(input1);
+ input2_ = AddInput(input2);
output_ = AddOutput(output);
SetBuiltinOp(BuiltinOperator_MUL, BuiltinOptions_MulOptions,
CreateMulOptions(builder_, activation_type).Union());
@@ -70,6 +71,7 @@ class QuantizedMulOpModel : public BaseMulOpModel {
TEST(FloatMulOpTest, NoActivation) {
FloatMulOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}},
+ {TensorType_FLOAT32, {1, 2, 2, 1}},
{TensorType_FLOAT32, {}}, ActivationFunctionType_NONE);
m.PopulateTensor<float>(m.input1(), {-2.0, 0.2, 0.7, 0.8});
m.PopulateTensor<float>(m.input2(), {0.1, 0.2, 0.3, 0.5});
@@ -79,9 +81,9 @@ TEST(FloatMulOpTest, NoActivation) {
}
TEST(FloatMulOpTest, ActivationRELU_N1_TO_1) {
- FloatMulOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}},
- {TensorType_FLOAT32, {}},
- ActivationFunctionType_RELU_N1_TO_1);
+ FloatMulOpModel m(
+ {TensorType_FLOAT32, {1, 2, 2, 1}}, {TensorType_FLOAT32, {1, 2, 2, 1}},
+ {TensorType_FLOAT32, {}}, ActivationFunctionType_RELU_N1_TO_1);
m.PopulateTensor<float>(m.input1(), {-2.0, 0.2, 0.7, 0.8});
m.PopulateTensor<float>(m.input2(), {0.1, 0.2, 0.3, 5});
m.Invoke();
@@ -94,6 +96,7 @@ TEST(FloatMulOpTest, VariousInputShapes) {
{6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}};
for (int i = 0; i < test_shapes.size(); ++i) {
FloatMulOpModel m({TensorType_FLOAT32, test_shapes[i]},
+ {TensorType_FLOAT32, test_shapes[i]},
{TensorType_FLOAT32, {}}, ActivationFunctionType_NONE);
m.PopulateTensor<float>(m.input1(), {-2.0, 0.2, 0.7, 0.8, 1.1, 2.0});
m.PopulateTensor<float>(m.input2(), {0.1, 0.2, 0.3, 0.5, 1.1, 0.1});
@@ -105,8 +108,26 @@ TEST(FloatMulOpTest, VariousInputShapes) {
}
}
+TEST(FloatMulOpTest, WithBroadcast) {
+ std::vector<std::initializer_list<int>> test_shapes = {
+ {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}};
+ for (int i = 0; i < test_shapes.size(); ++i) {
+ FloatMulOpModel m({TensorType_FLOAT32, test_shapes[i]},
+ {TensorType_FLOAT32, {}}, // always a scalar
+ {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE);
+ m.PopulateTensor<float>(m.input1(), {-2.0, 0.2, 0.7, 0.8, 1.1, 2.0});
+ m.PopulateTensor<float>(m.input2(), {0.1});
+ m.Invoke();
+ EXPECT_THAT(
+ m.GetOutput(),
+ ElementsAreArray(ArrayFloatNear({-0.2, 0.02, 0.07, 0.08, 0.11, 0.2})))
+ << "With shape number " << i;
+ }
+}
+
TEST(QuantizedMulOpTest, NoActivation) {
QuantizedMulOpModel m({TensorType_UINT8, {1, 2, 2, 1}, -1.0, 1.0},
+ {TensorType_UINT8, {1, 2, 2, 1}, -1.0, 1.0},
{TensorType_UINT8, {}, -1.0, 1.0},
ActivationFunctionType_NONE);
m.QuantizeAndPopulate<uint8_t>(m.input1(), {-0.8, 0.2, 0.9, 0.7});
@@ -117,6 +138,32 @@ TEST(QuantizedMulOpTest, NoActivation) {
kQuantizedTolerance)));
}
+// for quantized Mul, the error shouldn't exceed 2*step
+float GetTolerance(int min, int max) {
+ float kQuantizedStep = (max - min) / 255.0;
+ float kQuantizedTolerance = 2.0 * kQuantizedStep;
+ return kQuantizedTolerance;
+}
+
+TEST(QuantizedMulOpTest, WithBroadcast) {
+ float kQuantizedTolerance = GetTolerance(-3.0, 3.0);
+ std::vector<std::initializer_list<int>> test_shapes = {
+ {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}};
+ for (int i = 0; i < test_shapes.size(); ++i) {
+ QuantizedMulOpModel m({TensorType_UINT8, test_shapes[i], -3.0, 3.0},
+ {TensorType_UINT8, {}, -3.0, 3.0}, // always a scalar
+ {TensorType_UINT8, {}, -3.0, 3.0},
+ ActivationFunctionType_NONE);
+ m.QuantizeAndPopulate<uint8_t>(m.input1(), {-2.0, 0.2, 0.7, 0.8, 1.1, 2.0});
+ m.QuantizeAndPopulate<uint8_t>(m.input2(), {0.1});
+ m.Invoke();
+ EXPECT_THAT(m.GetDequantizedOutput(),
+ ElementsAreArray(ArrayFloatNear(
+ {-0.2, 0.02, 0.07, 0.08, 0.11, 0.2}, kQuantizedTolerance)))
+ << "With shape number " << i;
+ }
+}
+
} // namespace
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/strided_slice.cc b/tensorflow/contrib/lite/kernels/strided_slice.cc
index c510ee3b9f..c4ffdf79d3 100644
--- a/tensorflow/contrib/lite/kernels/strided_slice.cc
+++ b/tensorflow/contrib/lite/kernels/strided_slice.cc
@@ -57,63 +57,6 @@ struct StridedSliceContext {
int dims;
};
-TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
- TF_LITE_ENSURE_EQ(context, NumInputs(node), 4);
- TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
-
- StridedSliceContext op_context(context, node);
-
- // Ensure validity of input tensor and its dimension
- TF_LITE_ENSURE_EQ(context, NumDimensions(op_context.begin), 1);
- TF_LITE_ENSURE_EQ(context, NumDimensions(op_context.end), 1);
- TF_LITE_ENSURE_EQ(context, NumDimensions(op_context.strides), 1);
- TF_LITE_ENSURE_EQ(context, op_context.input->type, op_context.output->type);
- // Only INT32 begin/end/strides are supported
- // TODO(soroosh) add support for INT64
- TF_LITE_ENSURE_EQ(context, op_context.begin->type, kTfLiteInt32);
- TF_LITE_ENSURE_EQ(context, op_context.end->type, kTfLiteInt32);
- TF_LITE_ENSURE_EQ(context, op_context.strides->type, kTfLiteInt32);
- TF_LITE_ENSURE_MSG(context, op_context.dims <= 4,
- "StridedSlice op only supports 1D-4D input arrays.");
-
- // TODO(soroosh): add the following missing functionalities
- TF_LITE_ENSURE_MSG(context, op_context.params->ellipsis_mask == 0,
- "ellipsis_mask is not implemented yet.");
- TF_LITE_ENSURE_MSG(context, op_context.params->new_axis_mask == 0,
- "new_axis_mask is not implemented yet.");
-
- // TODO(soroosh): optimize for constant tensors to do allocation in Prepare
- op_context.output->allocation_type = kTfLiteDynamic;
- return kTfLiteOk;
-} // namespace strided_slice
-
-// TODO(soroosh): consolidate with BytesRequired in interpreter.h
-TfLiteStatus BytesRequired(TfLiteContext* context, TfLiteType type,
- const int* dims, int dims_size, size_t* bytes) {
- // TODO(aselle): Check for overflow here using overflow.h in TensorFlow
- // MultiplyWithoutOverflow.
- TF_LITE_ENSURE(context, bytes != nullptr);
- size_t count = 1;
- for (int k = 0; k < dims_size; k++) count *= dims[k];
- switch (type) {
- case kTfLiteFloat32:
- *bytes = sizeof(float) * count;
- break;
- case kTfLiteInt32:
- *bytes = sizeof(int32_t) * count;
- break;
- case kTfLiteUInt8:
- *bytes = sizeof(uint8_t) * count;
- break;
- case kTfLiteInt64:
- *bytes = sizeof(int64_t) * count;
- break;
- default:
- return kTfLiteError;
- }
- return kTfLiteOk;
-}
-
// Reverse order of bits in the mask to match the expected order in kernel
inline int ReverseMaskBits(int mask, int num_dimensions) {
int out = 0;
@@ -144,43 +87,44 @@ inline int32_t ClampedIndex(int32_t index, int dim, bool pos_stride) {
std::min(std::max(index, -dim), dim - 1), dim));
}
-template <KernelType kernel_type>
-TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
- StridedSliceContext op_context(context, node);
+inline int32_t GetBeginValueAtIndex(StridedSliceContext* op_context, int idx) {
+ const int dim = op_context->input->dims->data[idx];
+ const bool pos_stride = GetTensorData<int32_t>(op_context->strides)[idx] > 0;
+ return op_context->params->begin_mask & (1 << idx)
+ ? pos_stride ? 0 : dim - 1
+ : ClampedIndex(GetTensorData<int32_t>(op_context->begin)[idx], dim,
+ pos_stride);
+}
- std::vector<int> starts;
- std::vector<int> stops;
- std::vector<int> strides;
+inline int32_t GetEndValueAtIndex(StridedSliceContext* op_context, int idx) {
+ const int dim = op_context->input->dims->data[idx];
+ const bool pos_stride = GetTensorData<int32_t>(op_context->strides)[idx] > 0;
+ return op_context->params->end_mask & (1 << idx)
+ ? pos_stride ? dim : -1
+ : ClampedIndex(GetTensorData<int32_t>(op_context->end)[idx], dim,
+ pos_stride);
+}
+
+// Processes the indexing tensors (begin, end and strides) to resize the
+// output tensor. This function is callable from both Prepare() and Eval() as
+// long as the caller ensures the indexing tensors are present.
+TfLiteStatus ResizeOutputTensor(TfLiteContext* context,
+ StridedSliceContext* op_context) {
std::vector<int> output_shape_vector;
- for (int idx = op_context.dims - 1; idx >= 0; --idx) {
- int dim = op_context.input->dims->data[idx];
- int32_t stride = GetTensorData<int32_t>(op_context.strides)[idx];
+ for (int idx = op_context->dims - 1; idx >= 0; --idx) {
+ int32_t stride = GetTensorData<int32_t>(op_context->strides)[idx];
TF_LITE_ENSURE_MSG(context, stride != 0, "stride value has to be non-zero");
- bool pos_stride = stride > 0;
-
- int32_t begin =
- op_context.params->begin_mask & (1 << idx)
- ? pos_stride ? 0 : dim - 1
- : ClampedIndex(GetTensorData<int32_t>(op_context.begin)[idx], dim,
- pos_stride);
- int32_t end =
- op_context.params->end_mask & (1 << idx)
- ? pos_stride ? dim : -1
- : ClampedIndex(GetTensorData<int32_t>(op_context.end)[idx], dim,
- pos_stride);
+
+ int32_t begin = GetBeginValueAtIndex(op_context, idx);
+ int32_t end = GetEndValueAtIndex(op_context, idx);
// This is valid for both positive and negative strides
int32_t dim_shape = ceil((end - begin) / static_cast<float>(stride));
dim_shape = dim_shape < 0 ? 0 : dim_shape;
-
- if (!(op_context.params->shrink_axis_mask & (1 << idx))) {
+ if (!(op_context->params->shrink_axis_mask & (1 << idx))) {
output_shape_vector.push_back(dim_shape);
}
-
- starts.emplace_back(begin);
- stops.emplace_back(end);
- strides.emplace_back(stride);
}
TfLiteIntArray* output_shape =
@@ -189,22 +133,73 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
std::reverse_copy(output_shape_vector.begin(), output_shape_vector.end(),
output_shape->data);
+ TF_LITE_ENSURE_STATUS(
+ context->ResizeTensor(context, op_context->output, output_shape));
+
+ return kTfLiteOk;
+}
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ TF_LITE_ENSURE_EQ(context, NumInputs(node), 4);
+ TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
+
+ StridedSliceContext op_context(context, node);
+
+ // Ensure validity of input tensor and its dimension
+ TF_LITE_ENSURE_EQ(context, NumDimensions(op_context.begin), 1);
+ TF_LITE_ENSURE_EQ(context, NumDimensions(op_context.end), 1);
+ TF_LITE_ENSURE_EQ(context, NumDimensions(op_context.strides), 1);
+ TF_LITE_ENSURE_EQ(context, op_context.input->type, op_context.output->type);
+ // Only INT32 begin/end/strides are supported
+ // TODO(soroosh) add support for INT64
+ TF_LITE_ENSURE_EQ(context, op_context.begin->type, kTfLiteInt32);
+ TF_LITE_ENSURE_EQ(context, op_context.end->type, kTfLiteInt32);
+ TF_LITE_ENSURE_EQ(context, op_context.strides->type, kTfLiteInt32);
+ TF_LITE_ENSURE_MSG(context, op_context.dims <= 4,
+ "StridedSlice op only supports 1D-4D input arrays.");
+
+ // TODO(soroosh): add the following missing functionalities
+ TF_LITE_ENSURE_MSG(context, op_context.params->ellipsis_mask == 0,
+ "ellipsis_mask is not implemented yet.");
+ TF_LITE_ENSURE_MSG(context, op_context.params->new_axis_mask == 0,
+ "new_axis_mask is not implemented yet.");
+
+ // Postpone allocation of output if any of the indexing tensors is not
+ // constant
+ if (!(IsConstantTensor(op_context.begin) &&
+ IsConstantTensor(op_context.end) &&
+ IsConstantTensor(op_context.strides))) {
+ SetTensorToDynamic(op_context.output);
+ return kTfLiteOk;
+ }
+ return ResizeOutputTensor(context, &op_context);
+}
+
+template <KernelType kernel_type>
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ StridedSliceContext op_context(context, node);
+
+ if (IsDynamicTensor(op_context.output)) {
+ TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context));
+ TfLiteTensorRealloc(op_context.output->bytes, op_context.output);
+ }
+
+ std::vector<int32_t> starts;
+ std::vector<int32_t> stops;
+ std::vector<int32_t> strides;
+
+ for (int idx = op_context.dims - 1; idx >= 0; --idx) {
+ starts.emplace_back(GetBeginValueAtIndex(&op_context, idx));
+ stops.emplace_back(GetEndValueAtIndex(&op_context, idx));
+ strides.emplace_back(GetTensorData<int32_t>(op_context.strides)[idx]);
+ }
+
for (int i = op_context.dims; i < kMaxDim; i++) {
starts.emplace_back(0);
stops.emplace_back(1);
strides.emplace_back(1);
}
- TF_LITE_ENSURE_STATUS(
- context->ResizeTensor(context, op_context.output, output_shape));
-
- size_t required_bytes;
- TF_LITE_ENSURE_OK(
- context,
- BytesRequired(context, op_context.output->type, output_shape->data,
- output_shape->size, &required_bytes));
- TfLiteTensorRealloc(required_bytes, op_context.output);
-
op_context.params->begin_mask =
ReverseMaskBits(op_context.params->begin_mask, op_context.dims);
op_context.params->end_mask =
diff --git a/tensorflow/contrib/lite/testing/generate_examples.py b/tensorflow/contrib/lite/testing/generate_examples.py
index e7606eecc4..b2227a7c98 100644
--- a/tensorflow/contrib/lite/testing/generate_examples.py
+++ b/tensorflow/contrib/lite/testing/generate_examples.py
@@ -1560,10 +1560,11 @@ def make_strided_slice_tests(zip_path):
"input_shape": [[12, 2, 2, 5]],
"begin": [[0, 0, 0, 0], [1, 0, 1, 0]],
"end": [[8, 2, 2, 3], [12, 2, 2, 5]],
- "strides": [None, [1, 1, 1, 1], [2, 1, 3, 1]],
- "begin_mask": [None, 1, 2, 8],
- "end_mask": [None, 1, 2, 8],
- "shrink_axis_mask": [None, 1, 2, 4, 8, 11, 15, -1],
+ "strides": [None, [2, 1, 3, 1]],
+ "begin_mask": [None, 1, 8],
+ "end_mask": [None, 1, 8],
+ "shrink_axis_mask": [None, 1, 8, 11, 15, -1],
+ "constant_indices": [False, True],
},
# 2-D
{
@@ -1572,10 +1573,11 @@ def make_strided_slice_tests(zip_path):
"input_shape": [[2, 3]],
"begin": [[0, 0], [1, 0]],
"end": [[2, 3], [2, 2]],
- "strides": [None, [1, 1], [2, 2]],
+ "strides": [None, [2, 2]],
"begin_mask": [None, 1, 2],
"end_mask": [None, 1, 2],
"shrink_axis_mask": [None, 1, 2, 3, -1],
+ "constant_indices": [False, True],
},
# Negative strides
{
@@ -1588,6 +1590,7 @@ def make_strided_slice_tests(zip_path):
"begin_mask": [None, 1, 2],
"end_mask": [None, 1, 2],
"shrink_axis_mask": [None, 1, 2, 3, -1],
+ "constant_indices": [False],
},
]
@@ -1597,23 +1600,29 @@ def make_strided_slice_tests(zip_path):
dtype=parameters["dtype"],
name="input",
shape=parameters["input_shape"])
- begin = tf.placeholder(
- dtype=parameters["index_type"],
- name="begin",
- shape=[len(parameters["input_shape"])])
- end = tf.placeholder(
- dtype=parameters["index_type"],
- name="end",
- shape=[len(parameters["input_shape"])])
- strides = (
- tf.placeholder(
- dtype=parameters["index_type"],
- name="strides",
- shape=[len(parameters["input_shape"])])
- if parameters["strides"] is not None else None)
- tensors = [input_tensor, begin, end]
- if strides is not None:
- tensors.append(strides)
+ if parameters["constant_indices"]:
+ begin = parameters["begin"]
+ end = parameters["end"]
+ strides = parameters["strides"]
+ tensors = [input_tensor]
+ else:
+ begin = tf.placeholder(
+ dtype=parameters["index_type"],
+ name="begin",
+ shape=[len(parameters["input_shape"])])
+ end = tf.placeholder(
+ dtype=parameters["index_type"],
+ name="end",
+ shape=[len(parameters["input_shape"])])
+ strides = (
+ tf.placeholder(
+ dtype=parameters["index_type"],
+ name="strides",
+ shape=[len(parameters["input_shape"])])
+ if parameters["strides"] is not None else None)
+ tensors = [input_tensor, begin, end]
+ if strides is not None:
+ tensors.append(strides)
out = tf.strided_slice(
input_tensor,
begin,
@@ -1628,14 +1637,17 @@ def make_strided_slice_tests(zip_path):
input_values = create_tensor_data(parameters["dtype"],
parameters["input_shape"])
index_type = _TF_TYPE_INFO[parameters["index_type"]][0]
- begin_values = np.array(parameters["begin"]).astype(index_type)
- end_values = np.array(parameters["end"]).astype(index_type)
- stride_values = (
- np.array(parameters["strides"]).astype(index_type)
- if parameters["strides"] is not None else None)
- values = [input_values, begin_values, end_values]
- if stride_values is not None:
- values.append(stride_values)
+ values = [input_values]
+ if not parameters["constant_indices"]:
+ begin_values = np.array(parameters["begin"]).astype(index_type)
+ end_values = np.array(parameters["end"]).astype(index_type)
+ stride_values = (
+ np.array(parameters["strides"]).astype(index_type)
+ if parameters["strides"] is not None else None)
+ values.append(begin_values)
+ values.append(end_values)
+ if stride_values is not None:
+ values.append(stride_values)
return values, sess.run(outputs, feed_dict=dict(zip(inputs, values)))
diff --git a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc
index d73c9937ce..e8b425a592 100644
--- a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc
+++ b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc
@@ -47,9 +47,7 @@ tensorflow::Env* env = tensorflow::Env::Default();
// Key is a substring of the test name and value is a bug number.
// TODO(ahentz): make sure we clean this list up frequently.
std::map<string, string> kBrokenTests = {
- // Add doesn't support broadcasting.
- {R"(^\/adda.*input_shape_1=\[1,3,4,3\],input_shape_2=\[3\])", "68500195"},
- {R"(^\/mula.*input_shape_1=\[1,3,4,3\],input_shape_2=\[3\])", "68500195"},
+ // Sub and Div don't support broadcasting.
{R"(^\/diva.*input_shape_1=\[1,3,4,3\],input_shape_2=\[3\])", "68500195"},
{R"(^\/suba.*input_shape_1=\[1,3,4,3\],input_shape_2=\[3\])", "68500195"},
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
index 4fb3b6ae7a..7f26884bc1 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
@@ -1120,7 +1120,8 @@ void ProcessStridedSliceOperator(Model* model, StridedSliceOperator* op) {
stop += input_array.shape().dims(i);
}
- int dim_size = (stop - start) / op->strides[i];
+ int dim_size = ceil((stop - start) / static_cast<float>(op->strides[i]));
+ dim_size = dim_size < 0 ? 0 : dim_size;
if (op->shrink_axis_mask & mask) {
CHECK_EQ(dim_size, 1) << "Output size for an axis must compute to 1 when "
"shrinking that axis";
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc
index ca378af4c5..9862dbe99d 100644
--- a/tensorflow/contrib/lite/toco/import_tensorflow.cc
+++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc
@@ -173,7 +173,8 @@ void ImportFloatArray(const TensorProto& input_tensor, Array* output_array) {
}
auto& output_float_data =
output_array->GetMutableBuffer<ArrayDataType::kFloat>().data;
- output_float_data.resize(input_flat_size);
+ output_float_data.resize(RequiredBufferSizeForShape(output_array->shape()),
+ 0.f);
if (input_tensor.float_val_size() == 1) {
for (int i = 0; i < input_flat_size; i++) {
output_float_data[i] = input_tensor.float_val(0);
@@ -203,7 +204,7 @@ void ImportQuint8Array(const TensorProto& input_tensor, Array* output_array) {
}
auto& output_int_data =
output_array->GetMutableBuffer<ArrayDataType::kUint8>().data;
- output_int_data.resize(input_flat_size);
+ output_int_data.resize(RequiredBufferSizeForShape(output_array->shape()), 0);
if (input_tensor.int_val_size()) {
for (int i = 0; i < input_tensor.int_val_size(); i++) {
output_int_data[i] = input_tensor.int_val(i);
@@ -229,7 +230,7 @@ void ImportInt32Array(const TensorProto& input_tensor, Array* output_array) {
}
auto& output_int_data =
output_array->GetMutableBuffer<ArrayDataType::kInt32>().data;
- output_int_data.resize(input_flat_size);
+ output_int_data.resize(RequiredBufferSizeForShape(output_array->shape()), 0);
if (input_tensor.int_val_size()) {
for (int i = 0; i < input_tensor.int_val_size(); i++) {
output_int_data[i] = input_tensor.int_val(i);
@@ -255,7 +256,7 @@ void ImportInt64Array(const TensorProto& input_tensor, Array* output_array) {
}
auto& output_int_data =
output_array->GetMutableBuffer<ArrayDataType::kInt64>().data;
- output_int_data.resize(input_flat_size);
+ output_int_data.resize(RequiredBufferSizeForShape(output_array->shape()), 0);
if (input_tensor.int64_val_size()) {
for (int i = 0; i < input_tensor.int64_val_size(); i++) {
output_int_data[i] = input_tensor.int64_val(i);
@@ -281,7 +282,7 @@ void ImportStringArray(const TensorProto& input_tensor, Array* output_array) {
}
auto& output_string_data =
output_array->GetMutableBuffer<ArrayDataType::kString>().data;
- output_string_data.resize(input_flat_size);
+ output_string_data.resize(RequiredBufferSizeForShape(output_array->shape()));
if (input_flat_size != input_tensor.string_val_size()) {
LOG(FATAL) << "Input_content string_val doesn't have the right "
"dimensions for this string tensor.";
diff --git a/tensorflow/contrib/meta_graph_transform/meta_graph_transform.py b/tensorflow/contrib/meta_graph_transform/meta_graph_transform.py
index 2932ae1c8d..ff88b4fa84 100644
--- a/tensorflow/contrib/meta_graph_transform/meta_graph_transform.py
+++ b/tensorflow/contrib/meta_graph_transform/meta_graph_transform.py
@@ -171,7 +171,14 @@ def _clean_save_and_restore(graph_def, op, removed_op_names):
shape_op_value_tensor.tensor_shape.dim[0].size = len(shapes)
op.attr['dtypes'].list.type[:] = dtypes
+ if not name_op.attr['_output_shapes'].list.shape:
+ name_op.attr['_output_shapes'].list.shape.add()
+ name_op.attr['_output_shapes'].list.shape[0].dim.add()
name_op.attr['_output_shapes'].list.shape[0].dim[0].size = len(names)
+
+ if not shape_op.attr['_output_shapes'].list.shape:
+ shape_op.attr['_output_shapes'].list.shape.add()
+ shape_op.attr['_output_shapes'].list.shape[0].dim.add()
shape_op.attr['_output_shapes'].list.shape[0].dim[0].size = len(shapes)
diff --git a/tensorflow/contrib/py2tf/BUILD b/tensorflow/contrib/py2tf/BUILD
index 3e846aefeb..cea3738499 100644
--- a/tensorflow/contrib/py2tf/BUILD
+++ b/tensorflow/contrib/py2tf/BUILD
@@ -18,69 +18,12 @@ py_library(
name = "py2tf",
srcs = [
"__init__.py",
- "api.py",
- "config.py",
- "conversion.py",
- "naming.py",
],
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
deps = [
- "//tensorflow/contrib/py2tf/converters",
- "//tensorflow/contrib/py2tf/pyct",
- "//tensorflow/contrib/py2tf/pyct/static_analysis",
+ "//tensorflow/contrib/py2tf/impl",
"@gast_archive//:gast",
"@six_archive//:six",
],
)
-
-# Separate target that allows access to internal symbols for testing.
-py_library(
- name = "py2tf_internal",
- srcs = [
- "api.py",
- "config.py",
- "conversion.py",
- "naming.py",
- ],
- srcs_version = "PY2AND3",
- visibility = ["//tensorflow:__subpackages__"],
- deps = [
- "//tensorflow/contrib/py2tf/converters",
- "//tensorflow/contrib/py2tf/pyct",
- "//tensorflow/contrib/py2tf/pyct/static_analysis",
- "@gast_archive//:gast",
- "@six_archive//:six",
- ],
-)
-
-py_test(
- name = "api_test",
- srcs = ["api_test.py"],
- srcs_version = "PY2AND3",
- deps = [
- ":py2tf_internal",
- "//tensorflow/python:client_testlib",
- ],
-)
-
-py_test(
- name = "conversion_test",
- srcs = ["conversion_test.py"],
- srcs_version = "PY2AND3",
- deps = [
- ":py2tf_internal",
- "//tensorflow/python:client_testlib",
- "@gast_archive//:gast",
- ],
-)
-
-py_test(
- name = "naming_test",
- srcs = ["naming_test.py"],
- srcs_version = "PY2AND3",
- deps = [
- ":py2tf_internal",
- "//tensorflow/python:client_testlib",
- ],
-)
diff --git a/tensorflow/contrib/py2tf/__init__.py b/tensorflow/contrib/py2tf/__init__.py
index d187da99e0..878941b3a3 100644
--- a/tensorflow/contrib/py2tf/__init__.py
+++ b/tensorflow/contrib/py2tf/__init__.py
@@ -21,11 +21,12 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.py2tf.api import to_code
-from tensorflow.contrib.py2tf.api import to_graph
+from tensorflow.contrib.py2tf.impl.api import convert
+from tensorflow.contrib.py2tf.impl.api import graph_ready
+from tensorflow.contrib.py2tf.impl.api import to_code
+from tensorflow.contrib.py2tf.impl.api import to_graph
from tensorflow.python.util.all_util import remove_undocumented
-
-_allowed_symbols = ['to_graph', 'to_code']
+_allowed_symbols = ['to_graph', 'to_code', 'convert', 'graph_ready']
remove_undocumented(__name__, _allowed_symbols)
diff --git a/tensorflow/contrib/py2tf/impl/BUILD b/tensorflow/contrib/py2tf/impl/BUILD
new file mode 100644
index 0000000000..22f0c25cab
--- /dev/null
+++ b/tensorflow/contrib/py2tf/impl/BUILD
@@ -0,0 +1,65 @@
+licenses(["notice"]) # Apache 2.0
+
+load("//tensorflow:tensorflow.bzl", "py_test")
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
+
+py_library(
+ name = "impl",
+ srcs = [
+ "api.py",
+ "config.py",
+ "conversion.py",
+ "naming.py",
+ ],
+ srcs_version = "PY2AND3",
+ visibility = ["//tensorflow:__subpackages__"],
+ deps = [
+ "//tensorflow/contrib/py2tf/converters",
+ "//tensorflow/contrib/py2tf/pyct",
+ "//tensorflow/contrib/py2tf/pyct/static_analysis",
+ "@gast_archive//:gast",
+ "@six_archive//:six",
+ ],
+)
+
+py_test(
+ name = "api_test",
+ srcs = ["api_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":impl",
+ "//tensorflow/python:client_testlib",
+ ],
+)
+
+py_test(
+ name = "conversion_test",
+ srcs = ["conversion_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":impl",
+ "//tensorflow/python:client_testlib",
+ "@gast_archive//:gast",
+ ],
+)
+
+py_test(
+ name = "naming_test",
+ srcs = ["naming_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":impl",
+ "//tensorflow/python:client_testlib",
+ ],
+)
diff --git a/tensorflow/contrib/py2tf/api.py b/tensorflow/contrib/py2tf/impl/api.py
index ca1f4e2645..4b8cf0527a 100644
--- a/tensorflow/contrib/py2tf/api.py
+++ b/tensorflow/contrib/py2tf/impl/api.py
@@ -23,8 +23,8 @@ from functools import wraps
import gast
import six
-from tensorflow.contrib.py2tf import config
-from tensorflow.contrib.py2tf import conversion
+from tensorflow.contrib.py2tf.impl import config
+from tensorflow.contrib.py2tf.impl import conversion
from tensorflow.contrib.py2tf.pyct import compiler
from tensorflow.contrib.py2tf.pyct import parser
from tensorflow.python.util import tf_inspect
diff --git a/tensorflow/contrib/py2tf/api_test.py b/tensorflow/contrib/py2tf/impl/api_test.py
index 2384447708..dbd079a3ca 100644
--- a/tensorflow/contrib/py2tf/api_test.py
+++ b/tensorflow/contrib/py2tf/impl/api_test.py
@@ -18,8 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.py2tf import api
-from tensorflow.contrib.py2tf import config
+from tensorflow.contrib.py2tf.impl import api
+from tensorflow.contrib.py2tf.impl import config
from tensorflow.contrib.py2tf.pyct import parser
from tensorflow.python.framework import constant_op
from tensorflow.python.ops import math_ops
diff --git a/tensorflow/contrib/py2tf/config.py b/tensorflow/contrib/py2tf/impl/config.py
index 8c502a7a9e..0892241983 100644
--- a/tensorflow/contrib/py2tf/config.py
+++ b/tensorflow/contrib/py2tf/impl/config.py
@@ -32,7 +32,8 @@ DEFAULT_UNCOMPILED_MODULES = set((
NO_SIDE_EFFECT_CONSTRUCTORS = set(('tensorflow',))
# TODO(mdan): Also allow controlling the generated names (for testability).
+# TODO(mdan): Verify that these names are not hidden by generated code.
+# TODO(mdan): Make sure copybara renames the reference below.
COMPILED_IMPORT_STATEMENTS = (
- 'from contextlib import contextmanager',
'import tensorflow as tf',
)
diff --git a/tensorflow/contrib/py2tf/conversion.py b/tensorflow/contrib/py2tf/impl/conversion.py
index 67ca52d194..ed71ff5c06 100644
--- a/tensorflow/contrib/py2tf/conversion.py
+++ b/tensorflow/contrib/py2tf/impl/conversion.py
@@ -21,8 +21,6 @@ from __future__ import print_function
import gast
import six
-from tensorflow.contrib.py2tf import config
-from tensorflow.contrib.py2tf import naming
from tensorflow.contrib.py2tf.converters import asserts
from tensorflow.contrib.py2tf.converters import break_canonicalization
from tensorflow.contrib.py2tf.converters import builtin_functions
@@ -34,6 +32,8 @@ from tensorflow.contrib.py2tf.converters import for_canonicalization
from tensorflow.contrib.py2tf.converters import logical_expressions
from tensorflow.contrib.py2tf.converters import print_functions
from tensorflow.contrib.py2tf.converters import side_effect_guards
+from tensorflow.contrib.py2tf.impl import config
+from tensorflow.contrib.py2tf.impl import naming
from tensorflow.contrib.py2tf.pyct import context
from tensorflow.contrib.py2tf.pyct import parser
from tensorflow.contrib.py2tf.pyct.static_analysis import access
diff --git a/tensorflow/contrib/py2tf/conversion_test.py b/tensorflow/contrib/py2tf/impl/conversion_test.py
index 26f915f4f4..3888958f19 100644
--- a/tensorflow/contrib/py2tf/conversion_test.py
+++ b/tensorflow/contrib/py2tf/impl/conversion_test.py
@@ -20,7 +20,7 @@ from __future__ import print_function
import gast
-from tensorflow.contrib.py2tf import conversion
+from tensorflow.contrib.py2tf.impl import conversion
from tensorflow.python.platform import test
diff --git a/tensorflow/contrib/py2tf/naming.py b/tensorflow/contrib/py2tf/impl/naming.py
index 5c7e4c5f95..5c7e4c5f95 100644
--- a/tensorflow/contrib/py2tf/naming.py
+++ b/tensorflow/contrib/py2tf/impl/naming.py
diff --git a/tensorflow/contrib/py2tf/naming_test.py b/tensorflow/contrib/py2tf/impl/naming_test.py
index 5cf0a3da2c..beb4e54937 100644
--- a/tensorflow/contrib/py2tf/naming_test.py
+++ b/tensorflow/contrib/py2tf/impl/naming_test.py
@@ -18,7 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.py2tf import naming
+from tensorflow.contrib.py2tf.impl import naming
from tensorflow.python.platform import test
diff --git a/tensorflow/contrib/quantize/python/graph_matcher.py b/tensorflow/contrib/quantize/python/graph_matcher.py
index e3581cc559..b458f039df 100644
--- a/tensorflow/contrib/quantize/python/graph_matcher.py
+++ b/tensorflow/contrib/quantize/python/graph_matcher.py
@@ -18,8 +18,19 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import abc
-class OpTypePattern(object):
+
+class Pattern(object):
+ """The parent class of all patterns (e.g. OpTypePattern and OneofPattern)."""
+
+ @abc.abstractmethod
+ def match(self, op, tensor):
+ """Returns the result of matching op/tensor against this pattern."""
+ raise NotImplementedError('Method "match" not implemented.')
+
+
+class OpTypePattern(Pattern):
"""A tree pattern that matches TF expressions with certain op types."""
def __init__(self, op_type, name=None, inputs=None):
@@ -34,7 +45,7 @@ class OpTypePattern(object):
similar TF op types.
name: Optional string. The name of the pattern that can be looked up in
MatchResult.
- inputs: Optional list of `OpTypePattern`s or strings that specify the
+ inputs: Optional list of `Pattern`s or strings that specify the
patterns for the inputs of a matching op. If None, this pattern accepts
any inputs of a matching op.
"""
@@ -43,22 +54,51 @@ class OpTypePattern(object):
if inputs is None:
inputs = []
self._inputs = [
- input_pattern if isinstance(input_pattern, OpTypePattern) else
- OpTypePattern(input_pattern) for input_pattern in inputs
+ input_pattern
+ if isinstance(input_pattern, Pattern) else OpTypePattern(input_pattern)
+ for input_pattern in inputs
]
@property
- def op_type(self):
- return self._op_type
-
- @property
- def inputs(self):
- return self._inputs
-
- @property
def name(self):
return self._name
+ def match(self, op, tensor):
+ if self._op_type != '*':
+ if op.type not in self._op_type.split('|'):
+ return None
+
+ match_result = MatchResult()
+ match_result.add(self, op, tensor)
+
+ if not self._inputs:
+ # If pattern.inputs is empty, skips the rest and accepts all the inputs.
+ return match_result
+
+ if len(op.inputs) != len(self._inputs):
+ return None
+
+ for input_tensor, input_pattern in zip(op.inputs, self._inputs):
+ input_match_result = input_pattern.match(input_tensor.op, input_tensor)
+ if input_match_result is None:
+ return None
+ match_result.merge_from(input_match_result)
+ return match_result
+
+
+class OneofPattern(Pattern):
+ """Matches one of the given sub-patterns."""
+
+ def __init__(self, sub_patterns):
+ self._sub_patterns = sub_patterns
+
+ def match(self, op, tensor):
+ for sub_pattern in self._sub_patterns:
+ match_result = sub_pattern.match(op, tensor)
+ if match_result is not None:
+ return match_result
+ return None
+
class MatchResult(object):
r"""Encapsulates the result of a match done by GraphMatcher.
@@ -102,16 +142,36 @@ class MatchResult(object):
return pattern_or_name
if isinstance(pattern_or_name, str):
+ if pattern_or_name not in self._name_to_pattern:
+ return None
return self._name_to_pattern[pattern_or_name]
raise ValueError('pattern_or_name has type %s. Expect OpTypePattern or str.'
% type(pattern_or_name))
+ def _get_op_tensor(self, pattern_or_name):
+ pattern = self._to_pattern(pattern_or_name)
+ if pattern is None:
+ return None
+
+ if pattern not in self._pattern_to_op_tensor:
+ return None
+
+ return self._pattern_to_op_tensor[pattern]
+
def get_op(self, pattern_or_name):
- return self._pattern_to_op_tensor[self._to_pattern(pattern_or_name)][0]
+ op_tensor = self._get_op_tensor(pattern_or_name)
+ return op_tensor[0] if op_tensor else None
def get_tensor(self, pattern_or_name):
- return self._pattern_to_op_tensor[self._to_pattern(pattern_or_name)][1]
+ op_tensor = self._get_op_tensor(pattern_or_name)
+ return op_tensor[1] if op_tensor else None
+
+ def merge_from(self, other_match_result):
+ # pylint: disable=protected-access
+ self._pattern_to_op_tensor.update(other_match_result._pattern_to_op_tensor)
+ self._name_to_pattern.update(other_match_result._name_to_pattern)
+ # pylint: enable=protected-access
class GraphMatcher(object):
@@ -121,7 +181,7 @@ class GraphMatcher(object):
"""Initializes a GraphMatcher.
Args:
- pattern: The `OpTypePattern` against which `GraphMatcher` matches
+ pattern: The `Pattern` against which `GraphMatcher` matches
subgraphs.
"""
self._pattern = pattern
@@ -133,7 +193,7 @@ class GraphMatcher(object):
with key `pattern`.
Args:
- pattern: An `OpTypePattern`.
+ pattern: An `Pattern`.
op: A `tf.Operation` to match against the pattern.
tensor: the output `tf.Tensor` of `op` that is used by the matching op of
`pattern`'s parent. Can be None if `pattern` is already the root of the
@@ -142,20 +202,11 @@ class GraphMatcher(object):
Returns:
True if an TF expression rooted at `op` matches `pattern`.
"""
- if pattern.op_type != '*':
- if op.type not in pattern.op_type.split('|'):
- return False
-
- self._match_result.add(pattern, op, tensor)
-
- if not pattern.inputs:
- # If pattern.inputs is empty, skips the rest and accepts all the inputs.
- return True
-
- return len(op.inputs) == len(pattern.inputs) and all([
- self._match_pattern(input_pattern, input_tensor.op, input_tensor)
- for input_tensor, input_pattern in zip(op.inputs, pattern.inputs)
- ])
+ match_result = pattern.match(op, tensor)
+ if match_result is None:
+ return False
+ self._match_result.merge_from(match_result)
+ return True
def match_op(self, op):
"""Matches `op` against `self._pattern`.
diff --git a/tensorflow/contrib/quantize/python/graph_matcher_test.py b/tensorflow/contrib/quantize/python/graph_matcher_test.py
index e1572865e4..6d58757218 100644
--- a/tensorflow/contrib/quantize/python/graph_matcher_test.py
+++ b/tensorflow/contrib/quantize/python/graph_matcher_test.py
@@ -105,7 +105,7 @@ class GraphMatcherTest(test_util.TensorFlowTestCase):
self.assertEqual(match_result.get_op(y1_pattern), y1.op)
self.assertEqual(match_result.get_tensor(y1_pattern), y1)
- def test_oneof_pattern(self):
+ def test_oneof_type_pattern(self):
# - +
# / \ / \
# x y z
@@ -125,6 +125,44 @@ class GraphMatcherTest(test_util.TensorFlowTestCase):
for match_result in matcher.match_graph(g)
], [plus.op, minus.op])
+ def test_oneof_pattern(self):
+ reshape_pattern = graph_matcher.OpTypePattern('Reshape')
+ transpose_pattern = graph_matcher.OneofPattern([
+ graph_matcher.OpTypePattern(
+ 'Transpose',
+ name='transpose',
+ inputs=[
+ graph_matcher.OpTypePattern(
+ 'Slice', name='slice', inputs=[reshape_pattern, '*', '*']),
+ '*'
+ ]),
+ graph_matcher.OpTypePattern(
+ 'Transpose', name='transpose', inputs=[reshape_pattern, '*'])
+ ])
+
+ matcher = graph_matcher.GraphMatcher(transpose_pattern)
+
+ g = ops.Graph()
+ with g.as_default():
+ inputs = array_ops.placeholder(dtypes.float32, shape=[6])
+ reshape = array_ops.reshape(inputs, [2, 3])
+ transpose = array_ops.transpose(reshape)
+ [match_result] = list(matcher.match_graph(g))
+ self.assertEqual(match_result.get_tensor(reshape_pattern), reshape)
+ self.assertEqual(match_result.get_tensor('slice'), None)
+ self.assertEqual(match_result.get_op('transpose'), transpose.op)
+
+ g = ops.Graph()
+ with g.as_default():
+ inputs = array_ops.placeholder(dtypes.float32, shape=[6])
+ reshape = array_ops.reshape(inputs, [2, 3])
+ slicing = array_ops.slice(reshape, [0, 0], [-1, -1])
+ transpose = array_ops.transpose(slicing)
+ [match_result] = list(matcher.match_graph(g))
+ self.assertEqual(match_result.get_tensor(reshape_pattern), reshape)
+ self.assertEqual(match_result.get_tensor('slice'), slicing)
+ self.assertEqual(match_result.get_op('transpose'), transpose.op)
+
if __name__ == '__main__':
googletest.main()
diff --git a/tensorflow/contrib/summary/summary_ops.py b/tensorflow/contrib/summary/summary_ops.py
index ee661dfdc1..a6968d8b2a 100644
--- a/tensorflow/contrib/summary/summary_ops.py
+++ b/tensorflow/contrib/summary/summary_ops.py
@@ -202,7 +202,7 @@ def create_file_writer(logdir,
if flush_millis is None:
flush_millis = constant_op.constant(2 * 60 * 1000)
if filename_suffix is None:
- filename_suffix = constant_op.constant("")
+ filename_suffix = constant_op.constant(".v2")
return _make_summary_writer(
name,
gen_summary_ops.create_summary_file_writer,
diff --git a/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc b/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc
index 6a05a2abf6..b1ef9fde37 100644
--- a/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc
+++ b/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc
@@ -29,6 +29,7 @@ limitations under the License.
#include "tensorflow/contrib/tpu/profiler/version.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/util/command_line_flags.h"
@@ -47,6 +48,19 @@ string GetCurrentTimeStampAsString() {
return s;
}
+Status ValidateHostPortPair(const string& host_port) {
+ uint32 port;
+ std::vector<string> parts = str_util::Split(host_port, ':');
+ // Must be host:port, port must be a number, host must not contain a '/',
+ // host also must not be empty.
+ if (parts.size() != 2 || !strings::safe_strtou32(parts[1], &port) ||
+ parts[0].find("/") != string::npos || parts[0].empty()) {
+ return errors::InvalidArgument("Could not interpret \"", host_port,
+ "\" as a host-port pair.");
+ }
+ return Status::OK();
+}
+
ProfileResponse Profile(const string& service_addr, int duration_ms,
const ProfileOptions& opts) {
ProfileRequest request;
@@ -60,11 +74,14 @@ ProfileResponse Profile(const string& service_addr, int duration_ms,
::grpc::ClientContext context;
::grpc::ChannelArguments channel_args;
// TODO(ioeric): use `SetMaxReceiveMessageSize` instead once it's available.
+ // TODO(qiuminxu): use `NewHostPortGrpcChannel` instead once their
+ // `ValidateHostPortPair` checks for empty host string case.
channel_args.SetInt(GRPC_ARG_MAX_MESSAGE_LENGTH,
std::numeric_limits<int32>::max());
std::unique_ptr<TPUProfiler::Stub> stub =
TPUProfiler::NewStub(::grpc::CreateCustomChannel(
- service_addr, ::grpc::InsecureChannelCredentials(), channel_args));
+ "dns:///" + service_addr, ::grpc::InsecureChannelCredentials(),
+ channel_args));
ProfileResponse response;
TF_QCHECK_OK(FromGrpcStatus(stub->Profile(&context, request, &response)));
return response;
@@ -101,7 +118,14 @@ int main(int argc, char** argv) {
tensorflow::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
bool parse_ok = tensorflow::Flags::Parse(&argc, argv, flag_list);
if (!parse_ok || FLAGS_service_addr.empty() || FLAGS_logdir.empty()) {
- std::printf("%s", usage.c_str());
+ std::cout << usage.c_str() << std::endl;
+ return 2;
+ }
+ tensorflow::Status status =
+ tensorflow::tpu::ValidateHostPortPair(FLAGS_service_addr);
+ if (!status.ok()) {
+ std::cout << status.error_message() << std::endl;
+ std::cout << usage.c_str() << std::endl;
return 2;
}
tensorflow::port::InitMain(argv[0], &argc, &argv);
@@ -130,6 +154,8 @@ int main(int argc, char** argv) {
<< std::endl
<< "Tip: increase number of attempts with --num_tracing_attempts."
<< std::endl;
+ // Don't dump profile data if no trace is collected.
+ return 0;
}
// Use the current timestamp as the run name.
diff --git a/tensorflow/contrib/tpu/profiler/dump_tpu_profile.cc b/tensorflow/contrib/tpu/profiler/dump_tpu_profile.cc
index 64e4e6275d..ebd6185faa 100644
--- a/tensorflow/contrib/tpu/profiler/dump_tpu_profile.cc
+++ b/tensorflow/contrib/tpu/profiler/dump_tpu_profile.cc
@@ -151,8 +151,7 @@ Status WriteTensorboardTPUProfile(const string& logdir, const string& run,
TF_RETURN_IF_ERROR(Env::Default()->RecursivelyCreateDir(profile_run_dir));
// Ignore computation_graph for now.
- const bool empty_trace = response.encoded_trace().empty();
- if (!empty_trace) {
+ if (!response.encoded_trace().empty()) {
LOG(INFO) << "Converting trace events to TraceViewer JSON.";
TF_RETURN_IF_ERROR(
DumpTraceToLogDirectory(profile_run_dir, response.encoded_trace(), os));
@@ -163,11 +162,9 @@ Status WriteTensorboardTPUProfile(const string& logdir, const string& run,
TF_RETURN_IF_ERROR(DumpOpProfileToLogDirectory(profile_run_dir,
response.op_profile(), os));
}
- if (!empty_trace && !response.tool_data().empty()) {
- for (const auto& tool_data : response.tool_data()) {
- TF_RETURN_IF_ERROR(
- DumpToolDataToLogDirectory(profile_run_dir, tool_data, os));
- }
+ for (const auto& tool_data : response.tool_data()) {
+ TF_RETURN_IF_ERROR(
+ DumpToolDataToLogDirectory(profile_run_dir, tool_data, os));
}
return Status::OK();
diff --git a/tensorflow/contrib/tpu/profiler/dump_tpu_profile.h b/tensorflow/contrib/tpu/profiler/dump_tpu_profile.h
index 2f8656a37b..29ef977bac 100644
--- a/tensorflow/contrib/tpu/profiler/dump_tpu_profile.h
+++ b/tensorflow/contrib/tpu/profiler/dump_tpu_profile.h
@@ -29,6 +29,8 @@ namespace tpu {
// - Op profile
// - Input pipeline analyzer
// - Overview page
+// Note: this function creates a directory even when all fields in
+// ProfileResponse are unset/empty.
Status WriteTensorboardTPUProfile(const string& logdir, const string& run,
const ProfileResponse& response,
std::ostream* os);
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_config.py b/tensorflow/contrib/tpu/python/tpu/tpu_config.py
index 0c2580211a..188db6e2f0 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_config.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_config.py
@@ -53,7 +53,8 @@ class TPUConfig(
num_shards: The number of TPU shards in the system.
per_host_input_for_training: If `True`, `input_fn` is invoked Per-Host
rather than Per-Core. With Per-Host input pipeline deployment, `input_fn`
- is invoked once on each host. To be precise, with a global batch size
+ is invoked once on each host. With Per-Core input pipeline deployment, it
+ is invoked once for each core. To be precise, with a global batch size
`train_batch_size` in `TPUEstimator` constructor, the batch size for each
shard is `train_batch_size` // #hosts. With Per-Core input pipeline
deployment, the shard batch size is `train_batch_size` // #cores.
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
index 9d715bb236..c7008533f3 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
@@ -41,6 +41,7 @@ from tensorflow.contrib.tpu.python.tpu import util as util_lib
from tensorflow.core.framework.summary_pb2 import Summary
from tensorflow.core.protobuf import config_pb2
+from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.estimator import estimator as estimator_lib
from tensorflow.python.estimator import model_fn as model_fn_lib
from tensorflow.python.estimator import util
@@ -70,7 +71,12 @@ _BATCH_SIZE_KEY = 'batch_size'
_CROSS_REPLICA_SUM_OP = 'CrossReplicaSum'
_RESERVED_PARAMS_KEYS = [_BATCH_SIZE_KEY]
-# TODO(b/65703635): Flip the value and remove all dead code.
+
+# TODO(b/65703635): Flip the value and remove all dead code. Currently, this is
+# only used for per-core based deployments. For per-host based pipelines, if a
+# user returns a Dataset instance it will be automatically wrapped in a
+# tf.while_loop (This can be disabled by returning features and labels
+# explicitly).
_WRAP_INPUT_FN_INTO_WHILE_LOOP = False
@@ -215,16 +221,11 @@ class _TPUContext(object):
def is_running_on_cpu(self):
"""Determines whether the input_fn and model_fn should be invoked on CPU."""
mode = self._assert_mode()
- return ((not self._use_tpu) or mode == model_fn_lib.ModeKeys.PREDICT or
- (mode == model_fn_lib.ModeKeys.EVAL and
- self._eval_batch_size is None))
+ return (not self._use_tpu) or mode == model_fn_lib.ModeKeys.PREDICT
@property
def global_batch_size(self):
mode = self._assert_mode()
- if mode == model_fn_lib.ModeKeys.EVAL and self._eval_batch_size is None:
- raise RuntimeError('Internal error, EVAL on TPU is not enabled, but '
- '`global_batch_size` is called.')
return (self._train_batch_size
if mode == model_fn_lib.ModeKeys.TRAIN else self._eval_batch_size)
@@ -232,9 +233,6 @@ class _TPUContext(object):
def batch_size_for_input_fn(self):
"""Returns the shard batch size for `input_fn`."""
mode = self._assert_mode()
- # Special case for eval.
- if mode == model_fn_lib.ModeKeys.EVAL and self._eval_batch_size is None:
- return None
if self.is_running_on_cpu():
if mode == model_fn_lib.ModeKeys.TRAIN:
return self._train_batch_size
@@ -255,9 +253,6 @@ class _TPUContext(object):
def batch_size_for_model_fn(self):
"""Returns the shard batch size for `model_fn`."""
mode = self._assert_mode()
- # Special case for eval.
- if mode == model_fn_lib.ModeKeys.EVAL and self._eval_batch_size is None:
- return None
if self.is_running_on_cpu():
if mode == model_fn_lib.ModeKeys.TRAIN:
return self._train_batch_size
@@ -415,14 +410,13 @@ class TPUEstimatorSpec(
function should not capture any Tensors in `model_fn`.
`host_call` is a tuple of a `function` and a list or dictionary of `tensors`
- to pass to that function. `host_call` currently works for train() and
- evaluate(). The function's graph is executed on the CPU on every step, so
- there is communication overhead when sending tensors from TPU to CPU. To
- reduce the overhead, try reducing the size of the tensors. The `tensors` are
- concatenated along their major (batch) dimension, and so must be >= rank 1.
- The `host_call` is useful for writing summaries with
- @{tf.contrib.summary.create_file_writer}. Note that `host_call` does not
- currently work if `use_tpu` is set to False.
+ to pass to that function and returns a list of Tensors. `host_call` currently
+ works for train() and evaluate(). The Tensors returned by the function is
+ executed on the CPU on every step, so there is communication overhead when
+ sending tensors from TPU to CPU. To reduce the overhead, try reducing the
+ size of the tensors. The `tensors` are concatenated along their major (batch)
+ dimension, and so must be >= rank 1. The `host_call` is useful for writing
+ summaries with @{tf.contrib.summary.create_file_writer}.
"""
def __new__(cls,
@@ -454,10 +448,18 @@ class TPUEstimatorSpec(
def as_estimator_spec(self):
"""Creates an equivalent `EstimatorSpec` used by CPU train/eval."""
+ host_calls = {}
+ if self.eval_metrics is not None:
+ host_calls['eval_metrics'] = self.eval_metrics
+ if self.host_call is not None:
+ host_calls['host_call'] = self.host_call
+ host_call_ret = _OutfeedHostCall.create_cpu_hostcall(host_calls)
eval_metric_ops = None
if self.eval_metrics is not None:
- eval_metric_ops = _OutfeedHostCall.create_cpu_hostcall(
- {'eval_metrics': self.eval_metrics})['eval_metrics']
+ eval_metric_ops = host_call_ret['eval_metrics']
+ hooks = None
+ if self.host_call is not None:
+ hooks = [_OutfeedHostCallHook(host_call_ret['host_call'])]
scaffold = self.scaffold_fn() if self.scaffold_fn else None
return model_fn_lib.EstimatorSpec(
mode=self.mode,
@@ -466,7 +468,10 @@ class TPUEstimatorSpec(
train_op=self.train_op,
eval_metric_ops=eval_metric_ops,
export_outputs=self.export_outputs,
- scaffold=scaffold)
+ scaffold=scaffold,
+ training_hooks=hooks,
+ evaluation_hooks=hooks,
+ prediction_hooks=hooks)
class _OpQueueContext(object):
@@ -510,12 +515,19 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook):
dequeue.
"""
- def __init__(self, ctx, enqueue_ops, dequeue_ops):
+ def __init__(self,
+ ctx,
+ enqueue_ops,
+ dequeue_ops,
+ run_infeed_loop_on_coordinator=True):
self._master_job = ctx.master_job
self._enqueue_ops = enqueue_ops
self._dequeue_ops = dequeue_ops
+
+ self._run_infeed_loop_on_coordinator = run_infeed_loop_on_coordinator
self._initial_infeed_sleep_secs = (
ctx.config.tpu_config.initial_infeed_sleep_secs)
+
self._session_cancel_timer = None
self._feed_error = None
@@ -598,15 +610,15 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook):
logging.info('%s thread starting after sleep', self._name)
try:
- if _WRAP_INPUT_FN_INTO_WHILE_LOOP:
- for _ in queue_ctx.read_iteration_counts():
- session.run(self._enqueue_ops)
- else:
+ if self._run_infeed_loop_on_coordinator:
for count, steps in enumerate(queue_ctx.read_iteration_counts()):
for i in xrange(steps):
logging.debug('Infeed enqueue for iteration (%d, %d)', count, i)
session.run(self._enqueue_ops)
- logging.debug('Infeed thread finished, shutting down.')
+ else:
+ for _ in queue_ctx.read_iteration_counts():
+ session.run(self._enqueue_ops)
+ logging.info('Infeed thread finished, shutting down.')
except Exception as e: # pylint: disable=broad-except
self._log_error(session, e)
@@ -617,6 +629,7 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook):
for i in xrange(steps):
logging.debug('Outfeed dequeue for iteration (%d, %d)', count, i)
session.run(self._dequeue_ops)
+ logging.info('Outfeed thread finished, shutting down.')
except Exception as e: # pylint: disable=broad-except
self._log_error(session, e)
@@ -644,7 +657,6 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook):
logging.info('Enqueue next (%d) batch(es) of data to infeed.', iterations)
self._infeed_controller.send_next_batch_signal(iterations)
- # TODO(xiejw): Refactor the outfeed dequeue into tf.while_loop.
logging.info('Dequeue next (%d) batch(es) of data from outfeed.',
iterations)
self._outfeed_controller.send_next_batch_signal(iterations)
@@ -763,11 +775,14 @@ def generate_per_core_enqueue_ops_fn_for_host(ctx, input_fn,
per_host_sharded_inputs = []
for core_ordinal in range(num_cores_per_host):
with ops.name_scope('ordinal_%d' % (core_ordinal)):
- inputs = input_fn()
- if isinstance(inputs, tuple):
- features, labels = inputs
- else:
- features, labels = inputs, None
+ inputs = _Inputs.from_input_fn(input_fn())
+ if inputs.is_dataset:
+ raise TypeError(
+ '`input_fn` returning `Dataset` is not yet supported in '
+ 'per-Core input pipeline deployment yet. Please set '
+ 'TPUConfig.per_host_input_for_training to True or return '
+ '`features` and `labels` from `input_fn`')
+ features, labels = inputs.features_and_labels()
inputs_structure_recorder.validate_and_record_structure(
features, labels)
@@ -794,14 +809,23 @@ def generate_per_host_enqueue_ops_fn_for_host(
"""Generates infeed enqueue ops for per-host input_fn on a single host."""
captured_infeed_queue = _CapturedObject()
+ hooks = []
+
+ with ops.device(device):
+ inputs = _Inputs.from_input_fn(input_fn())
+
+ is_dataset = inputs.is_dataset
+ if is_dataset:
+ hooks.append(inputs.dataset_initializer_hook())
+
def enqueue_ops_fn():
with ops.device(device):
num_cores_per_host = ctx.num_of_cores_per_host
- inputs = input_fn()
- if isinstance(inputs, tuple):
- features, labels = inputs
- else:
- features, labels = inputs, None
+ # Convert user input to features and labels. If the user returns a
+ # dataset, it is initialized and the features and labels extracted via
+ # `dataset.iterator.get_next()`
+ features, labels = inputs.features_and_labels()
+
inputs_structure_recorder.validate_and_record_structure(features, labels)
unsharded_tensor_list = (
inputs_structure_recorder.flatten_features_and_labels(
@@ -819,7 +843,7 @@ def generate_per_host_enqueue_ops_fn_for_host(
unsharded_tensor_list, placement_function=lambda x: device))
return per_host_enqueue_ops
- return enqueue_ops_fn, captured_infeed_queue
+ return enqueue_ops_fn, captured_infeed_queue, hooks, is_dataset
class _InputPipeline(object):
@@ -955,7 +979,7 @@ class _InputPipeline(object):
# Single tensor case.
unflattened_label = flattened_inputs[expected_num_features]
- return unflattened_features, unflattened_label
+ return _Inputs(unflattened_features, unflattened_label)
def __init__(self, input_fn, batch_axis, ctx):
"""Constructor.
@@ -983,7 +1007,8 @@ class _InputPipeline(object):
# While tf.while_loop is called, the body function, which invokes
# `enqueue_fn` passed in, is called to construct the graph. So, input_fn
# structure is recorded.
- enqueue_ops = self._invoke_input_fn_and_record_structure()
+ enqueue_ops, all_hooks, run_infeed_loop_on_coordinator = (
+ self._invoke_input_fn_and_record_structure())
self._validate_input_pipeline()
@@ -994,14 +1019,18 @@ class _InputPipeline(object):
return self._inputs_structure_recorder.unflatten_features_and_labels(
values)
- return (enqueue_ops, dequeue_fn)
+ return (enqueue_ops, dequeue_fn, all_hooks, run_infeed_loop_on_coordinator)
def _invoke_input_fn_and_record_structure(self):
"""Deploys the input pipeline and record input structure."""
enqueue_ops = []
infeed_queues = []
+ all_hooks = []
num_hosts = self._ctx.num_hosts
tpu_host_placement_fn = self._ctx.tpu_host_placement_function
+
+ run_infeed_loop_on_coordinator = True
+
if self._sharded_per_core:
# Per-Core input pipeline deployment.
# Invoke input pipeline for each core and placed on the corresponding
@@ -1015,6 +1044,7 @@ class _InputPipeline(object):
self._ctx, self._input_fn, self._inputs_structure_recorder))
if _WRAP_INPUT_FN_INTO_WHILE_LOOP:
+ run_infeed_loop_on_coordinator = False
enqueue_ops.append(
_wrap_computation_in_while_loop(
device=host_device, op_fn=enqueue_ops_fn))
@@ -1028,12 +1058,26 @@ class _InputPipeline(object):
host_device = tpu_host_placement_fn(host_id=host_id)
with ops.device(host_device):
with ops.name_scope('input_pipeline_task%d' % (host_id)):
- enqueue_ops_fn, captured_infeed_queue = (
+ enqueue_ops_fn, captured_infeed_queue, hooks, is_dataset = (
generate_per_host_enqueue_ops_fn_for_host(
self._ctx, self._input_fn, self._inputs_structure_recorder,
self._batch_axis, host_device))
-
- if _WRAP_INPUT_FN_INTO_WHILE_LOOP:
+ all_hooks.extend(hooks)
+
+ # NOTE(xiejw): We dispatch here based on the return type of the
+ # users `input_fn`.
+ #
+ # 1. If input_fn returns a Dataset instance, we initialize the
+ # iterator outside of tf.while_loop, and call the iterator.get_next
+ # inside tf.while_loop. This should be always safe.
+ #
+ # 2. If input_fn returns (features, labels), it is too late to wrap
+ # them inside tf.while_loop, as resource initialization cannot be
+ # handled in TF control flow properly. In this case, we will use
+ # python loop to enqueue the data into TPU system. This may be
+ # slow compared to the previous case.
+ if is_dataset:
+ run_infeed_loop_on_coordinator = False
enqueue_ops.append(
_wrap_computation_in_while_loop(
device=host_device, op_fn=enqueue_ops_fn))
@@ -1044,7 +1088,7 @@ class _InputPipeline(object):
# dequeue is dtypes and types. So, any one can be used. Here, grab the
# first one.
self._infeed_queue = infeed_queues[0]
- return enqueue_ops
+ return enqueue_ops, all_hooks, run_infeed_loop_on_coordinator
def _validate_input_pipeline(self):
# Perform some sanity checks to log user friendly information. We should
@@ -1110,7 +1154,8 @@ class _ModelFnWrapper(object):
def train_step(loss):
"""Training step function for use inside a while loop."""
del loss # unused; required in function signature.
- features, labels = dequeue_fn()
+ inputs = dequeue_fn()
+ features, labels = inputs.features_and_labels()
estimator_spec = self._verify_estimator_spec(
self._call_model_fn(features, labels))
@@ -1161,7 +1206,8 @@ class _ModelFnWrapper(object):
def eval_step(total_loss):
"""Evaluation step function for use inside a while loop."""
- features, labels = dequeue_fn()
+ inputs = dequeue_fn()
+ features, labels = inputs.features_and_labels()
tpu_estimator_spec = self._call_model_fn(features, labels)
if not isinstance(tpu_estimator_spec, TPUEstimatorSpec):
@@ -1414,6 +1460,34 @@ class _OutfeedHostCall(object):
return ret
+class _OutfeedHostCallHook(session_run_hook.SessionRunHook):
+ """Hook to run host calls when use_tpu=False."""
+
+ def __init__(self, tensors):
+ self._tensors = tensors
+
+ def begin(self):
+ # We duplicate this code from the TPUInfeedOutfeedSessionHook rather than
+ # create a separate hook to guarantee execution order, because summaries
+ # need to be initialized before the outfeed thread starts.
+ # TODO(jhseu): Make a wrapper hook instead?
+ self._init_ops = contrib_summary.summary_writer_initializer_op()
+ # Get all the writer resources from the initializer, so we know what to
+ # flush.
+ self._finalize_ops = []
+ for op in self._init_ops:
+ self._finalize_ops.append(contrib_summary.flush(writer=op.inputs[0]))
+
+ def after_create_session(self, session, coord):
+ session.run(self._init_ops)
+
+ def before_run(self, run_context):
+ return basic_session_run_hooks.SessionRunArgs(self._tensors)
+
+ def end(self, session):
+ session.run(self._finalize_ops)
+
+
class ExamplesPerSecondHook(basic_session_run_hooks.StepCounterHook):
"""Count examples during runtime."""
@@ -1464,30 +1538,28 @@ class TPUEstimator(estimator_lib.Estimator):
replicating inputs and models for each core, and returning to host
periodically to run hooks.
- If `use_tpu` is false, all training, evaluation, and predict are executed on
- CPU.
-
- For training, TPUEstimator transforms a global batch size in params to a
- per-shard batch size when calling the `input_fn` and `model_fn`. Users should
- specify `train_batch_size` in constructor, and then get the batch size for
- each shard in `input_fn` and `model_fn` by `params['batch_size']`. If
- `TPUConfig.per_host_input_for_training` is `True`, `input_fn` is invoked per
- host rather than per core. In this case, a global batch size is transformed a
- per-host batch size in params for `input_fn`, but `model_fn` still gets
- per-core batch size.
-
- For evaluation, if `eval_batch_size` is None, it is executed on CPU, even if
- `use_tpu` is `True`. If `eval_batch_size` is not `None`, it is executed on
- TPU, which is an experimental feature. In this case, `model_fn` should return
- `TPUEstimatorSpec` instead of `EstimatorSpec`, which expects the
- `eval_metrics` for TPU evaluation.
-
+ TPUEstimator transforms a global batch size in params to a per-shard batch
+ size when calling the `input_fn` and `model_fn`. Users should specify
+ global batch size in constructor, and then get the batch size for each shard
+ in `input_fn` and `model_fn` by `params['batch_size']`.
+ For training, `model_fn` gets per-core batch size; `input_fn` may get
+ per-core or per-host batch size depending on
+ `per_host_input_for_training` in `TPUConfig`.
+ For evaluation, `model_fn` gets per-core batch size and `input_fn` get
+ per-host batch size.
+
+ `model_fn` should return `TPUEstimatorSpec`, which expects the `eval_metrics`
+ for TPU evaluation.
`TPUEstimatorSpec.eval_metrics` is a tuple of `metric_fn` and `tensors`, where
`tensors` could be a list of `Tensor`s or dict of names to `Tensor`s. (See
`TPUEstimatorSpec` for details). `metric_fn` takes the `tensors` and returns
a dict from metric string name to the result of calling a metric function,
namely a `(metric_tensor, update_op)` tuple.
+ One can set `use_tpu` to `False` for testing. All training, evaluation, and
+ predict will be executed on CPU. `input_fn` and `model_fn` will receive
+ `train_batch_size` or `eval_batch_size` unmodified as `params['batch_size']`.
+
Current limitations:
1. TPU evaluation only works on single host.
@@ -1560,8 +1632,7 @@ class TPUEstimator(estimator_lib.Estimator):
basic python types. There are reserved keys for `TPUEstimator`,
including 'batch_size'.
use_tpu: A bool indicating whether TPU support is enabled. Currently,
- - TPU training respects this bit.
- - If true, see `eval_batch_size` for evaluate support.
+ - TPU training and evaluation respect this bit.
- Predict still happens on CPU.
train_batch_size: An int representing the global training batch size.
TPUEstimator transforms this global batch size to a per-shard batch
@@ -1569,9 +1640,7 @@ class TPUEstimator(estimator_lib.Estimator):
Cannot be `None` if `use_tpu` is `True`. Must be divisible by
`config.tpu_config.num_shards`.
eval_batch_size: An int representing the global training batch size.
- Currently, if `None`, evaluation is still executed on CPU (even when
- `use_tpu` is True). In near future, `use_tpu` will be the only option to
- switch between TPU/CPU evaluation.
+ Must be divisible by `config.tpu_config.num_shards`.
batch_axis: A python tuple of int values describing how each tensor
produced by the Estimator `input_fn` should be split across the TPU
compute shards. For example, if your input_fn produced (images, labels)
@@ -1611,10 +1680,10 @@ class TPUEstimator(estimator_lib.Estimator):
.format(train_batch_size, config.tpu_config.num_shards))
if eval_batch_size is not None:
- if config.tpu_config.num_shards > 8:
- raise NotImplementedError(
- 'TPU evaluation is only supported with one host.')
-
+ if not isinstance(eval_batch_size, int):
+ raise ValueError('`eval_batch_size` must be an int')
+ if eval_batch_size < 1:
+ raise ValueError('`eval_batch_size` must be positive')
if eval_batch_size % config.tpu_config.num_shards != 0:
raise ValueError(
'eval batch size {} must be divisible by number of shards {}'
@@ -1687,6 +1756,14 @@ class TPUEstimator(estimator_lib.Estimator):
util_lib.check_positive_integer(steps, 'Eval steps')
+ if self._config.tpu_config.num_shards > 8:
+ raise NotImplementedError(
+ 'TPU evaluation is only supported with one host.')
+
+ if self._ctx._eval_batch_size is None: # pylint: disable=protected-access
+ raise ValueError('`eval_batch_size` cannot be `None`'
+ 'if evaluate() is called on TPU.')
+
return [
evaluation._StopAfterNEvalsHook( # pylint: disable=protected-access
num_evals=steps),
@@ -1765,7 +1842,7 @@ class TPUEstimator(estimator_lib.Estimator):
input_fn = features
input_holders = _InputPipeline(input_fn, batch_axis, ctx)
- enqueue_ops, dequeue_fn = (
+ enqueue_ops, dequeue_fn, input_hooks, run_infeed_loop_on_coordinator = (
input_holders.generate_infeed_enqueue_ops_and_dequeue_fn())
if mode == model_fn_lib.ModeKeys.TRAIN:
@@ -1775,7 +1852,12 @@ class TPUEstimator(estimator_lib.Estimator):
if host_ops is None:
host_ops = []
hooks = [
- TPUInfeedOutfeedSessionHook(ctx, enqueue_ops, host_ops),
+ TPUInfeedOutfeedSessionHook(
+ ctx,
+ enqueue_ops,
+ host_ops,
+ run_infeed_loop_on_coordinator=(
+ run_infeed_loop_on_coordinator)),
ExamplesPerSecondHook(ctx.global_batch_size),
InstallSignalHandlerHook(),
training.LoggingTensorHook(
@@ -1784,7 +1866,7 @@ class TPUEstimator(estimator_lib.Estimator):
'step': training.get_global_step()
},
every_n_secs=30)
- ]
+ ] + input_hooks
summary.scalar(model_fn_lib.LOSS_METRIC_KEY, loss)
with ops.control_dependencies([loss]):
update_ops = _sync_variables_ops()
@@ -1834,9 +1916,12 @@ class TPUEstimator(estimator_lib.Estimator):
else:
host_ops = host_call_ret['host_call']
hooks = [
- TPUInfeedOutfeedSessionHook(ctx, enqueue_ops,
- eval_update_ops + host_ops),
- ]
+ TPUInfeedOutfeedSessionHook(
+ ctx,
+ enqueue_ops,
+ eval_update_ops + host_ops,
+ run_infeed_loop_on_coordinator=run_infeed_loop_on_coordinator),
+ ] + input_hooks
return model_fn_lib.EstimatorSpec(
mode,
@@ -2004,3 +2089,60 @@ class _CapturingContext(control_flow_ops.ControlFlowContext):
def __exit__(self, _, __, ___): # pylint: disable=invalid-name
self._g._set_control_flow_context(self._old) # pylint: disable=protected-access
+
+
+# TODO(xiejw): Extend this to support internal signal.
+class _Inputs(object):
+ """A data structure representing the input_fn returned values.
+
+ This also supports the returned value from input_fn as `Dataset`.
+ """
+
+ def __init__(self, features=None, labels=None, dataset=None):
+ if dataset is not None and (features is not None or labels is not None):
+ raise RuntimeError('Internal Error: Either (features and labels) or '
+ 'dataset should be provided, not both. Please file '
+ 'bug')
+
+ self._features = features
+ self._labels = labels
+
+ self._dataset = dataset
+ self._iterator = None
+
+ @staticmethod
+ def from_input_fn(return_values):
+ """Returns an `_Inputs` instance according to `input_fn` return value."""
+ if isinstance(return_values, dataset_ops.Dataset):
+ dataset = return_values
+ return _Inputs(dataset=dataset)
+
+ if isinstance(return_values, tuple):
+ features, labels = return_values
+ else:
+ features, labels = return_values, None
+ return _Inputs(features, labels)
+
+ @property
+ def is_dataset(self):
+ """Returns True if the return value from input_fn is Dataset."""
+ return self._dataset is not None
+
+ def dataset_initializer_hook(self):
+ """Returns a `SessionRunHook` to initialize this dataset.
+
+ This must be called before `features_and_labels`.
+ """
+ iterator = self._dataset.make_initializable_iterator()
+ # pylint: disable=protected-access
+ hook = estimator_lib._DatasetInitializerHook(iterator)
+ self._iterator = iterator
+ return hook
+
+ def features_and_labels(self):
+ """Gets `features` and `labels`."""
+ if self.is_dataset:
+ return (_Inputs.from_input_fn(
+ self._iterator.get_next()).features_and_labels())
+
+ return (self._features, self._labels)
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index 063a4e9d30..a8a8c34846 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -1355,6 +1355,13 @@ tf_pyclif_proto_library(
visibility = ["//visibility:public"],
)
+tf_pyclif_proto_library(
+ name = "protobuf/device_properties_pyclif",
+ proto_lib = ":protos_all_cc",
+ proto_srcfile = "protobuf/device_properties.proto",
+ visibility = ["//visibility:public"],
+)
+
# -----------------------------------------------------------------------------
# Internal targets
diff --git a/tensorflow/core/api_def/base_api/api_def_AssignAddVariableOp.pbtxt b/tensorflow/core/api_def/base_api/api_def_AssignAddVariableOp.pbtxt
index 5d21d7bab6..ac05b54eea 100644
--- a/tensorflow/core/api_def/base_api/api_def_AssignAddVariableOp.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_AssignAddVariableOp.pbtxt
@@ -20,10 +20,7 @@ END
}
summary: "Adds a value to the current value of a variable."
description: <<END
-Any ReadVariableOp which depends directly or indirectly on this assign is
-guaranteed to see the incremented value or a subsequent newer one.
-
-Outputs the incremented value, which can be used to totally order the
-increments to this variable.
+Any ReadVariableOp with a control dependency on this op is guaranteed to
+see the incremented value or a subsequent newer one.
END
}
diff --git a/tensorflow/core/api_def/base_api/api_def_AssignSubVariableOp.pbtxt b/tensorflow/core/api_def/base_api/api_def_AssignSubVariableOp.pbtxt
index 102201c4cb..9dd28f8711 100644
--- a/tensorflow/core/api_def/base_api/api_def_AssignSubVariableOp.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_AssignSubVariableOp.pbtxt
@@ -20,10 +20,7 @@ END
}
summary: "Subtracts a value from the current value of a variable."
description: <<END
-Any ReadVariableOp which depends directly or indirectly on this assign is
-guaranteed to see the incremented value or a subsequent newer one.
-
-Outputs the incremented value, which can be used to totally order the
-increments to this variable.
+Any ReadVariableOp with a control dependency on this op is guaranteed to
+see the decremented value or a subsequent newer one.
END
}
diff --git a/tensorflow/core/framework/dataset.h b/tensorflow/core/framework/dataset.h
index 2c2c7e7c58..f866183f61 100644
--- a/tensorflow/core/framework/dataset.h
+++ b/tensorflow/core/framework/dataset.h
@@ -16,6 +16,8 @@ limitations under the License.
#ifndef TENSORFLOW_FRAMEWORK_DATASET_H_
#define TENSORFLOW_FRAMEWORK_DATASET_H_
+#include "tensorflow/core/lib/core/status.h"
+
namespace tensorflow {
namespace dataset {
// Registry for stateful ops that need to be used in dataset functions.
diff --git a/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc b/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc
index e58f5f61f3..a376534bad 100644
--- a/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc
+++ b/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc
@@ -648,8 +648,9 @@ struct BatchNarrowMatrixTransposeDispatcher {
static_assert(
(TileLongSide & (TileLongSide - 1)) == 0,
"The length of the longer side of the tile is always a power of 2.");
- bool request_satisfied = max(tile_size_i, tile_size_j) <= TileLongSide &&
- min(tile_size_i, tile_size_j) <= TileShortSide;
+ bool request_satisfied =
+ std::max(tile_size_i, tile_size_j) <= TileLongSide &&
+ std::min(tile_size_i, tile_size_j) <= TileShortSide;
if (request_satisfied) {
LaunchBatchNarrowMatrixTransposeKernel<T, TileLongSide, TileShortSide>(
@@ -662,7 +663,7 @@ struct BatchNarrowMatrixTransposeDispatcher {
// determine whether it is the long side or the short side that falls short
// of the request and increase that parameter accordingly.
const bool long_side_request_not_satisfied =
- max(tile_size_i, tile_size_j) > TileLongSide;
+ std::max(tile_size_i, tile_size_j) > TileLongSide;
if (long_side_request_not_satisfied) {
BatchNarrowMatrixTransposeDispatcher<
@@ -690,8 +691,9 @@ struct BatchNarrowMatrixTransposeDispatcher<
static_assert(
(TileLongSide & (TileLongSide - 1)) == 0,
"The length of the longer side of the tile is always a power of 2.");
- bool request_satisfied = max(tile_size_i, tile_size_j) <= TileLongSide &&
- min(tile_size_i, tile_size_j) <= TileShortSide;
+ bool request_satisfied =
+ std::max(tile_size_i, tile_size_j) <= TileLongSide &&
+ std::min(tile_size_i, tile_size_j) <= TileShortSide;
if (request_satisfied) {
LaunchBatchNarrowMatrixTransposeKernel<T, TileLongSide, TileShortSide>(
@@ -816,7 +818,7 @@ void SwapDimension1And2InTensor3WithNarrowMatrices(
int tile_long_side_len = 0;
int tile_short_side_len = 0;
float lowest_cost = std::numeric_limits<float>::max();
- int data_long_side = max(input_dims[1], input_dims[2]);
+ int data_long_side = std::max(input_dims[1], input_dims[2]);
for (auto tile_size_pair : tile_spec) {
int proposed_tile_long_side_len = tile_size_pair.first;
@@ -861,12 +863,14 @@ void SwapDimension1And2InTensor3WithNarrowMatrices(
// Truncate the shorter size requested according to the manual limit set in
// tile_spec to make sure that we do not launch configurations violating
// hardware limits.
- requested_tile_size_i = requested_tile_size_i == tile_long_side_len
- ? tile_long_side_len
- : min(requested_tile_size_i, tile_short_side_len);
- requested_tile_size_j = requested_tile_size_j == tile_long_side_len
- ? tile_long_side_len
- : min(requested_tile_size_j, tile_short_side_len);
+ requested_tile_size_i =
+ requested_tile_size_i == tile_long_side_len
+ ? tile_long_side_len
+ : std::min(requested_tile_size_i, tile_short_side_len);
+ requested_tile_size_j =
+ requested_tile_size_j == tile_long_side_len
+ ? tile_long_side_len
+ : std::min(requested_tile_size_j, tile_short_side_len);
Dimension<3> input_dims_in_tiles = {
input_dims[0],
diff --git a/tensorflow/core/ops/lookup_ops.cc b/tensorflow/core/ops/lookup_ops.cc
index a67267418d..50ea8ad01a 100644
--- a/tensorflow/core/ops/lookup_ops.cc
+++ b/tensorflow/core/ops/lookup_ops.cc
@@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_def_builder.h"
#include "tensorflow/core/framework/shape_inference.h"
@@ -102,6 +103,8 @@ REGISTER_OP("LookupTableFindV2")
c->set_output(0, c->UnknownShape());
return Status::OK();
});
+WHITELIST_STATEFUL_OP_FOR_DATASET_FUNCTIONS("LookupTableFindV2");
+// TODO(b/72710477): Update this.
REGISTER_OP("LookupTableInsert")
.Input("table_handle: Ref(string)")
diff --git a/tensorflow/core/profiler/README.md b/tensorflow/core/profiler/README.md
index 7997bdfa05..57d76eb4cb 100644
--- a/tensorflow/core/profiler/README.md
+++ b/tensorflow/core/profiler/README.md
@@ -257,7 +257,7 @@ bug fix. `OpLogProto` is a good plus if it is used.
#### Teams
-* Xin Pan (xpan@google.com, github: panyx0718)
+* Xin Pan
* Chris Antaki
* Yao Zhang
* Jon Shlens
diff --git a/tensorflow/core/util/example_proto_fast_parsing_test.cc b/tensorflow/core/util/example_proto_fast_parsing_test.cc
index 9b6a8e1251..13e41c17f7 100644
--- a/tensorflow/core/util/example_proto_fast_parsing_test.cc
+++ b/tensorflow/core/util/example_proto_fast_parsing_test.cc
@@ -57,6 +57,7 @@ void TestCorrectness(const string& serialized) {
Example example;
Example fast_example;
EXPECT_TRUE(example.ParseFromString(serialized));
+ example.DiscardUnknownFields();
EXPECT_TRUE(TestFastParse(serialized, &fast_example));
EXPECT_EQ(example.DebugString(), fast_example.DebugString());
if (example.DebugString() != fast_example.DebugString()) {
diff --git a/tensorflow/go/graph.go b/tensorflow/go/graph.go
index fc087d9d99..08943a527c 100644
--- a/tensorflow/go/graph.go
+++ b/tensorflow/go/graph.go
@@ -173,7 +173,11 @@ type OpSpec struct {
// operation.
Attrs map[string]interface{}
- // Other possible fields: Device, ColocateWith, ControlInputs.
+ // Operations that must be executed before executing the operation
+ // being added.
+ ControlDependencies []*Operation
+
+ // Other possible fields: Device, ColocateWith.
}
// AddOperation adds an operation to g.
@@ -204,6 +208,9 @@ func (g *Graph) AddOperation(args OpSpec) (*Operation, error) {
}
}
}
+ for _, in := range args.ControlDependencies {
+ C.TF_AddControlInput(cdesc, in.c)
+ }
status := newStatus()
for name, value := range args.Attrs {
if err := setAttr(cdesc, status, name, value); err != nil {
diff --git a/tensorflow/go/op/scope.go b/tensorflow/go/op/scope.go
index a9ec79463a..13de4294dc 100644
--- a/tensorflow/go/op/scope.go
+++ b/tensorflow/go/op/scope.go
@@ -33,10 +33,11 @@ import (
// A Scope object and all its derivates (e.g., obtained from Scope.SubScope)
// are not safe for concurrent use by multiple goroutines.
type Scope struct {
- graph *tf.Graph
- namemap map[string]int
- namespace string
- err *scopeErr
+ graph *tf.Graph
+ namemap map[string]int
+ namespace string
+ controlDependencies []*tf.Operation
+ err *scopeErr
}
// scopeErr is used to share errors between all derivatives of a root scope.
@@ -80,6 +81,7 @@ func (s *Scope) AddOperation(args tf.OpSpec) *tf.Operation {
if s.namespace != "" {
args.Name = s.namespace + "/" + args.Name
}
+ args.ControlDependencies = append(args.ControlDependencies, s.controlDependencies...)
op, err := s.graph.AddOperation(args)
if err != nil {
s.UpdateErr(args.Type, err)
@@ -103,6 +105,28 @@ func (s *Scope) SubScope(namespace string) *Scope {
}
}
+// WithControlDependencies returns a new Scope which will cause all operations
+// added to the graph to execute only after all the provided operations have
+// executed first (in addition to any other control dependencies in s).
+func (s *Scope) WithControlDependencies(ops ...*tf.Operation) *Scope {
+ // Force a copy of the control dependencies into a new underlying array on
+ // every call. We cannot alias the same underlying array as `ops`, otherwise
+ // the user could modify that array after calling s.WithControlDependencies,
+ // which would be confusing. We cannot alias the same underlying array as the
+ // original `s.controlDependencies`, since Scopes form a logical tree, and
+ // other calls to s.WithControlDependencies could stomp on each other.
+ deps := make([]*tf.Operation, 0, len(s.controlDependencies)+len(ops))
+ deps = append(deps, s.controlDependencies...)
+ deps = append(deps, ops...)
+ return &Scope{
+ graph: s.graph,
+ namemap: s.namemap,
+ namespace: s.namespace,
+ controlDependencies: deps,
+ err: s.err,
+ }
+}
+
// Err returns the error, if any, encountered during the construction
// of the Graph managed by s.
//
diff --git a/tensorflow/go/op/scope_test.go b/tensorflow/go/op/scope_test.go
index 6fb5d32e50..b58a61de98 100644
--- a/tensorflow/go/op/scope_test.go
+++ b/tensorflow/go/op/scope_test.go
@@ -69,6 +69,49 @@ func TestScopeSubScopeErrors(t *testing.T) {
}
}
+func TestControlDependencies(t *testing.T) {
+ var (
+ s = NewScope()
+ zero = Const(s.SubScope("zero"), int32(0))
+ one = Const(s.SubScope("one"), int32(1))
+ variable = VarHandleOp(s, tf.Int32, tf.ScalarShape())
+ init = AssignVariableOp(s, variable, zero)
+ update = AssignAddVariableOp(s, variable, one)
+ readDeps = []*tf.Operation{update}
+ )
+ // We intend for `read` to have a control dependency on `update`.
+ s = s.WithControlDependencies(readDeps...)
+ // Ensure that Scope.WithControlDependencies makes a copy of the underlying
+ // array, rather than just holding a slice reference to the same user-supplied
+ // underlying array. If the copy is correctly performed, overwriting
+ // readDeps[0] should have no effect on control dependencies for `read`.
+ readDeps[0] = init
+ read := ReadVariableOp(s, variable, tf.Int32)
+
+ graph, err := s.Finalize()
+ if err != nil {
+ t.Fatal(err)
+ }
+ sess, err := tf.NewSession(graph, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if _, err = sess.Run(nil, nil, []*tf.Operation{init}); err != nil {
+ t.Fatal(err)
+ }
+ // Without the control dependency, the read operation may not see the
+ // update.
+ for i := int32(0); i < 10; i++ {
+ out, err := sess.Run(nil, []tf.Output{read}, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if got, want := out[0].Value().(int32), i+1; got != want {
+ t.Errorf("Got %d, want %d", got, want)
+ }
+ }
+}
+
func TestScopeFinalize(t *testing.T) {
var (
root = NewScope()
diff --git a/tensorflow/python/client/session.py b/tensorflow/python/client/session.py
index e6f94396b8..6befeb846d 100644
--- a/tensorflow/python/client/session.py
+++ b/tensorflow/python/client/session.py
@@ -35,6 +35,7 @@ from tensorflow.python.ops import session_ops
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import compat
from tensorflow.python.util import nest
+from tensorflow.python.util.tf_export import tf_export
class SessionInterface(object):
@@ -1441,6 +1442,7 @@ class BaseSession(SessionInterface):
return handles
+@tf_export('Session')
class Session(BaseSession):
"""A class for running TensorFlow operations.
@@ -1581,6 +1583,7 @@ class Session(BaseSession):
tf_session.TF_Reset(target, containers, config)
+@tf_export('InteractiveSession')
class InteractiveSession(BaseSession):
"""A TensorFlow `Session` for use in interactive contexts, such as a shell.
diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py
index f41e807b4c..b7afb8af46 100644
--- a/tensorflow/python/data/ops/dataset_ops.py
+++ b/tensorflow/python/data/ops/dataset_ops.py
@@ -41,8 +41,10 @@ from tensorflow.python.ops import gen_io_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import script_ops
from tensorflow.python.util import deprecation
+from tensorflow.python.util.tf_export import tf_export
+@tf_export("data.Dataset")
class Dataset(object):
"""Represents a potentially large set of elements.
diff --git a/tensorflow/python/data/ops/iterator_ops.py b/tensorflow/python/data/ops/iterator_ops.py
index 53a3244ce1..e573fe0192 100644
--- a/tensorflow/python/data/ops/iterator_ops.py
+++ b/tensorflow/python/data/ops/iterator_ops.py
@@ -25,6 +25,7 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import gen_dataset_ops
+from tensorflow.python.util.tf_export import tf_export
# NOTE(mrry): It is legitimate to call `Iterator.get_next()` multiple
@@ -47,6 +48,7 @@ GET_NEXT_CALL_WARNING_MESSAGE = (
"`next_element` inside the loop.")
+@tf_export("data.Iterator")
class Iterator(object):
"""Represents the state of iterating through a `Dataset`."""
diff --git a/tensorflow/python/data/ops/readers.py b/tensorflow/python/data/ops/readers.py
index 830dc5cec4..fa7601741b 100644
--- a/tensorflow/python/data/ops/readers.py
+++ b/tensorflow/python/data/ops/readers.py
@@ -23,12 +23,14 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import gen_dataset_ops
+from tensorflow.python.util.tf_export import tf_export
# TODO(b/64974358): Increase default buffer size to 256 MB.
_DEFAULT_READER_BUFFER_SIZE_BYTES = 256 * 1024 # 256 KB
+@tf_export("data.TextLineDataset")
class TextLineDataset(Dataset):
"""A `Dataset` comprising lines from one or more text files."""
@@ -71,6 +73,7 @@ class TextLineDataset(Dataset):
return dtypes.string
+@tf_export("data.TFRecordDataset")
class TFRecordDataset(Dataset):
"""A `Dataset` comprising records from one or more TFRecord files."""
@@ -115,6 +118,7 @@ class TFRecordDataset(Dataset):
return dtypes.string
+@tf_export("data.FixedLengthRecordDataset")
class FixedLengthRecordDataset(Dataset):
"""A `Dataset` of fixed-length records from one or more binary files."""
diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD
index 9e3382d4f3..ab81d40148 100644
--- a/tensorflow/python/eager/BUILD
+++ b/tensorflow/python/eager/BUILD
@@ -206,29 +206,6 @@ cc_library(
],
)
-cc_library(
- name = "python_eager_op_gen_main",
- srcs = [
- "python_eager_op_gen_main.cc",
- ],
- visibility = ["//visibility:public"],
- deps = [
- ":python_eager_op_gen",
- "//tensorflow/core:framework",
- "//tensorflow/core:lib",
- "//tensorflow/core:op_gen_lib",
- "//tensorflow/core:protos_all_cc",
- ],
-)
-
-tf_cc_binary(
- name = "python_eager_op_gen_demo",
- deps = [
- ":python_eager_op_gen_main",
- "//tensorflow/core:ops",
- ],
-)
-
py_library(
name = "custom_gradient",
srcs = ["custom_gradient.py"],
diff --git a/tensorflow/python/eager/gen_op.bzl b/tensorflow/python/eager/gen_op.bzl
deleted file mode 100644
index 8bc1d6c10a..0000000000
--- a/tensorflow/python/eager/gen_op.bzl
+++ /dev/null
@@ -1,65 +0,0 @@
-"""For eager-mode Python."""
-
-load("//tensorflow:tensorflow.bzl",
- "clean_dep",
- "tf_binary_additional_srcs",
- "tf_copts",
- "tf_cc_binary")
-
-def tfe_gen_op_wrapper_py(name,
- out=None,
- visibility=None,
- deps=[],
- generated_target_name=None,
- # ApiDefs will be loaded in the order specified in this list.
- api_def_srcs=[]):
- """Generate an eager-mode Python op wrapper for an op library."""
- # Construct a cc_binary containing the specified ops.
- tool_name = "gen_" + name + "_py_wrappers_cc"
- if not deps:
- deps = [str(Label("//tensorflow/core:" + name + "_op_lib"))]
- tf_cc_binary(
- name=tool_name,
- linkopts=["-lm"],
- copts=tf_copts(),
- linkstatic=1,
- deps=([
- clean_dep("//tensorflow/python/eager:python_eager_op_gen_main")
- ] + deps),
- visibility=[clean_dep("//visibility:public")],)
-
- # Invoke the previous cc_binary to generate a python file.
- if not out:
- out = "gen_" + name + ".py"
-
- if not api_def_srcs:
- api_def_args_str = ","
- else:
- api_def_args = []
- for api_def_src in api_def_srcs:
- # Add directory of the first ApiDef source to args.
- # We are assuming all ApiDefs in a single api_def_src are in the
- # same directory.
- api_def_args.append(
- "$$(dirname $$(echo $(locations " + api_def_src +
- ") | cut -d\" \" -f1))")
- api_def_args_str = ",".join(api_def_args)
-
- native.genrule(
- name=name + "_pygenrule",
- outs=[out],
- srcs=api_def_srcs,
- tools=[tool_name] + tf_binary_additional_srcs(),
- cmd=("$(location " + tool_name + ") " + api_def_args_str + " > $@"))
-
- # Make a py_library out of the generated python file.
- if not generated_target_name:
- generated_target_name = name
- native.py_library(
- name=generated_target_name,
- srcs=[out],
- srcs_version="PY2AND3",
- visibility=visibility,
- deps=[
- clean_dep("//tensorflow/python/eager:framework_for_generated_wrappers"),
- ],)
diff --git a/tensorflow/python/eager/python_eager_op_gen_main.cc b/tensorflow/python/eager/python_eager_op_gen_main.cc
deleted file mode 100644
index 05351bd8b1..0000000000
--- a/tensorflow/python/eager/python_eager_op_gen_main.cc
+++ /dev/null
@@ -1,72 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#include "tensorflow/python/eager/python_eager_op_gen.h"
-
-#include <memory>
-#include <string>
-#include <vector>
-
-#include "tensorflow/core/framework/op.h"
-#include "tensorflow/core/framework/op_def.pb.h"
-#include "tensorflow/core/framework/op_gen_lib.h"
-#include "tensorflow/core/lib/io/path.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/platform/env.h"
-#include "tensorflow/core/platform/init_main.h"
-
-namespace tensorflow {
-namespace {
-
-void PrintAllPythonOps(const std::vector<string>& hidden_ops,
- const std::vector<string>& api_def_dirs) {
- OpList ops;
- OpRegistry::Global()->Export(false, &ops);
-
- ApiDefMap api_def_map(ops);
- if (!api_def_dirs.empty()) {
- Env* env = Env::Default();
-
- for (const auto& api_def_dir : api_def_dirs) {
- std::vector<string> api_files;
- TF_CHECK_OK(env->GetMatchingPaths(io::JoinPath(api_def_dir, "*.pbtxt"),
- &api_files));
- TF_CHECK_OK(api_def_map.LoadFileList(env, api_files));
- }
- api_def_map.UpdateDocs();
- }
-
- PrintEagerPythonOps(ops, api_def_map, hidden_ops, true /* require_shapes */);
-}
-
-} // namespace
-} // namespace tensorflow
-
-int main(int argc, char* argv[]) {
- tensorflow::port::InitMain(argv[0], &argc, &argv);
-
- // Usage:
- // python_eager_op_gen_main api_def_dir1,api_def_dir2,...
- if (argc == 1) {
- tensorflow::PrintAllPythonOps({}, {});
- } else if (argc == 2) {
- const std::vector<tensorflow::string> api_def_dirs =
- tensorflow::str_util::Split(argv[1], ",",
- tensorflow::str_util::SkipEmpty());
- tensorflow::PrintAllPythonOps({}, api_def_dirs);
- } else {
- return -1;
- }
- return 0;
-}
diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc
index 836998cfdc..d927f3abed 100644
--- a/tensorflow/python/eager/pywrap_tfe_src.cc
+++ b/tensorflow/python/eager/pywrap_tfe_src.cc
@@ -528,6 +528,34 @@ tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*>* GetTapeSet() {
return tape_set;
}
+// A safe copy of the current tapeset. Does not get affected by other python
+// threads changing the set of active tapes.
+class SafeTapeSet {
+ public:
+ SafeTapeSet() : tape_set_(*GetTapeSet()) {
+ for (auto* tape : tape_set_) {
+ Py_INCREF(tape);
+ }
+ }
+
+ ~SafeTapeSet() {
+ for (auto* tape : tape_set_) {
+ Py_DECREF(tape);
+ }
+ }
+
+ tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*>::const_iterator begin() {
+ return tape_set_.begin();
+ }
+
+ tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*>::const_iterator end() {
+ return tape_set_.end();
+ }
+
+ private:
+ tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*> tape_set_;
+};
+
// xcode 7 doesn't define thread_local, so for compatibility we implement our
// own. TODO(apassos) remove once we can deprecate xcode 7.
#ifndef __APPLE__
@@ -718,10 +746,7 @@ void TFE_Py_TapeSetWatchVariable(PyObject* variable) {
if (*ThreadTapeIsStopped()) {
return;
}
- // Note: making a copy because watching a variable can trigger a change to the
- // set of tapes by allowing python's garbage collector to run.
- auto tape_set = *GetTapeSet();
- for (TFE_Py_Tape* tape : tape_set) {
+ for (TFE_Py_Tape* tape : SafeTapeSet()) {
tape->tape->WatchVariable(variable);
}
}
@@ -777,8 +802,7 @@ void TFE_Py_TapeSetRecordOperation(PyObject* op_type, PyObject* output_tensors,
return;
}
- auto set = *GetTapeSet();
- for (TFE_Py_Tape* tape : set) {
+ for (TFE_Py_Tape* tape : SafeTapeSet()) {
Py_INCREF(backward_function);
tape->tape->RecordOperation(
op_type_str, output_info, input_ids, backward_function,
@@ -787,10 +811,7 @@ void TFE_Py_TapeSetRecordOperation(PyObject* op_type, PyObject* output_tensors,
}
void TFE_Py_TapeSetDeleteTrace(tensorflow::int64 tensor_id) {
- // Note: making a copy because deleting the trace can trigger a change to the
- // set of tapes by allowing python's garbage collector to run.
- auto tape_set = *GetTapeSet();
- for (TFE_Py_Tape* tape : tape_set) {
+ for (TFE_Py_Tape* tape : SafeTapeSet()) {
tape->tape->DeleteTrace(tensor_id);
}
}
diff --git a/tensorflow/python/estimator/export/export.py b/tensorflow/python/estimator/export/export.py
index 51075731dd..83251c79fc 100644
--- a/tensorflow/python/estimator/export/export.py
+++ b/tensorflow/python/estimator/export/export.py
@@ -36,12 +36,14 @@ from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.saved_model import signature_def_utils
from tensorflow.python.util import compat
+from tensorflow.python.util.tf_export import tf_export
_SINGLE_FEATURE_DEFAULT_NAME = 'feature'
_SINGLE_RECEIVER_DEFAULT_NAME = 'input'
+@tf_export('estimator.export.ServingInputReceiver')
class ServingInputReceiver(collections.namedtuple(
'ServingInputReceiver',
['features', 'receiver_tensors', 'receiver_tensors_alternatives'])):
@@ -118,6 +120,7 @@ class ServingInputReceiver(collections.namedtuple(
receiver_tensors_alternatives=receiver_tensors_alternatives)
+@tf_export('estimator.export.build_parsing_serving_input_receiver_fn')
def build_parsing_serving_input_receiver_fn(feature_spec,
default_batch_size=None):
"""Build a serving_input_receiver_fn expecting fed tf.Examples.
@@ -146,6 +149,7 @@ def build_parsing_serving_input_receiver_fn(feature_spec,
return serving_input_receiver_fn
+@tf_export('estimator.export.build_raw_serving_input_receiver_fn')
def build_raw_serving_input_receiver_fn(features, default_batch_size=None):
"""Build a serving_input_receiver_fn expecting feature Tensors.
diff --git a/tensorflow/python/estimator/export/export_output.py b/tensorflow/python/estimator/export/export_output.py
index 863af6d41d..87b964be37 100644
--- a/tensorflow/python/estimator/export/export_output.py
+++ b/tensorflow/python/estimator/export/export_output.py
@@ -26,8 +26,10 @@ import six
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.saved_model import signature_def_utils
+from tensorflow.python.util.tf_export import tf_export
+@tf_export('estimator.export.ExportOutput')
class ExportOutput(object):
"""Represents an output of a model that can be served.
@@ -50,6 +52,7 @@ class ExportOutput(object):
pass
+@tf_export('estimator.export.ClassificationOutput')
class ClassificationOutput(ExportOutput):
"""Represents the output of a classification head.
@@ -118,6 +121,7 @@ class ClassificationOutput(ExportOutput):
examples, self.classes, self.scores)
+@tf_export('estimator.export.RegressionOutput')
class RegressionOutput(ExportOutput):
"""Represents the output of a regression head."""
@@ -153,6 +157,7 @@ class RegressionOutput(ExportOutput):
_SINGLE_OUTPUT_DEFAULT_NAME = 'output'
+@tf_export('estimator.export.PredictOutput')
class PredictOutput(ExportOutput):
"""Represents the output of a generic prediction head.
diff --git a/tensorflow/python/estimator/inputs/numpy_io.py b/tensorflow/python/estimator/inputs/numpy_io.py
index c4c2e30e87..a6f4712910 100644
--- a/tensorflow/python/estimator/inputs/numpy_io.py
+++ b/tensorflow/python/estimator/inputs/numpy_io.py
@@ -24,6 +24,7 @@ import numpy as np
from six import string_types
from tensorflow.python.estimator.inputs.queues import feeding_functions
+from tensorflow.python.util.tf_export import tf_export
# Key name to pack the target into dict of `features`. See
# `_get_unique_target_key` for details.
@@ -86,6 +87,7 @@ def _validate_and_convert_features(x):
return ordered_dict_data
+@tf_export('estimator.inputs.numpy_input_fn')
def numpy_input_fn(x,
y=None,
batch_size=128,
diff --git a/tensorflow/python/estimator/inputs/pandas_io.py b/tensorflow/python/estimator/inputs/pandas_io.py
index 90d6145377..bd06843021 100644
--- a/tensorflow/python/estimator/inputs/pandas_io.py
+++ b/tensorflow/python/estimator/inputs/pandas_io.py
@@ -21,6 +21,7 @@ from __future__ import print_function
import numpy as np
from tensorflow.python.estimator.inputs.queues import feeding_functions
+from tensorflow.python.util.tf_export import tf_export
try:
# pylint: disable=g-import-not-at-top
@@ -34,6 +35,7 @@ except ImportError:
HAS_PANDAS = False
+@tf_export('estimator.inputs.pandas_input_fn')
def pandas_input_fn(x,
y=None,
batch_size=128,
diff --git a/tensorflow/python/feature_column/feature_column.py b/tensorflow/python/feature_column/feature_column.py
index 7feb209cc4..5947d8f6e2 100644
--- a/tensorflow/python/feature_column/feature_column.py
+++ b/tensorflow/python/feature_column/feature_column.py
@@ -157,6 +157,7 @@ from tensorflow.python.platform import gfile
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import checkpoint_utils
from tensorflow.python.util import nest
+from tensorflow.python.util.tf_export import tf_export
def _internal_input_layer(features,
@@ -209,6 +210,7 @@ def _internal_input_layer(features,
return array_ops.concat(output_tensors, 1)
+@tf_export('feature_column.input_layer')
def input_layer(features,
feature_columns,
weight_collections=None,
@@ -329,6 +331,7 @@ class InputLayer(object):
return self._input_layer_template.weights
+@tf_export('feature_column.linear_model')
def linear_model(features,
feature_columns,
units=1,
@@ -498,6 +501,7 @@ def _transform_features(features, feature_columns):
return outputs
+@tf_export('feature_column.make_parse_example_spec')
def make_parse_example_spec(feature_columns):
"""Creates parsing spec dictionary from input feature_columns.
@@ -557,6 +561,7 @@ def make_parse_example_spec(feature_columns):
return result
+@tf_export('feature_column.embedding_column')
def embedding_column(
categorical_column, dimension, combiner='mean', initializer=None,
ckpt_to_load_from=None, tensor_name_in_ckpt=None, max_norm=None,
@@ -807,6 +812,7 @@ def shared_embedding_columns(
return result
+@tf_export('feature_column.numeric_column')
def numeric_column(key,
shape=(1,),
default_value=None,
@@ -881,6 +887,7 @@ def numeric_column(key,
normalizer_fn=normalizer_fn)
+@tf_export('feature_column.bucketized_column')
def bucketized_column(source_column, boundaries):
"""Represents discretized dense input.
@@ -970,6 +977,7 @@ def _assert_string_or_int(dtype, prefix):
'{} dtype must be string or integer. dtype: {}.'.format(prefix, dtype))
+@tf_export('feature_column.categorical_column_with_hash_bucket')
def categorical_column_with_hash_bucket(key,
hash_bucket_size,
dtype=dtypes.string):
@@ -1026,6 +1034,7 @@ def categorical_column_with_hash_bucket(key,
return _HashedCategoricalColumn(key, hash_bucket_size, dtype)
+@tf_export('feature_column.categorical_column_with_vocabulary_file')
def categorical_column_with_vocabulary_file(key,
vocabulary_file,
vocabulary_size=None,
@@ -1145,6 +1154,7 @@ def categorical_column_with_vocabulary_file(key,
dtype=dtype)
+@tf_export('feature_column.categorical_column_with_vocabulary_list')
def categorical_column_with_vocabulary_list(
key, vocabulary_list, dtype=None, default_value=-1, num_oov_buckets=0):
"""A `_CategoricalColumn` with in-memory vocabulary.
@@ -1255,6 +1265,7 @@ def categorical_column_with_vocabulary_list(
default_value=default_value, num_oov_buckets=num_oov_buckets)
+@tf_export('feature_column.categorical_column_with_identity')
def categorical_column_with_identity(key, num_buckets, default_value=None):
"""A `_CategoricalColumn` that returns identity values.
@@ -1322,6 +1333,7 @@ def categorical_column_with_identity(key, num_buckets, default_value=None):
key=key, num_buckets=num_buckets, default_value=default_value)
+@tf_export('feature_column.indicator_column')
def indicator_column(categorical_column):
"""Represents multi-hot representation of given categorical column.
@@ -1350,6 +1362,7 @@ def indicator_column(categorical_column):
return _IndicatorColumn(categorical_column)
+@tf_export('feature_column.weighted_categorical_column')
def weighted_categorical_column(
categorical_column, weight_feature_key, dtype=dtypes.float32):
"""Applies weight values to a `_CategoricalColumn`.
@@ -1424,6 +1437,7 @@ def weighted_categorical_column(
dtype=dtype)
+@tf_export('feature_column.crossed_column')
def crossed_column(keys, hash_bucket_size, hash_key=None):
"""Returns a column for performing crosses of categorical features.
diff --git a/tensorflow/python/keras/BUILD b/tensorflow/python/keras/BUILD
index 6125755775..a9dd8d8e9d 100755
--- a/tensorflow/python/keras/BUILD
+++ b/tensorflow/python/keras/BUILD
@@ -39,6 +39,7 @@ py_library(
"_impl/keras/engine/__init__.py",
"_impl/keras/engine/topology.py",
"_impl/keras/engine/training.py",
+ "_impl/keras/engine/training_eager.py",
"_impl/keras/estimator.py",
"_impl/keras/initializers.py",
"_impl/keras/layers/__init__.py",
@@ -720,6 +721,19 @@ py_test(
)
py_test(
+ name = "training_eager_test",
+ size = "medium",
+ srcs = ["_impl/keras/engine/training_eager_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["notsan"],
+ deps = [
+ ":keras",
+ "//tensorflow/python:client_testlib",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
name = "topology_test",
size = "small",
srcs = ["_impl/keras/engine/topology_test.py"],
diff --git a/tensorflow/python/keras/_impl/keras/backend.py b/tensorflow/python/keras/_impl/keras/backend.py
index 460c0dc5f3..098ea063f9 100644
--- a/tensorflow/python/keras/_impl/keras/backend.py
+++ b/tensorflow/python/keras/_impl/keras/backend.py
@@ -29,6 +29,7 @@ import numpy as np
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import session as session_module
+from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes as dtypes_module
from tensorflow.python.framework import ops
@@ -326,7 +327,15 @@ def learning_phase():
Returns:
Learning phase (scalar integer tensor or Python integer).
+
+ Raises:
+ ValueError: If called when Eager execution is enabled.
"""
+ if context.in_eager_mode():
+ if 'eager' not in _GRAPH_LEARNING_PHASES:
+ raise ValueError('No learning phase set in Eager mode.')
+ return _GRAPH_LEARNING_PHASES['eager']
+
graph = ops.get_default_graph()
if graph not in _GRAPH_LEARNING_PHASES:
phase = array_ops.placeholder_with_default(
@@ -347,7 +356,10 @@ def set_learning_phase(value):
global _GRAPH_LEARNING_PHASES # pylint: disable=global-variable-not-assigned
if value not in {0, 1}:
raise ValueError('Expected learning phase to be ' '0 or 1.')
- _GRAPH_LEARNING_PHASES[ops.get_default_graph()] = value
+ if context.in_eager_mode():
+ _GRAPH_LEARNING_PHASES['eager'] = value
+ else:
+ _GRAPH_LEARNING_PHASES[ops.get_default_graph()] = value
def get_session():
diff --git a/tensorflow/python/keras/_impl/keras/engine/topology.py b/tensorflow/python/keras/_impl/keras/engine/topology.py
index 64aa868f38..8354a2b8fd 100644
--- a/tensorflow/python/keras/_impl/keras/engine/topology.py
+++ b/tensorflow/python/keras/_impl/keras/engine/topology.py
@@ -708,8 +708,10 @@ class Network(tf_network.GraphNetwork, Layer):
self.input_names.append(layer.name)
if layer.is_placeholder:
self._feed_input_names.append(layer.name)
- self._feed_inputs.append(layer.input)
self._feed_input_shapes.append(K.int_shape(self.inputs[i]))
+ # layer.input gives an error in eager mode
+ if context.in_graph_mode():
+ self._feed_inputs.append(layer.input)
for layer in self._output_layers:
self.output_names.append(layer.name)
diff --git a/tensorflow/python/keras/_impl/keras/engine/training.py b/tensorflow/python/keras/_impl/keras/engine/training.py
index 699ae2edf0..43d95b1f19 100644
--- a/tensorflow/python/keras/_impl/keras/engine/training.py
+++ b/tensorflow/python/keras/_impl/keras/engine/training.py
@@ -22,17 +22,21 @@ import copy
import numpy as np
+from tensorflow.python.eager import context
+from tensorflow.python.framework import ops
from tensorflow.python.keras._impl.keras import backend as K
from tensorflow.python.keras._impl.keras import callbacks as cbks
from tensorflow.python.keras._impl.keras import losses
from tensorflow.python.keras._impl.keras import metrics as metrics_module
from tensorflow.python.keras._impl.keras import optimizers
+from tensorflow.python.keras._impl.keras.engine import training_eager
from tensorflow.python.keras._impl.keras.engine.topology import Network
from tensorflow.python.keras._impl.keras.utils.data_utils import GeneratorEnqueuer
from tensorflow.python.keras._impl.keras.utils.data_utils import OrderedEnqueuer
from tensorflow.python.keras._impl.keras.utils.data_utils import Sequence
from tensorflow.python.keras._impl.keras.utils.generic_utils import Progbar
from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.training import optimizer as tf_optimizer_module
try:
from scipy.sparse import issparse # pylint: disable=g-import-not-at-top
@@ -82,21 +86,24 @@ def _standardize_input_data(data,
if data[x].__class__.__name__ == 'DataFrame' else data[x]
for x in names
]
- data = [np.expand_dims(x, 1) if x.ndim == 1 else x for x in data]
except KeyError as e:
raise ValueError('No data provided for "' + e.args[0] + '". Need data '
'for each key in: ' + str(names))
elif isinstance(data, list):
- data = [
- x.values if x.__class__.__name__ == 'DataFrame' else x for x in data
- ]
- data = [
- np.expand_dims(x, 1) if x is not None and x.ndim == 1 else x
- for x in data
- ]
+ if isinstance(data[0], list):
+ data = [np.asarray(d) for d in data]
+ elif len(names) == 1 and isinstance(data[0], (float, int)):
+ data = [np.asarray(data)]
+ else:
+ data = [
+ x.values if x.__class__.__name__ == 'DataFrame' else x for x in data
+ ]
else:
data = data.values if data.__class__.__name__ == 'DataFrame' else data
- data = [np.expand_dims(data, 1)] if data.ndim == 1 else [data]
+ data = [data]
+ data = [
+ np.expand_dims(x, 1) if x is not None and x.ndim == 1 else x for x in data
+ ]
if len(data) != len(names):
if data and hasattr(data[0], 'shape'):
@@ -618,9 +625,15 @@ class Model(Network):
`optimizer`, `loss`, `metrics` or `sample_weight_mode`.
"""
loss = loss or {}
+ if context.in_eager_mode() and not isinstance(
+ optimizer, tf_optimizer_module.Optimizer):
+ raise ValueError('Only TF native optimizers are supported in Eager mode.')
+
self.optimizer = optimizers.get(optimizer)
self.loss = loss
self.loss_weights = loss_weights
+ if context.in_eager_mode() and sample_weight_mode is not None:
+ raise ValueError('sample_weight_mode is not supported in Eager mode.')
self.sample_weight_mode = sample_weight_mode
# Prepare loss functions.
@@ -651,6 +664,7 @@ class Model(Network):
loss_function = losses.get(loss)
loss_functions = [loss_function for _ in range(len(self.outputs))]
self.loss_functions = loss_functions
+
weighted_losses = [_weighted_masked_objective(fn) for fn in loss_functions]
skip_target_indices = []
skip_target_weighing_indices = []
@@ -664,11 +678,12 @@ class Model(Network):
skip_target_weighing_indices.append(i)
# Prepare output masks.
- masks = self.compute_mask(self.inputs, mask=None)
- if masks is None:
- masks = [None for _ in self.outputs]
- if not isinstance(masks, list):
- masks = [masks]
+ if context.in_graph_mode():
+ masks = self.compute_mask(self.inputs, mask=None)
+ if masks is None:
+ masks = [None for _ in self.outputs]
+ if not isinstance(masks, list):
+ masks = [masks]
# Prepare loss weights.
if loss_weights is None:
@@ -694,6 +709,32 @@ class Model(Network):
else:
raise TypeError('Could not interpret loss_weights argument: ' +
str(loss_weights) + ' - expected a list of dicts.')
+ self.loss_weights_list = loss_weights_list
+
+ # initialization for Eager mode execution
+ if context.in_eager_mode():
+ if target_tensors is not None:
+ raise ValueError('target_tensors are not currently supported in Eager'
+ 'mode.')
+ self.total_loss = None
+ self.metrics = metrics
+ self.weighted_metrics = weighted_metrics
+ self.metrics_tensors = []
+ self.metrics_names = ['loss']
+ for i in range(len(self.outputs)):
+ if len(self.outputs) > 1:
+ self.metrics_names.append(self.output_names[i] + '_loss')
+ self.nested_metrics = _collect_metrics(metrics, self.output_names)
+ self._feed_sample_weight_modes = []
+ for i in range(len(self.outputs)):
+ self._feed_sample_weight_modes.append(None)
+ self.sample_weights = []
+ self.targets = []
+ self._collected_trainable_weights = self.trainable_weights
+ for i in range(len(self.outputs)):
+ self._feed_output_names.append(self.output_names[i])
+
+ return
# Prepare targets of model.
self.targets = []
@@ -720,6 +761,7 @@ class Model(Network):
else:
raise TypeError('Expected `target_tensors` to be '
'a list or dict, but got:', target_tensors)
+
for i in range(len(self.outputs)):
if i in skip_target_indices:
self.targets.append(None)
@@ -769,7 +811,7 @@ class Model(Network):
weight = K.placeholder(ndim=2, name=name + '_sample_weights')
sample_weight_modes.append('temporal')
else:
- weight = K.placeholder(ndim=1, name=name + '_sample_weights')
+ weight = K.placeholder(ndim=1, name=name + 'sample_weights')
sample_weight_modes.append(None)
sample_weights.append(weight)
elif isinstance(sample_weight_mode, list):
@@ -929,7 +971,7 @@ class Model(Network):
self._feed_sample_weights = []
for i in range(len(self.sample_weights)):
if i not in skip_target_weighing_indices:
- self._feed_sample_weights.append(sample_weights[i])
+ self._feed_sample_weights.append(self.sample_weights[i])
# Functions for train, test and predict will
# be compiled lazily when required.
@@ -978,6 +1020,7 @@ class Model(Network):
with K.name_scope(self.optimizer.__class__.__name__):
training_updates = self.optimizer.get_updates(
params=self._collected_trainable_weights, loss=self.total_loss)
+
updates = self.updates + training_updates
# Gets loss and metrics. Updates weights at each call.
self.train_function = K.function(
@@ -1156,6 +1199,7 @@ class Model(Network):
callback_model = self
callbacks.set_model(callback_model)
+
callbacks.set_params({
'batch_size': batch_size,
'epochs': epochs,
@@ -1216,6 +1260,7 @@ class Model(Network):
np.random.shuffle(index_array)
batches = _make_batches(num_train_samples, batch_size)
+
for batch_index, (batch_start, batch_end) in enumerate(batches):
batch_ids = index_array[batch_start:batch_end]
try:
@@ -1410,6 +1455,7 @@ class Model(Network):
ins_batch[i] = ins_batch[i].toarray()
batch_outs = f(ins_batch)
+
if isinstance(batch_outs, list):
if batch_index == 0:
for batch_out in enumerate(batch_outs):
@@ -1420,7 +1466,6 @@ class Model(Network):
if batch_index == 0:
outs.append(0.)
outs[0] += batch_outs * len(batch_ids)
-
if verbose == 1:
progbar.update(batch_end)
for i in range(len(outs)):
@@ -1636,6 +1681,7 @@ class Model(Network):
batch_size=batch_size)
# Prepare validation data.
do_validation = False
+ val_ins = []
if validation_data:
do_validation = True
if len(validation_data) == 2:
@@ -1686,39 +1732,65 @@ class Model(Network):
ins = x + y + sample_weights + [1.]
else:
ins = x + y + sample_weights
- self._make_train_function()
- f = self.train_function
# Prepare display labels.
out_labels = self._get_deduped_metrics_names()
- if do_validation:
- self._make_test_function()
- val_f = self.test_function
- callback_metrics = copy.copy(out_labels) + [
- 'val_' + n for n in out_labels
- ]
+ if context.in_eager_mode():
+ if do_validation:
+ callback_metrics = copy.copy(out_labels) + [
+ 'val_' + n for n in out_labels
+ ]
+ else:
+ callback_metrics = copy.copy(out_labels)
+
+ return training_eager.fit_loop(
+ self,
+ ins,
+ out_labels=out_labels,
+ batch_size=batch_size,
+ epochs=epochs,
+ verbose=verbose,
+ callbacks=callbacks,
+ val_ins=val_ins,
+ shuffle=shuffle,
+ callback_metrics=callback_metrics,
+ initial_epoch=initial_epoch,
+ steps_per_epoch=steps_per_epoch,
+ validation_steps=validation_steps)
else:
- callback_metrics = copy.copy(out_labels)
- val_f = None
- val_ins = []
-
- # Delegate logic to `_fit_loop`.
- return self._fit_loop(
- f,
- ins,
- out_labels=out_labels,
- batch_size=batch_size,
- epochs=epochs,
- verbose=verbose,
- callbacks=callbacks,
- val_f=val_f,
- val_ins=val_ins,
- shuffle=shuffle,
- callback_metrics=callback_metrics,
- initial_epoch=initial_epoch,
- steps_per_epoch=steps_per_epoch,
- validation_steps=validation_steps)
+ self._make_train_function()
+ f = self.train_function
+
+ if do_validation:
+ if context.in_graph_mode():
+ self._make_test_function()
+ val_f = self.test_function
+ else:
+ val_f = None
+ callback_metrics = copy.copy(out_labels) + [
+ 'val_' + n for n in out_labels
+ ]
+ else:
+ val_f = None
+ callback_metrics = copy.copy(out_labels)
+
+ # Delegate logic to `_fit_loop`.
+ return self._fit_loop(
+ f,
+ ins,
+ out_labels=out_labels,
+ batch_size=batch_size,
+ epochs=epochs,
+ verbose=verbose,
+ callbacks=callbacks,
+ val_f=val_f,
+ val_ins=val_ins,
+ shuffle=shuffle,
+ callback_metrics=callback_metrics,
+ initial_epoch=initial_epoch,
+ steps_per_epoch=steps_per_epoch,
+ validation_steps=validation_steps)
def evaluate(self,
x=None,
@@ -1794,10 +1866,15 @@ class Model(Network):
ins = x + y + sample_weights + [0.]
else:
ins = x + y + sample_weights
- self._make_test_function()
- f = self.test_function
- return self._test_loop(
- f, ins, batch_size=batch_size, verbose=verbose, steps=steps)
+
+ if context.in_eager_mode():
+ return training_eager.test_loop(
+ self, ins, batch_size=batch_size, verbose=verbose, steps=steps)
+ else:
+ self._make_test_function()
+ f = self.test_function
+ return self._test_loop(
+ f, ins, batch_size=batch_size, verbose=verbose, steps=steps)
def predict(self, x, batch_size=None, verbose=0, steps=None):
"""Generates output predictions for the input samples.
@@ -1849,10 +1926,16 @@ class Model(Network):
ins = x + [0.]
else:
ins = x
- self._make_predict_function()
- f = self.predict_function
- return self._predict_loop(
- f, ins, batch_size=batch_size, verbose=verbose, steps=steps)
+
+ if context.in_eager_mode():
+ return training_eager.predict_loop(
+ self, ins, batch_size=batch_size, verbose=verbose, steps=steps)
+ else:
+ self._make_predict_function()
+ f = self.predict_function
+
+ return self._predict_loop(
+ f, ins, batch_size=batch_size, verbose=verbose, steps=steps)
def train_on_batch(self, x, y, sample_weight=None, class_weight=None):
"""Runs a single gradient update on a single batch of data.
@@ -1888,6 +1971,7 @@ class Model(Network):
or list of scalars (if the model has multiple outputs
and/or metrics). The attribute `model.metrics_names` will give you
the display labels for the scalar outputs.
+
"""
x, y, sample_weights = self._standardize_user_data(
x,
@@ -1899,11 +1983,16 @@ class Model(Network):
ins = x + y + sample_weights + [1.]
else:
ins = x + y + sample_weights
- self._make_train_function()
- outputs = self.train_function(ins)
- if len(outputs) == 1:
- return outputs[0]
- return outputs
+
+ if context.in_eager_mode():
+ return training_eager.train_on_batch(self, ins)
+
+ if context.in_graph_mode():
+ self._make_train_function()
+ outputs = self.train_function(ins)
+ if len(outputs) == 1:
+ return outputs[0]
+ return outputs
def test_on_batch(self, x, y, sample_weight=None):
"""Test the model on a single batch of samples.
@@ -1942,11 +2031,16 @@ class Model(Network):
ins = x + y + sample_weights + [0.]
else:
ins = x + y + sample_weights
- self._make_test_function()
- outputs = self.test_function(ins)
- if len(outputs) == 1:
- return outputs[0]
- return outputs
+
+ if context.in_eager_mode():
+ return training_eager.test_on_batch(self, ins)
+
+ if context.in_graph_mode():
+ self._make_test_function()
+ outputs = self.test_function(ins)
+ if len(outputs) == 1:
+ return outputs[0]
+ return outputs
def predict_on_batch(self, x):
"""Returns predictions for a single batch of samples.
@@ -1956,6 +2050,7 @@ class Model(Network):
Returns:
Numpy array(s) of predictions.
+
"""
x = _standardize_input_data(x, self._feed_input_names,
self._feed_input_shapes)
@@ -1963,11 +2058,25 @@ class Model(Network):
ins = x + [0.]
else:
ins = x
- self._make_predict_function()
- outputs = self.predict_function(ins)
- if len(outputs) == 1:
- return outputs[0]
- return outputs
+
+ if context.in_eager_mode():
+ ins_batch_converted = []
+ for ib in ins:
+ ins_batch_converted.append(ops.convert_to_tensor(ib, dtype=K.floatx()))
+
+ eager_model_inputs = []
+ for i in range(len(self.inputs)):
+ eager_model_inputs.append(ins_batch_converted[i])
+
+ outs = self(eager_model_inputs) # pylint: disable=not-callable
+ return outs
+
+ if context.in_graph_mode():
+ self._make_predict_function()
+ outputs = self.predict_function(ins)
+ if len(outputs) == 1:
+ return outputs[0]
+ return outputs
def fit_generator(self,
generator,
@@ -2072,7 +2181,6 @@ class Model(Network):
model.fit_generator(generate_arrays_from_file('/my_file.txt'),
steps_per_epoch=10000, epochs=10)
```
-
Raises:
ValueError: In case the generator yields
data in an invalid format.
diff --git a/tensorflow/python/keras/_impl/keras/engine/training_eager.py b/tensorflow/python/keras/_impl/keras/engine/training_eager.py
new file mode 100644
index 0000000000..0a115969ca
--- /dev/null
+++ b/tensorflow/python/keras/_impl/keras/engine/training_eager.py
@@ -0,0 +1,666 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Keras training and evaluation routines.
+"""
+# pylint: disable=protected-access
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+import numpy as np
+from tensorflow.python.eager.backprop import GradientTape
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_util
+from tensorflow.python.keras._impl.keras import backend as K
+from tensorflow.python.keras._impl.keras import callbacks as cbks
+from tensorflow.python.keras._impl.keras import losses
+from tensorflow.python.keras._impl.keras import metrics as metrics_module
+from tensorflow.python.keras._impl.keras.utils.generic_utils import Progbar
+
+
+def _make_batches(size, batch_size):
+ """Returns a list of batch indices (tuples of indices).
+
+ Arguments:
+ size: Integer, total size of the data to slice into batches.
+ batch_size: Integer, batch size.
+
+ Returns:
+ A list of tuples of array indices.
+ """
+ num_batches = int(np.ceil(size / float(batch_size)))
+ return [(i * batch_size, min(size, (i + 1) * batch_size))
+ for i in range(0, num_batches)]
+
+
+def _slice_arrays(arrays, start=None, stop=None):
+ """Slice an array or list of arrays.
+
+ This takes an array-like, or a list of
+ array-likes, and outputs:
+ - arrays[start:stop] if `arrays` is an array-like
+ - [x[start:stop] for x in arrays] if `arrays` is a list
+
+ Can also work on list/array of indices: `_slice_arrays(x, indices)`
+
+ Arguments:
+ arrays: Single array or list of arrays.
+ start: can be an integer index (start index)
+ or a list/array of indices
+ stop: integer (stop index); should be None if
+ `start` was a list.
+
+ Returns:
+ A slice of the array(s).
+
+ Raises:
+ ValueError: If the value of start is a list and stop is not None.
+ """
+ if arrays is None:
+ return [None]
+ if isinstance(start, list) and stop is not None:
+ raise ValueError('The stop argument has to be None if the value of start is'
+ 'a list.')
+ elif isinstance(arrays, list):
+ if hasattr(start, '__len__'):
+ # hdf5 datasets only support list objects as indices
+ if hasattr(start, 'shape'):
+ start = start.tolist()
+ return [None if x is None else x[start] for x in arrays]
+ else:
+ return [None if x is None else x[start:stop] for x in arrays]
+ else:
+ if hasattr(start, '__len__'):
+ if hasattr(start, 'shape'):
+ start = start.tolist()
+ return arrays[start]
+ elif hasattr(start, '__getitem__'):
+ return arrays[start:stop]
+ else:
+ return [None]
+
+
+def _get_metrics_info(metric, internal_output_shapes=None, loss_func=None):
+ if metric == 'accuracy' or metric == 'acc':
+ # custom handling of accuracy
+ # (because of class mode duality)
+ output_shape = internal_output_shapes
+ if output_shape[-1] == 1 or loss_func == losses.binary_crossentropy:
+ # case: binary accuracy
+ acc_fn = metrics_module.binary_accuracy
+ elif loss_func == losses.sparse_categorical_crossentropy:
+ # case: categorical accuracy with sparse targets
+ acc_fn = metrics_module.sparse_categorical_accuracy
+ else:
+ acc_fn = metrics_module.categorical_accuracy
+
+ metric_name = 'acc'
+ return metric_name, acc_fn
+ else:
+ metric_fn = metrics_module.get(metric)
+ metric_name = metric_fn.__name__
+ return metric_name, metric_fn
+
+
+def _eager_loss_fn(outputs, targets, loss_fn, output_name):
+ with K.name_scope(output_name + '_loss'):
+ loss = loss_fn(targets, outputs)
+ return loss
+
+
+def _eager_metrics_fn(model, outputs, targets):
+ """Calculates the metrics for each output of the given model.
+
+ Arguments:
+ model: The model on which metrics are being calculated.
+ outputs: The outputs of the given model.
+ targets: The predictions or targets of the given model.
+
+ Returns:
+ Returns the metric names and metric results for each output of the model.
+ """
+ metric_names = []
+ metric_results = []
+ if not isinstance(outputs, list):
+ outputs = [outputs]
+
+ if not isinstance(targets, list):
+ targets = [targets]
+
+ for i in range(len(model.outputs)):
+ output_metrics = model.nested_metrics[i]
+ for nested_output_metric in output_metrics:
+ metric_name, metric_fn = _get_metrics_info(
+ nested_output_metric, model._internal_output_shapes[i],
+ model.loss_functions[i])
+
+ if len(model.output_names) > 1:
+ metric_name = model.output_names[i] + '_' + metric_name
+ if metric_name not in model.metrics_names:
+ model.metrics_names.append(metric_name)
+
+ with K.name_scope(metric_name):
+ metric_result = metric_fn(outputs[i], targets[i])
+ metric_names.append(metric_name)
+ metric_results.append(K.mean(metric_result))
+
+ return metric_names, metric_results
+
+
+def _model_loss(model, inputs, targets):
+ """Calculates the loss for a given model.
+
+ Arguments:
+ model: The model on which metrics are being calculated.
+ inputs: The inputs of the given model. This is typically the mini batch of
+ data that is fed to the model.
+ targets: The predictions or targets of the given model.
+
+ Returns:
+ Returns the model output, total loss and loss value calculated using the
+ specified loss function. The total loss includes regularization losses and
+ applies masking and sample weighting to the loss value.
+ """
+ total_loss = 0
+ outs = model(inputs)
+ if not isinstance(outs, list):
+ outs = [outs]
+
+ if not isinstance(targets, list):
+ targets = [targets]
+
+ loss_metrics = []
+ with K.name_scope('loss'):
+ for i, loss_fn in enumerate(model.loss_functions):
+ # compute the loss
+ output_loss = _eager_loss_fn(outs[i], targets[i], loss_fn,
+ model.output_names[i])
+ loss_metrics.append(K.mean(output_loss))
+
+ mask = outs[i]._keras_mask
+ # adapted from weighted_loss_fn
+ if mask is not None:
+ # mask should have the same shape as output_loss
+ output_loss *= mask
+ # the loss per batch should be proportional
+ # to the number of unmasked samples.
+ output_loss /= K.mean(mask)
+
+ # adapted from weighted_loss_fn
+ # apply sample weighting
+ if model.sample_weights:
+ # reduce score_array to same ndim as weight array
+ ndim = K.ndim(output_loss)
+ weight_ndim = K.ndim(model.sample_weights)
+ output_loss = K.mean(output_loss, axis=list(range(weight_ndim, ndim)))
+ output_loss *= model.sample_weights
+ output_loss /= K.mean(K.cast(K.not_equal(model.sample_weights, 0),
+ K.floatx()))
+ output_loss = K.mean(output_loss)
+
+ loss_weight = model.loss_weights_list[i]
+ if total_loss is None:
+ total_loss = loss_weight * output_loss
+ else:
+ total_loss += loss_weight * output_loss
+
+ total_loss = K.mean(total_loss)
+ # Add regularization losses
+ custom_losses = []
+ for layer in model.layers:
+ if layer.losses:
+ custom_losses += layer.losses
+
+ if custom_losses:
+ total_loss += sum(custom_losses)
+
+ return outs, total_loss, loss_metrics
+
+
+def _process_single_batch(eager_model_inputs, eager_model_outputs, model,
+ training=True):
+ """Calculate the loss and gradient for one input batch.
+
+ The model weights are updated if training is set to True.
+
+ Arguments:
+ eager_model_inputs: Input batch data.
+ eager_model_outputs: Output batch data.
+ model: Model whose loss has to be calculated.
+ training: The boolean represents if the weights of the model are updated.
+ 'fit' methods will set this to True while 'evaluate' methods will
+ set this to False.
+
+ Returns:
+ output of the model, total loss and the loss associated with each output.
+
+ Raises:
+ ValueError: If the model loss is 0 or if the trainable weights list is
+ empty when the trainable parameter is set to True.
+ """
+ K.set_learning_phase(training)
+ with GradientTape() as tape:
+ outs, loss, loss_metrics = _model_loss(model, eager_model_inputs,
+ eager_model_outputs)
+ if loss is None:
+ raise ValueError('The model cannot be run '
+ 'because it has no loss to optimize.')
+ if training:
+ if not model._collected_trainable_weights:
+ raise ValueError('The list of trainable weights is empty. Make sure that '
+ 'you are not setting model.trainable to False before '
+ 'compiling the model.')
+ grads = tape.gradient(loss, model._collected_trainable_weights)
+ model.optimizer.apply_gradients(zip(grads,
+ model._collected_trainable_weights))
+ return outs, loss, loss_metrics
+
+
+def train_on_batch(model, ins):
+ """Calculates the loss and gradient updates for one input batch.
+
+ Arguments:
+ model: Given model on which loss and gradients are calculated.
+ ins: Input and output batch numpy arrays.
+
+ Returns:
+ total loss and the loss associated with each output.
+ """
+ ins_batch_converted = []
+ for ib in ins:
+ ins_batch_converted.append(ops.convert_to_tensor(ib, dtype=K.floatx()))
+ eager_model_inputs = []
+ eager_model_outputs = []
+ for i in range(len(model.inputs)):
+ eager_model_inputs.append(ins_batch_converted[i])
+ for i in range(len(model.inputs), len(ins_batch_converted)):
+ eager_model_outputs.append(ins_batch_converted[i])
+ outs, loss, _ = _process_single_batch(
+ eager_model_inputs, eager_model_outputs, model)
+ if not isinstance(outs, list):
+ outs = [outs]
+ _, metrics_results = _eager_metrics_fn(
+ model, outs, eager_model_outputs)
+ if not isinstance(loss, list):
+ loss = [loss]
+ return loss + metrics_results
+
+
+def test_on_batch(model, ins):
+ """Calculates the loss for one input batch.
+
+ Arguments:
+ model: Given model on which loss is calculated.
+ ins: Input and output batch numpy arrays.
+
+ Returns:
+ total loss, loss and metrics associated with each output.
+ """
+ ins_batch_converted = []
+ for ib in ins:
+ ins_batch_converted.append(ops.convert_to_tensor(ib, dtype=K.floatx()))
+ eager_model_inputs = []
+ eager_model_outputs = []
+ for i in range(len(model.inputs)):
+ eager_model_inputs.append(ins_batch_converted[i])
+ for i in range(len(model.inputs), len(ins_batch_converted)):
+ eager_model_outputs.append(ins_batch_converted[i])
+ outs, loss, loss_metrics = _process_single_batch(
+ eager_model_inputs, eager_model_outputs, model, training=False)
+ if not isinstance(outs, list):
+ outs = [outs]
+ metric_names, metrics_results = _eager_metrics_fn(
+ model, outs, eager_model_outputs)
+ model.metrics_names.append(metric_names)
+ if not isinstance(loss, list):
+ loss = [loss]
+ return loss + loss_metrics + metrics_results
+
+
+def fit_loop(
+ model,
+ ins,
+ out_labels=None,
+ batch_size=None,
+ epochs=100,
+ verbose=1,
+ callbacks=None,
+ val_ins=None,
+ shuffle=True,
+ callback_metrics=None,
+ initial_epoch=0,
+ steps_per_epoch=None,
+ validation_steps=None):
+ """Abstract fit function for `f(ins)`.
+
+ Assume that f returns a list, labeled by out_labels.
+
+ Arguments:
+ model: Instance of the model that is being executed in Eager mode.
+ ins: List of tensors to be fed to `f`
+ out_labels: List of strings, display names of
+ the outputs of `f`
+ batch_size: Integer batch size or None if unknown.
+ epochs: Number of times to iterate over the data
+ verbose: Verbosity mode, 0, 1 or 2
+ callbacks: List of callbacks to be called during training
+ val_ins: List of tensors to be fed to `val_f`
+ shuffle: Whether to shuffle the data at the beginning of each epoch
+ callback_metrics: List of strings, the display names of the metrics
+ passed to the callbacks. They should be the
+ concatenation of list the display names of the outputs of
+ `f` and the list of display names of the outputs of `f_val`.
+ initial_epoch: Epoch at which to start training
+ (useful for resuming a previous training run)
+ steps_per_epoch: Total number of steps (batches of samples)
+ before declaring one epoch finished and starting the
+ next epoch. Ignored with the default value of `None`.
+ validation_steps: Number of steps to run validation for (only if doing
+ validation from data tensors). Ignored with default value of `None`.
+
+ Returns:
+ `History` object.
+
+ Raises:
+ ValueError: In case of invalid argument values.
+ """
+ # Required for Eager mode
+ K.set_learning_phase(True)
+
+ do_validation = False
+ if val_ins:
+ do_validation = True
+ if (verbose and ins and hasattr(ins[0], 'shape') and
+ hasattr(val_ins[0], 'shape')):
+ print('Train on %d samples, validate on %d samples' %
+ (ins[0].shape[0], val_ins[0].shape[0]))
+ if validation_steps:
+ if steps_per_epoch is None:
+ raise ValueError('Can only use `validation_steps` when doing step-wise '
+ 'training, i.e. `steps_per_epoch` must be set.')
+ do_validation = True
+
+ num_train_samples = model._check_num_samples(
+ ins, batch_size, steps_per_epoch, 'steps_per_epoch')
+
+ if num_train_samples is not None:
+ index_array = np.arange(num_train_samples)
+
+ model.history = cbks.History()
+ callbacks = [cbks.BaseLogger()] + (callbacks or []) + [model.history]
+ if verbose:
+ if steps_per_epoch is not None:
+ count_mode = 'steps'
+ else:
+ count_mode = 'samples'
+ callbacks += [cbks.ProgbarLogger(count_mode)]
+ callbacks = cbks.CallbackList(callbacks)
+ out_labels = out_labels or []
+
+ # it's possible to callback a different model than self
+ # (used by Sequential models)
+ if hasattr(model, 'callback_model') and model.callback_model:
+ callback_model = model.callback_model
+ else:
+ callback_model = model
+
+ callbacks.set_model(callback_model)
+
+ callbacks.set_params({
+ 'batch_size': batch_size,
+ 'epochs': epochs,
+ 'steps': steps_per_epoch,
+ 'samples': num_train_samples,
+ 'verbose': verbose,
+ 'do_validation': do_validation,
+ 'metrics': callback_metrics or [],
+ })
+ callbacks.on_train_begin()
+ callback_model.stop_training = False
+ for cbk in callbacks:
+ cbk.validation_data = val_ins
+
+ for epoch in range(initial_epoch, epochs):
+ callbacks.on_epoch_begin(epoch)
+ epoch_logs = {}
+ if shuffle == 'batch':
+ index_array = model._batch_shuffle(index_array, batch_size)
+ elif shuffle:
+ np.random.shuffle(index_array)
+
+ batches = _make_batches(num_train_samples, batch_size)
+
+ for batch_index, (batch_start, batch_end) in enumerate(batches):
+ batch_ids = index_array[batch_start:batch_end]
+ try:
+ if isinstance(ins[-1], float):
+ # Do not slice the training phase flag.
+ ins_batch = _slice_arrays(ins[:-1], batch_ids) + [ins[-1]]
+ else:
+ ins_batch = _slice_arrays(ins, batch_ids)
+ except TypeError:
+ raise TypeError('TypeError while preparing batch. '
+ 'If using HDF5 input data, '
+ 'pass shuffle="batch".')
+ batch_logs = {}
+ batch_logs['batch'] = batch_index
+ batch_logs['size'] = len(batch_ids)
+
+ callbacks.on_batch_begin(batch_index, batch_logs)
+
+ ins_batch_converted = []
+ for ib in ins_batch:
+ ins_batch_converted.append(ops.convert_to_tensor(ib, dtype=K.floatx()))
+ eager_model_inputs = []
+ eager_model_outputs = []
+ for i in range(len(model.inputs)):
+ eager_model_inputs.append(ins_batch_converted[i])
+
+ for i in range(len(model.inputs), len(ins_batch_converted)):
+ eager_model_outputs.append(ins_batch_converted[i])
+
+ outs, loss, loss_metrics = _process_single_batch(eager_model_inputs,
+ eager_model_outputs,
+ model)
+
+ if not isinstance(outs, list):
+ outs = [outs]
+
+ for l, o in zip(out_labels, outs):
+ batch_logs[l] = o
+ # Required for Eager mode
+ metrics_names, metrics_results = _eager_metrics_fn(model, outs,
+ eager_model_outputs)
+ batch_logs['loss'] = tensor_util.constant_value(K.mean(loss))
+
+ # TODO(anjalisridhar): Move this to compile to avoid duplicate code.
+ # In graph mode we set the metric names in compile. However in
+ # Eager mode we calculate the metrics for each batch in fit_loop.
+ # We could calculate the metric names and functions in compile.
+ # This would avoid setting the callback parameters separately.
+ # We need to do this for the first iteration alone
+ for m in metrics_names:
+ if m not in callback_metrics:
+ callback_metrics.append(m)
+
+ callbacks.set_params({
+ 'batch_size': batch_size,
+ 'epochs': epochs,
+ 'steps': steps_per_epoch,
+ 'samples': num_train_samples,
+ 'verbose': verbose,
+ 'do_validation': do_validation,
+ 'metrics': callback_metrics or [],
+ })
+
+ for k, v in zip(model.metrics_names,
+ [K.mean(loss)] + loss_metrics + metrics_results):
+ batch_logs[k] = tensor_util.constant_value(v)
+
+ callbacks.on_batch_end(batch_index, batch_logs)
+ if callback_model.stop_training:
+ break
+
+ if batch_index == len(batches) - 1: # Last batch.
+ if do_validation:
+ val_outs = test_loop(
+ model, val_ins, batch_size=batch_size, verbose=0)
+ if not isinstance(val_outs, list):
+ val_outs = [val_outs]
+ # Same labels assumed.
+ for l, o in zip(out_labels, val_outs):
+ epoch_logs['val_' + l] = o
+ callbacks.on_epoch_end(epoch, epoch_logs)
+ if callback_model.stop_training:
+ break
+ callbacks.on_train_end()
+ return model.history
+
+
+def test_loop(model, ins, batch_size=None, verbose=0, steps=None):
+ """Abstract method to loop over some data in batches.
+
+ Arguments:
+ model: Model instance that is being evaluated in Eager mode.
+ ins: list of tensors to be fed to `f`.
+ batch_size: integer batch size or `None`.
+ verbose: verbosity mode.
+ steps: Total number of steps (batches of samples)
+ before declaring predictions finished.
+ Ignored with the default value of `None`.
+
+ Returns:
+ Scalar loss (if the model has a single output and no metrics)
+ or list of scalars (if the model has multiple outputs
+ and/or metrics). The attribute `model.metrics_names` will give you
+ the display labels for the scalar outputs.
+ """
+ K.set_learning_phase(False)
+ num_samples = model._check_num_samples(ins, batch_size, steps, 'steps')
+ outs = []
+ if verbose == 1:
+ progbar = Progbar(target=num_samples)
+ batches = _make_batches(num_samples, batch_size)
+ index_array = np.arange(num_samples)
+ for batch_index, (batch_start, batch_end) in enumerate(batches):
+ batch_ids = index_array[batch_start:batch_end]
+ if isinstance(ins[-1], float):
+ # Do not slice the training phase flag.
+ ins_batch = _slice_arrays(ins[:-1], batch_ids) + [ins[-1]]
+ else:
+ ins_batch = _slice_arrays(ins, batch_ids)
+
+ ins_batch_converted = []
+ for ib in ins_batch:
+ ins_batch_converted.append(ops.convert_to_tensor(ib, dtype=K.floatx()))
+
+ eager_model_inputs = []
+ eager_model_outputs = []
+ for i in range(len(model.inputs)):
+ eager_model_inputs.append(ins_batch_converted[i])
+
+ for i in range(len(model.inputs), len(ins_batch_converted)):
+ eager_model_outputs.append(ins_batch_converted[i])
+
+ loss_outs, loss, loss_metrics = _model_loss(model, eager_model_inputs,
+ eager_model_outputs)
+ _, metrics_results = _eager_metrics_fn(model, loss_outs,
+ eager_model_outputs)
+ batch_outs = []
+ for _, v in zip(model.metrics_names,
+ [K.mean(loss)] + loss_metrics + metrics_results):
+ batch_outs.append(tensor_util.constant_value(v))
+
+ if isinstance(batch_outs, list):
+ if batch_index == 0:
+ for batch_out in enumerate(batch_outs):
+ outs.append(0.)
+ for i, batch_out in enumerate(batch_outs):
+ outs[i] += batch_out * len(batch_ids)
+ else:
+ if batch_index == 0:
+ outs.append(0.)
+ outs[0] += batch_outs * len(batch_ids)
+
+ if verbose == 1:
+ progbar.update(batch_end)
+ for i in range(len(outs)):
+ outs[i] /= num_samples
+ if len(outs) == 1:
+ return outs[0]
+ return outs
+
+
+def predict_loop(model, ins, batch_size=32, verbose=0, steps=None):
+ """Abstract method to loop over some data in batches.
+
+ Arguments:
+ model:
+ ins: list of tensors to be fed to `f`.
+ batch_size: integer batch size.
+ verbose: verbosity mode.
+ steps: Total number of steps (batches of samples)
+ before declaring `_predict_loop` finished.
+ Ignored with the default value of `None`.
+
+ Returns:
+ Array of predictions (if the model has a single output)
+ or list of arrays of predictions
+ (if the model has multiple outputs).
+ """
+ K.set_learning_phase(False)
+ num_samples = model._check_num_samples(ins, batch_size, steps, 'steps')
+ if verbose == 1:
+ if steps is not None:
+ progbar = Progbar(target=steps)
+ else:
+ progbar = Progbar(target=num_samples)
+
+ outs = []
+ batches = _make_batches(num_samples, batch_size)
+ index_array = np.arange(num_samples)
+ for batch_index, (batch_start, batch_end) in enumerate(batches):
+ batch_ids = index_array[batch_start:batch_end]
+ if ins and isinstance(ins[-1], float):
+ # Do not slice the training phase flag.
+ ins_batch = _slice_arrays(ins[:-1], batch_ids) + [ins[-1]]
+ else:
+ ins_batch = _slice_arrays(ins, batch_ids)
+
+ ins_batch_converted = []
+ for ib in ins_batch:
+ ins_batch_converted.append(ops.convert_to_tensor(ib, dtype=K.floatx()))
+
+ eager_model_inputs = []
+ for i in range(len(model.inputs)):
+ eager_model_inputs.append(ins_batch_converted[i])
+
+ batch_outs = model(eager_model_inputs)
+
+ if not isinstance(batch_outs, list):
+ batch_outs = [batch_outs]
+ if batch_index == 0:
+ # Pre-allocate the results arrays.
+ for batch_out in batch_outs:
+ dims = batch_out.shape[1:].dims
+ dims_list = [d.value for d in dims]
+ shape = (num_samples,) + tuple(dims_list)
+ outs.append(np.zeros(shape, dtype=batch_out.dtype.as_numpy_dtype))
+ for i, batch_out in enumerate(batch_outs):
+ outs[i][batch_start:batch_end] = batch_out
+ if verbose == 1:
+ progbar.update(batch_end)
+ if len(outs) == 1:
+ return outs[0]
+ return outs
diff --git a/tensorflow/python/keras/_impl/keras/engine/training_eager_test.py b/tensorflow/python/keras/_impl/keras/engine/training_eager_test.py
new file mode 100644
index 0000000000..81e2f7a514
--- /dev/null
+++ b/tensorflow/python/keras/_impl/keras/engine/training_eager_test.py
@@ -0,0 +1,755 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for training routines."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import numpy as np
+
+from tensorflow.python.framework import ops
+from tensorflow.python.keras._impl import keras
+from tensorflow.python.keras._impl.keras import testing_utils
+from tensorflow.python.platform import test
+from tensorflow.python.training.rmsprop import RMSPropOptimizer
+
+
+class TrainingTest(test.TestCase):
+
+ def test_fit_on_arrays(self):
+ a = keras.layers.Input(shape=(3,), name='input_a')
+ b = keras.layers.Input(shape=(3,), name='input_b')
+
+ dense = keras.layers.Dense(4, name='dense')
+ c = dense(a)
+ d = dense(b)
+ e = keras.layers.Dropout(0.5, name='dropout')(c)
+
+ model = keras.models.Model([a, b], [d, e])
+
+ optimizer = RMSPropOptimizer(learning_rate=0.001)
+ loss = 'mse'
+ loss_weights = [1., 0.5]
+ metrics = ['mae']
+ model.compile(optimizer, loss, metrics=metrics, loss_weights=loss_weights)
+
+ input_a_np = np.random.random((10, 3))
+ input_b_np = np.random.random((10, 3))
+
+ output_d_np = np.random.random((10, 4))
+ output_e_np = np.random.random((10, 4))
+
+ # Test fit at different verbosity
+ model.fit(
+ [input_a_np, input_b_np], [output_d_np, output_e_np],
+ epochs=1,
+ batch_size=5,
+ verbose=0)
+ model.fit(
+ [input_a_np, input_b_np], [output_d_np, output_e_np],
+ epochs=1,
+ batch_size=5,
+ verbose=1)
+ model.fit(
+ [input_a_np, input_b_np], [output_d_np, output_e_np],
+ epochs=2,
+ batch_size=5,
+ verbose=2)
+
+ # Test with validation data
+ model.fit(
+ [input_a_np, input_b_np], [output_d_np, output_e_np],
+ validation_data=([input_a_np, input_b_np], [output_d_np,
+ output_e_np]),
+ epochs=1,
+ batch_size=5,
+ verbose=0)
+ model.fit(
+ [input_a_np, input_b_np], [output_d_np, output_e_np],
+ validation_data=([input_a_np, input_b_np], [output_d_np,
+ output_e_np]),
+ epochs=2,
+ batch_size=5,
+ verbose=1)
+ model.fit(
+ [input_a_np, input_b_np], [output_d_np, output_e_np],
+ validation_data=([input_a_np, input_b_np], [output_d_np,
+ output_e_np]),
+ epochs=2,
+ batch_size=5,
+ verbose=2)
+ model.train_on_batch([input_a_np, input_b_np], [output_d_np, output_e_np])
+
+ # Test with validation split
+ model.fit(
+ [input_a_np, input_b_np], [output_d_np, output_e_np],
+ epochs=2,
+ batch_size=5,
+ verbose=0,
+ validation_split=0.2)
+
+ # Test with dictionary inputs
+ model.fit(
+ {
+ 'input_a': input_a_np,
+ 'input_b': input_b_np
+ }, {'dense': output_d_np,
+ 'dropout': output_e_np},
+ epochs=1,
+ batch_size=5,
+ verbose=0)
+ model.fit(
+ {
+ 'input_a': input_a_np,
+ 'input_b': input_b_np
+ }, {'dense': output_d_np,
+ 'dropout': output_e_np},
+ epochs=1,
+ batch_size=5,
+ verbose=1)
+ model.fit(
+ {
+ 'input_a': input_a_np,
+ 'input_b': input_b_np
+ }, {'dense': output_d_np,
+ 'dropout': output_e_np},
+ validation_data=({'input_a': input_a_np,
+ 'input_b': input_b_np
+ },
+ {
+ 'dense': output_d_np,
+ 'dropout': output_e_np
+ }),
+ epochs=1,
+ batch_size=5,
+ verbose=0)
+ model.train_on_batch({
+ 'input_a': input_a_np,
+ 'input_b': input_b_np
+ }, {'dense': output_d_np,
+ 'dropout': output_e_np})
+ # Test with lists for loss, metrics
+ loss = ['mae', 'mse']
+ metrics = ['acc', 'mae']
+ model.compile(optimizer, loss, metrics=metrics)
+ model.fit(
+ [input_a_np, input_b_np], [output_d_np, output_e_np],
+ epochs=1,
+ batch_size=5,
+ verbose=0)
+
+ # Test with dictionaries for loss, metrics, loss weights
+ loss = {'dense': 'mse', 'dropout': 'mae'}
+ loss_weights = {'dense': 1., 'dropout': 0.5}
+ metrics = {'dense': 'mse', 'dropout': 'mae'}
+ model.compile(optimizer, loss, metrics=metrics, loss_weights=loss_weights)
+ model.fit(
+ [input_a_np, input_b_np], [output_d_np, output_e_np],
+ epochs=1,
+ batch_size=5,
+ verbose=0)
+
+ # Invalid use cases
+ with self.assertRaises(AttributeError):
+ model.fit(
+ [input_a_np, input_b_np], [output_d_np, output_e_np],
+ epochs=1,
+ validation_data=([input_a_np, input_b_np], 0, 0),
+ verbose=0)
+ with self.assertRaises(ValueError):
+ model.train_on_batch({'input_a': input_a_np},
+ [output_d_np, output_e_np])
+ with self.assertRaises(ValueError):
+ model.train_on_batch([input_a_np], [output_d_np, output_e_np])
+ with self.assertRaises(AttributeError):
+ model.train_on_batch(1, [output_d_np, output_e_np])
+ with self.assertRaises(ValueError):
+ model.train_on_batch(input_a_np, [output_d_np, output_e_np])
+ with self.assertRaises(ValueError):
+ bad_input = np.random.random((11, 3))
+ model.train_on_batch([bad_input, input_b_np],
+ [output_d_np, output_e_np])
+ with self.assertRaises(ValueError):
+ bad_target = np.random.random((11, 4))
+ model.train_on_batch([input_a_np, input_b_np],
+ [bad_target, output_e_np])
+
+ # Build single-input model
+ x = keras.layers.Input(shape=(3,), name='input_a')
+ y = keras.layers.Dense(4)(x)
+ model = keras.models.Model(x, y)
+ model.compile(optimizer=RMSPropOptimizer(learning_rate=0.001), loss='mse')
+ # This will work
+ model.fit([input_a_np], output_d_np, epochs=1)
+ with self.assertRaises(ValueError):
+ model.fit([input_a_np, input_a_np], output_d_np, epochs=1)
+
+ def test_evaluate_predict_on_arrays(self):
+ a = keras.layers.Input(shape=(3,), name='input_a')
+ b = keras.layers.Input(shape=(3,), name='input_b')
+
+ dense = keras.layers.Dense(4, name='dense')
+ c = dense(a)
+ d = dense(b)
+ e = keras.layers.Dropout(0.5, name='dropout')(c)
+
+ model = keras.models.Model([a, b], [d, e])
+
+ optimizer = RMSPropOptimizer(learning_rate=0.001)
+ loss = 'mse'
+ loss_weights = [1., 0.5]
+ metrics = ['mae']
+ model.compile(
+ optimizer,
+ loss,
+ metrics=metrics,
+ loss_weights=loss_weights,
+ sample_weight_mode=None)
+
+ input_a_np = np.random.random((10, 3))
+ input_b_np = np.random.random((10, 3))
+
+ output_d_np = np.random.random((10, 4))
+ output_e_np = np.random.random((10, 4))
+
+ # Test evaluate at different verbosity
+ out = model.evaluate(
+ [input_a_np, input_b_np], [output_d_np, output_e_np],
+ batch_size=5,
+ verbose=0)
+ self.assertEqual(len(out), 5)
+ out = model.evaluate(
+ [input_a_np, input_b_np], [output_d_np, output_e_np],
+ batch_size=5,
+ verbose=1)
+ self.assertEqual(len(out), 5)
+ out = model.evaluate(
+ [input_a_np, input_b_np], [output_d_np, output_e_np],
+ batch_size=5,
+ verbose=2)
+ self.assertEqual(len(out), 5)
+ out = model.test_on_batch([input_a_np, input_b_np],
+ [output_d_np, output_e_np])
+ self.assertEqual(len(out), 5)
+
+ # Test evaluate with dictionary inputs
+ model.evaluate(
+ {
+ 'input_a': input_a_np,
+ 'input_b': input_b_np
+ }, {'dense': output_d_np,
+ 'dropout': output_e_np},
+ batch_size=5,
+ verbose=0)
+ model.evaluate(
+ {
+ 'input_a': input_a_np,
+ 'input_b': input_b_np
+ }, {'dense': output_d_np,
+ 'dropout': output_e_np},
+ batch_size=5,
+ verbose=1)
+
+ # Test predict
+ out = model.predict([input_a_np, input_b_np], batch_size=5)
+ self.assertEqual(len(out), 2)
+ out = model.predict({'input_a': input_a_np, 'input_b': input_b_np})
+ self.assertEqual(len(out), 2)
+ out = model.predict_on_batch({
+ 'input_a': input_a_np,
+ 'input_b': input_b_np
+ })
+ self.assertEqual(len(out), 2)
+
+ def test_invalid_loss_or_metrics(self):
+ num_classes = 5
+ train_samples = 1000
+ test_samples = 1000
+ input_dim = 5
+
+ model = keras.models.Sequential()
+ model.add(keras.layers.Dense(10, input_shape=(input_dim,)))
+ model.add(keras.layers.Activation('relu'))
+ model.add(keras.layers.Dense(num_classes))
+ model.add(keras.layers.Activation('softmax'))
+ model.compile(loss='categorical_crossentropy',
+ optimizer=RMSPropOptimizer(learning_rate=0.001))
+ np.random.seed(1337)
+
+ (x_train, y_train), (_, _) = testing_utils.get_test_data(
+ train_samples=train_samples,
+ test_samples=test_samples,
+ input_shape=(input_dim,),
+ num_classes=num_classes)
+
+ with self.assertRaises(ValueError):
+ model.fit(x_train, np.concatenate([y_train, y_train], axis=-1))
+
+ with self.assertRaises(TypeError):
+ model.compile(loss='categorical_crossentropy',
+ optimizer=RMSPropOptimizer(learning_rate=0.001),
+ metrics=set(0))
+
+ with self.assertRaises(ValueError):
+ model.compile(loss=None,
+ optimizer='rms')
+
+
+class LossWeightingTest(test.TestCase):
+
+ def test_class_weights(self):
+ num_classes = 5
+ batch_size = 5
+ epochs = 5
+ weighted_class = 3
+ train_samples = 3000
+ test_samples = 3000
+ input_dim = 5
+
+ model = keras.models.Sequential()
+ model.add(keras.layers.Dense(10, input_shape=(input_dim,)))
+ model.add(keras.layers.Activation('relu'))
+ model.add(keras.layers.Dense(num_classes))
+ model.add(keras.layers.Activation('softmax'))
+ model.compile(loss='categorical_crossentropy',
+ optimizer=RMSPropOptimizer(learning_rate=0.001))
+
+ np.random.seed(1337)
+ (x_train, y_train), (x_test, y_test) = testing_utils.get_test_data(
+ train_samples=train_samples,
+ test_samples=test_samples,
+ input_shape=(input_dim,),
+ num_classes=num_classes)
+ int_y_test = y_test.copy()
+ int_y_train = y_train.copy()
+ # convert class vectors to binary class matrices
+ y_train = keras.utils.to_categorical(y_train, num_classes)
+ y_test = keras.utils.to_categorical(y_test, num_classes)
+ test_ids = np.where(int_y_test == np.array(weighted_class))[0]
+
+ class_weight = dict([(i, 1.) for i in range(num_classes)])
+ class_weight[weighted_class] = 2.
+
+ sample_weight = np.ones((y_train.shape[0]))
+ sample_weight[int_y_train == weighted_class] = 2.
+
+ model.fit(
+ x_train,
+ y_train,
+ batch_size=batch_size,
+ epochs=epochs // 3,
+ verbose=0,
+ class_weight=class_weight,
+ validation_data=(x_train, y_train, sample_weight))
+ model.fit(
+ x_train,
+ y_train,
+ batch_size=batch_size,
+ epochs=epochs // 2,
+ verbose=0,
+ class_weight=class_weight)
+ model.fit(
+ x_train,
+ y_train,
+ batch_size=batch_size,
+ epochs=epochs // 2,
+ verbose=0,
+ class_weight=class_weight,
+ validation_split=0.1)
+
+ model.train_on_batch(
+ x_train[:batch_size], y_train[:batch_size], class_weight=class_weight)
+ ref_score = model.evaluate(x_test, y_test, verbose=0)
+ score = model.evaluate(
+ x_test[test_ids, :], y_test[test_ids, :], verbose=0)
+ self.assertLess(score, ref_score)
+
+ def test_sample_weights(self):
+ num_classes = 5
+ batch_size = 5
+ epochs = 5
+ weighted_class = 3
+ train_samples = 3000
+ test_samples = 3000
+ input_dim = 5
+
+ model = keras.models.Sequential()
+ model.add(keras.layers.Dense(10, input_shape=(input_dim,)))
+ model.add(keras.layers.Activation('relu'))
+ model.add(keras.layers.Dense(num_classes))
+ model.add(keras.layers.Activation('softmax'))
+ model.compile(loss='categorical_crossentropy',
+ optimizer=RMSPropOptimizer(learning_rate=0.001))
+
+ np.random.seed(43)
+ (x_train, y_train), (x_test, y_test) = testing_utils.get_test_data(
+ train_samples=train_samples,
+ test_samples=test_samples,
+ input_shape=(input_dim,),
+ num_classes=num_classes)
+ int_y_test = y_test.copy()
+ int_y_train = y_train.copy()
+ # convert class vectors to binary class matrices
+ y_train = keras.utils.to_categorical(y_train, num_classes)
+ y_test = keras.utils.to_categorical(y_test, num_classes)
+ test_ids = np.where(int_y_test == np.array(weighted_class))[0]
+
+ class_weight = dict([(i, 1.) for i in range(num_classes)])
+ class_weight[weighted_class] = 2.
+
+ sample_weight = np.ones((y_train.shape[0]))
+ sample_weight[int_y_train == weighted_class] = 2.
+
+ model.fit(
+ x_train,
+ y_train,
+ batch_size=batch_size,
+ epochs=epochs // 3,
+ verbose=0,
+ sample_weight=sample_weight)
+ model.fit(
+ x_train,
+ y_train,
+ batch_size=batch_size,
+ epochs=epochs // 3,
+ verbose=0,
+ sample_weight=sample_weight,
+ validation_split=0.1)
+ model.train_on_batch(
+ x_train[:batch_size],
+ y_train[:batch_size],
+ sample_weight=sample_weight[:batch_size])
+ model.test_on_batch(
+ x_train[:batch_size],
+ y_train[:batch_size],
+ sample_weight=sample_weight[:batch_size])
+
+ def test_temporal_sample_weights(self):
+ num_classes = 5
+ weighted_class = 3
+ train_samples = 1000
+ test_samples = 1000
+ input_dim = 5
+ timesteps = 3
+
+ model = keras.models.Sequential()
+ model.add(
+ keras.layers.TimeDistributed(
+ keras.layers.Dense(num_classes),
+ input_shape=(timesteps, input_dim)))
+ model.add(keras.layers.Activation('softmax'))
+
+ np.random.seed(1337)
+ (_, y_train), _ = testing_utils.get_test_data(
+ train_samples=train_samples,
+ test_samples=test_samples,
+ input_shape=(input_dim,),
+ num_classes=num_classes)
+ int_y_train = y_train.copy()
+ # convert class vectors to binary class matrices
+ y_train = keras.utils.to_categorical(y_train, num_classes)
+
+ class_weight = dict([(i, 1.) for i in range(num_classes)])
+ class_weight[weighted_class] = 2.
+
+ sample_weight = np.ones((y_train.shape[0]))
+ sample_weight[int_y_train == weighted_class] = 2.
+ with self.assertRaises(ValueError):
+ model.compile(
+ loss='binary_crossentropy',
+ optimizer=RMSPropOptimizer(learning_rate=0.001),
+ sample_weight_mode='temporal')
+
+ def test_class_weight_invalid_use_case(self):
+ num_classes = 5
+ train_samples = 1000
+ test_samples = 1000
+ input_dim = 5
+ timesteps = 3
+
+ model = keras.models.Sequential()
+ model.add(
+ keras.layers.TimeDistributed(
+ keras.layers.Dense(num_classes),
+ input_shape=(timesteps, input_dim)))
+ model.add(keras.layers.Activation('softmax'))
+ model.compile(
+ loss='binary_crossentropy',
+ optimizer=RMSPropOptimizer(learning_rate=0.001))
+
+ (x_train, y_train), _ = testing_utils.get_test_data(
+ train_samples=train_samples,
+ test_samples=test_samples,
+ input_shape=(input_dim,),
+ num_classes=num_classes)
+ # convert class vectors to binary class matrices
+ y_train = keras.utils.to_categorical(y_train, num_classes)
+ class_weight = dict([(i, 1.) for i in range(num_classes)])
+
+ del class_weight[1]
+ with self.assertRaises(ValueError):
+ model.fit(x_train, y_train,
+ epochs=0, verbose=0, class_weight=class_weight)
+
+ with self.assertRaises(ValueError):
+ model.compile(
+ loss='binary_crossentropy',
+ optimizer=RMSPropOptimizer(learning_rate=0.001),
+ sample_weight_mode=[])
+
+ # Build multi-output model
+ x = keras.Input((3,))
+ y1 = keras.layers.Dense(4, name='1')(x)
+ y2 = keras.layers.Dense(4, name='2')(x)
+ model = keras.models.Model(x, [y1, y2])
+ model.compile(optimizer=RMSPropOptimizer(learning_rate=0.001), loss='mse')
+ x_np = np.random.random((10, 3))
+ y_np = np.random.random((10, 4))
+ w_np = np.random.random((10,))
+ # This will work
+ model.fit(x_np, [y_np, y_np], epochs=1, sample_weight={'1': w_np})
+ # These will not
+ with self.assertRaises(ValueError):
+ model.fit(x_np, [y_np, y_np], epochs=1, sample_weight=[w_np])
+ with self.assertRaises(TypeError):
+ model.fit(x_np, [y_np, y_np], epochs=1, sample_weight=w_np)
+ with self.assertRaises(ValueError):
+ bad_w_np = np.random.random((11,))
+ model.fit(x_np, [y_np, y_np], epochs=1, sample_weight={'1': bad_w_np})
+ with self.assertRaises(ValueError):
+ bad_w_np = np.random.random((10, 2))
+ model.fit(x_np, [y_np, y_np], epochs=1, sample_weight={'1': bad_w_np})
+ with self.assertRaises(ValueError):
+ bad_w_np = np.random.random((10, 2, 2))
+ model.fit(x_np, [y_np, y_np], epochs=1, sample_weight={'1': bad_w_np})
+
+
+class TestDynamicTrainability(test.TestCase):
+
+ def test_trainable_warning(self):
+ x = np.random.random((5, 3))
+ y = np.random.random((5, 2))
+ model = keras.models.Sequential()
+ model.add(keras.layers.Dense(2, input_dim=3))
+ model.trainable = False
+ model.compile(RMSPropOptimizer(learning_rate=0.001), 'mse')
+ model.trainable = True
+ with self.assertRaises(ValueError):
+ model.train_on_batch(x, y)
+
+ def test_trainable_argument(self):
+ x = np.random.random((5, 3))
+ y = np.random.random((5, 2))
+
+ model = keras.models.Sequential()
+ model.add(keras.layers.Dense(2, input_dim=3, trainable=False))
+ model.compile(RMSPropOptimizer(learning_rate=0.001), 'mse')
+ out = model.predict(x)
+ with self.assertRaises(ValueError):
+ model.train_on_batch(x, y)
+ out_2 = model.predict(x)
+ self.assertAllClose(out, out_2)
+
+ # test with nesting
+ inputs = keras.layers.Input(shape=(3,))
+ output = model(inputs)
+ model = keras.models.Model(inputs, output)
+ model.compile(RMSPropOptimizer(learning_rate=0.001), 'mse')
+ out = model.predict(x)
+ with self.assertRaises(ValueError):
+ model.train_on_batch(x, y)
+ out_2 = model.predict(x)
+ self.assertAllClose(out, out_2)
+
+ def test_layer_trainability_switch(self):
+ # with constructor argument, in Sequential
+ model = keras.models.Sequential()
+ model.add(keras.layers.Dense(2, trainable=False, input_dim=1))
+ self.assertListEqual(model.trainable_weights, [])
+
+ # by setting the `trainable` argument, in Sequential
+ model = keras.models.Sequential()
+ layer = keras.layers.Dense(2, input_dim=1)
+ model.add(layer)
+ self.assertListEqual(model.trainable_weights, layer.trainable_weights)
+ layer.trainable = False
+ self.assertListEqual(model.trainable_weights, [])
+
+ # with constructor argument, in Model
+ x = keras.layers.Input(shape=(1,))
+ y = keras.layers.Dense(2, trainable=False)(x)
+ model = keras.models.Model(x, y)
+ self.assertListEqual(model.trainable_weights, [])
+
+ # by setting the `trainable` argument, in Model
+ x = keras.layers.Input(shape=(1,))
+ layer = keras.layers.Dense(2)
+ y = layer(x)
+ model = keras.models.Model(x, y)
+ self.assertListEqual(model.trainable_weights, layer.trainable_weights)
+ layer.trainable = False
+ self.assertListEqual(model.trainable_weights, [])
+
+ def test_model_trainability_switch(self):
+ # a non-trainable model has no trainable weights
+ x = keras.layers.Input(shape=(1,))
+ y = keras.layers.Dense(2)(x)
+ model = keras.models.Model(x, y)
+ model.trainable = False
+ self.assertListEqual(model.trainable_weights, [])
+
+ # same for Sequential
+ model = keras.models.Sequential()
+ model.add(keras.layers.Dense(2, input_dim=1))
+ model.trainable = False
+ self.assertListEqual(model.trainable_weights, [])
+
+ def test_nested_model_trainability(self):
+
+ # a Sequential inside a Model
+ inner_model = keras.models.Sequential()
+ inner_model.add(keras.layers.Dense(2, input_dim=1))
+
+ x = keras.layers.Input(shape=(1,))
+ y = inner_model(x)
+ outer_model = keras.models.Model(x, y)
+ self.assertListEqual(outer_model.trainable_weights,
+ inner_model.trainable_weights)
+ inner_model.trainable = False
+ self.assertListEqual(outer_model.trainable_weights, [])
+ inner_model.trainable = True
+ inner_model.layers[-1].trainable = False
+ self.assertListEqual(outer_model.trainable_weights, [])
+
+ # a Sequential inside a Sequential
+ inner_model = keras.models.Sequential()
+ inner_model.add(keras.layers.Dense(2, input_dim=1))
+ outer_model = keras.models.Sequential()
+ outer_model.add(inner_model)
+ self.assertListEqual(outer_model.trainable_weights,
+ inner_model.trainable_weights)
+ inner_model.trainable = False
+ self.assertListEqual(outer_model.trainable_weights, [])
+ inner_model.trainable = True
+ inner_model.layers[-1].trainable = False
+ self.assertListEqual(outer_model.trainable_weights, [])
+
+ # a Model inside a Model
+ x = keras.layers.Input(shape=(1,))
+ y = keras.layers.Dense(2)(x)
+ inner_model = keras.models.Model(x, y)
+ x = keras.layers.Input(shape=(1,))
+ y = inner_model(x)
+ outer_model = keras.models.Model(x, y)
+ self.assertListEqual(outer_model.trainable_weights,
+ inner_model.trainable_weights)
+ inner_model.trainable = False
+ self.assertListEqual(outer_model.trainable_weights, [])
+ inner_model.trainable = True
+ inner_model.layers[-1].trainable = False
+ self.assertListEqual(outer_model.trainable_weights, [])
+
+ # a Model inside a Sequential
+ x = keras.layers.Input(shape=(1,))
+ y = keras.layers.Dense(2)(x)
+ inner_model = keras.models.Model(x, y)
+ outer_model = keras.models.Sequential()
+ outer_model.add(inner_model)
+ self.assertListEqual(outer_model.trainable_weights,
+ inner_model.trainable_weights)
+ inner_model.trainable = False
+ self.assertListEqual(outer_model.trainable_weights, [])
+ inner_model.trainable = True
+ inner_model.layers[-1].trainable = False
+ self.assertListEqual(outer_model.trainable_weights, [])
+
+
+class TestTrainingUtils(test.TestCase):
+
+ def test_check_array_lengths(self):
+ keras.engine.training._check_array_lengths(None, None, None)
+ a_np = np.random.random((4, 3, 3))
+ keras.engine.training._check_array_lengths(a_np, a_np, a_np)
+ keras.engine.training._check_array_lengths(
+ [a_np, a_np], [a_np, a_np], [a_np, a_np])
+ keras.engine.training._check_array_lengths([None], [None], [None])
+
+ b_np = np.random.random((3, 4))
+ with self.assertRaises(ValueError):
+ keras.engine.training._check_array_lengths(a_np, None, None)
+ with self.assertRaises(ValueError):
+ keras.engine.training._check_array_lengths(a_np, a_np, None)
+ with self.assertRaises(ValueError):
+ keras.engine.training._check_array_lengths([a_np], [None], None)
+ with self.assertRaises(ValueError):
+ keras.engine.training._check_array_lengths([a_np], [b_np], None)
+ with self.assertRaises(ValueError):
+ keras.engine.training._check_array_lengths([a_np], None, [b_np])
+
+ def test_slice_arrays(self):
+ input_a = np.random.random((10, 3))
+ keras.engine.training._slice_arrays(None)
+ keras.engine.training._slice_arrays(input_a, 0)
+ keras.engine.training._slice_arrays(input_a, 0, 1)
+ keras.engine.training._slice_arrays(input_a, stop=2)
+ input_a = [None, [1, 1], None, [1, 1]]
+ keras.engine.training._slice_arrays(input_a, 0)
+ keras.engine.training._slice_arrays(input_a, 0, 1)
+ keras.engine.training._slice_arrays(input_a, stop=2)
+ input_a = [None]
+ keras.engine.training._slice_arrays(input_a, 0)
+ keras.engine.training._slice_arrays(input_a, 0, 1)
+ keras.engine.training._slice_arrays(input_a, stop=2)
+ input_a = None
+ keras.engine.training._slice_arrays(input_a, 0)
+ keras.engine.training._slice_arrays(input_a, 0, 1)
+ keras.engine.training._slice_arrays(input_a, stop=2)
+
+ def test_fit_with_BatchNorm(self):
+ model = keras.models.Sequential()
+ model.add(keras.layers.Dense(10, input_dim=4))
+ model.add(keras.layers.BatchNormalization())
+ model.add(keras.layers.Activation('tanh'))
+ model.add(keras.layers.Dropout(0.2))
+
+ input_a_np = np.random.random((10, 4))
+ output_b_np = np.random.random((10, 10))
+
+ model.compile(loss='binary_crossentropy', optimizer=RMSPropOptimizer(0.001))
+ model.fit(input_a_np, output_b_np, epochs=1, batch_size=5, verbose=0)
+
+ def test_fit_with_regularization(self):
+ model = keras.models.Sequential()
+ with self.assertRaises(ValueError):
+ model.add(
+ keras.layers.Dense(4, input_dim=3,
+ kernel_regularizer=keras.regularizers.l2(0.01),
+ activity_regularizer=keras.regularizers.l1(0.01)))
+
+
+if __name__ == '__main__':
+ # Bazel sets these environment variables to very long paths.
+ # Tempfile uses them to create long paths, and in turn multiprocessing
+ # library tries to create sockets named after paths. Delete whatever bazel
+ # writes to these to avoid tests failing due to socket addresses being too
+ # long.
+ for var in ('TMPDIR', 'TMP', 'TEMP'):
+ if var in os.environ:
+ del os.environ[var]
+
+ ops.enable_eager_execution()
+ test.main()
diff --git a/tensorflow/python/keras/_impl/keras/engine/training_test.py b/tensorflow/python/keras/_impl/keras/engine/training_test.py
index 5a033a04ad..b380238e4e 100644
--- a/tensorflow/python/keras/_impl/keras/engine/training_test.py
+++ b/tensorflow/python/keras/_impl/keras/engine/training_test.py
@@ -78,6 +78,14 @@ class TrainingTest(test.TestCase):
verbose=2)
model.train_on_batch([input_a_np, input_b_np], [output_d_np, output_e_np])
+ # Test model with input data as a list of lists
+ model.fit(
+ [np.ndarray.tolist(input_a_np), np.ndarray.tolist(input_b_np)],
+ [output_d_np, output_e_np],
+ epochs=2,
+ batch_size=5,
+ verbose=2)
+
# Test with validation data
model.fit(
[input_a_np, input_b_np], [output_d_np, output_e_np],
@@ -205,6 +213,16 @@ class TrainingTest(test.TestCase):
with self.assertRaises(ValueError):
model.fit([input_a_np, input_a_np], output_d_np, epochs=1)
+ # Test model on a list of floats
+ input_a_np = np.random.random((10, 3))
+ input_b_np = np.random.random((10, 4))
+
+ model.fit([np.ndarray.tolist(input_a_np)],
+ [np.ndarray.tolist(input_b_np)],
+ epochs=2,
+ batch_size=5,
+ verbose=2)
+
def test_evaluate_predict_on_arrays(self):
with self.test_session():
a = keras.layers.Input(shape=(3,), name='input_a')
diff --git a/tensorflow/python/keras/_impl/keras/layers/core.py b/tensorflow/python/keras/_impl/keras/layers/core.py
index 6ee3fb48b2..ea2d3f2f04 100644
--- a/tensorflow/python/keras/_impl/keras/layers/core.py
+++ b/tensorflow/python/keras/_impl/keras/layers/core.py
@@ -23,6 +23,7 @@ import types as python_types
import numpy as np
+from tensorflow.python.eager import context
from tensorflow.python.framework import tensor_shape
from tensorflow.python.keras._impl.keras import activations
from tensorflow.python.keras._impl.keras import backend as K
@@ -119,7 +120,8 @@ class Dropout(tf_core_layers.Dropout, Layer):
if training is None:
training = K.learning_phase()
output = super(Dropout, self).call(inputs, training=training)
- if training is K.learning_phase():
+ # EagerTensor object has no attribute _uses_learning_phase
+ if not context.in_eager_mode() and training is K.learning_phase():
output._uses_learning_phase = True # pylint: disable=protected-access
return output
diff --git a/tensorflow/python/keras/_impl/keras/layers/normalization.py b/tensorflow/python/keras/_impl/keras/layers/normalization.py
index 965ef70e6e..eecb14ceaa 100644
--- a/tensorflow/python/keras/_impl/keras/layers/normalization.py
+++ b/tensorflow/python/keras/_impl/keras/layers/normalization.py
@@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from tensorflow.python.eager import context
from tensorflow.python.keras._impl.keras import backend as K
from tensorflow.python.keras._impl.keras import constraints
from tensorflow.python.keras._impl.keras import initializers
@@ -108,7 +109,7 @@ class BatchNormalization(tf_normalization_layers.BatchNormalization, Layer):
if training is None:
training = K.learning_phase()
output = super(BatchNormalization, self).call(inputs, training=training)
- if training is K.learning_phase():
+ if context.in_graph_mode() and training is K.learning_phase():
output._uses_learning_phase = True # pylint: disable=protected-access
return output
diff --git a/tensorflow/python/keras/_impl/keras/optimizers.py b/tensorflow/python/keras/_impl/keras/optimizers.py
index e47987aadc..a55a5e39a6 100644
--- a/tensorflow/python/keras/_impl/keras/optimizers.py
+++ b/tensorflow/python/keras/_impl/keras/optimizers.py
@@ -24,6 +24,7 @@ import copy
import six
from six.moves import zip # pylint: disable=redefined-builtin
+from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes as dtypes_module
from tensorflow.python.framework import ops
from tensorflow.python.keras._impl.keras import backend as K
@@ -680,7 +681,14 @@ class TFOptimizer(Optimizer):
def __init__(self, optimizer): # pylint: disable=super-init-not-called
self.optimizer = optimizer
with K.name_scope(self.__class__.__name__):
- self.iterations = K.variable(0, dtype='int64', name='iterations')
+ if context.in_graph_mode():
+ self.iterations = K.variable(0, dtype='int64', name='iterations')
+
+ def apply_gradients(self, grads):
+ self.optimizer.apply_gradients(grads)
+
+ def get_grads(self, loss, params):
+ return self.optimizer.compute_gradients(loss, params)
def get_updates(self, loss, params):
grads = self.optimizer.compute_gradients(loss, params)
diff --git a/tensorflow/python/kernel_tests/extract_image_patches_op_test.py b/tensorflow/python/kernel_tests/extract_image_patches_op_test.py
index 5c7624f1f6..6ea9f1badc 100644
--- a/tensorflow/python/kernel_tests/extract_image_patches_op_test.py
+++ b/tensorflow/python/kernel_tests/extract_image_patches_op_test.py
@@ -84,7 +84,7 @@ class ExtractImagePatches(test.TestCase):
patches=patches)
def testKsize2x2Stride1x1Rate1x1Valid(self):
- """Test for 1x1 kernel ."""
+ """Test for 2x2 kernel with VALID padding."""
# [1, 2, 2, 1]
image = [[[[1], [2]], [[3], [4]]]]
# [1, 1, 1, 4]
@@ -98,7 +98,7 @@ class ExtractImagePatches(test.TestCase):
patches=patches)
def testKsize2x2Stride1x1Rate1x1Same(self):
- """Test for 1x1 kernel ."""
+ """Test for 2x2 kernel with SAME padding."""
# [1, 2, 2, 1]
image = [[[[1], [2]], [[3], [4]]]]
# [1, 2, 2, 4]
@@ -111,6 +111,20 @@ class ExtractImagePatches(test.TestCase):
padding="SAME",
patches=patches)
+ def testKsize2x2Stride1x1Rate2x2Valid(self):
+ """Test for 2x2 kernel with 2x2 dilation."""
+ # [1, 2, 2, 1]
+ image = np.arange(16).reshape(1, 4, 4, 1).astype(np.float32)
+ # [1, 2, 2, 4]
+ patches = [[[[0, 2, 8, 10], [1, 3, 9, 11]],
+ [[4, 6, 12, 14], [5, 7, 13, 15]]]]
+ self._VerifyValues(
+ image,
+ ksizes=[2, 2],
+ strides=[1, 1],
+ rates=[2, 2],
+ padding="VALID",
+ patches=patches)
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/layers/base.py b/tensorflow/python/layers/base.py
index 04b6056ace..5dea732cba 100644
--- a/tensorflow/python/layers/base.py
+++ b/tensorflow/python/layers/base.py
@@ -31,6 +31,7 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.layers import utils as layers_util
+from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.ops import variables as tf_variables
@@ -649,6 +650,7 @@ class Layer(object):
else:
scope_context_manager = vs.variable_scope(
self._scope, reuse=self._reuse, auxiliary_name_scope=False)
+ input_shapes = None
with scope_context_manager as scope:
with ops.name_scope(self._name_scope_name(scope)):
if not self.built:
@@ -698,6 +700,9 @@ class Layer(object):
else:
# Deferred mode behavior: use `compute_output_shape` to
# infer the number of outputs of the layer and their shapes.
+ if input_shapes is None:
+ input_shapes = nest.map_structure(lambda x: x.get_shape(), inputs)
+
output_shapes = self.compute_output_shape(input_shapes)
output_shapes = nest.flatten(output_shapes)
outputs = [
@@ -1393,7 +1398,10 @@ class _DeferredTensor(object):
def __init__(self, shape, dtype, name=None):
self.shape = tensor_shape.TensorShape(shape)
- self.dtype = dtypes.as_dtype(dtype)
+ if dtype is None:
+ self.dtype = dtypes.as_dtype(np.float32)
+ else:
+ self.dtype = dtypes.as_dtype(dtype)
self.name = name
def get_shape(self):
diff --git a/tensorflow/python/layers/network.py b/tensorflow/python/layers/network.py
index 0a5dd57621..745843975c 100644
--- a/tensorflow/python/layers/network.py
+++ b/tensorflow/python/layers/network.py
@@ -621,6 +621,11 @@ class GraphNetwork(base.Layer):
A list of loss tensors.
"""
losses = []
+ if context.in_eager_mode():
+ for layer in self.layers:
+ losses += layer.losses
+ return losses
+
# Retrieve losses for all internal layers.
for layer in self.layers:
if hasattr(layer, 'losses'):
@@ -853,7 +858,6 @@ class GraphNetwork(base.Layer):
for node in nodes:
# This is always a single layer, never a list.
layer = node.outbound_layer
-
reference_input_tensors = node.input_tensors
reference_output_tensors = node.output_tensors
@@ -901,12 +905,13 @@ class GraphNetwork(base.Layer):
else:
output_masks = [None for _ in range(len(output_tensors))]
- # Apply activity regularizer if any:
- if layer.activity_regularizer is not None:
- regularization_losses = [
- layer.activity_regularizer(x) for x in computed_tensors
- ]
- layer.add_loss(regularization_losses, computed_tensors)
+ if context.in_graph_mode():
+ if layer.activity_regularizer is not None:
+ regularization_losses = [
+ layer.activity_regularizer(x) for x in computed_tensors
+ ]
+ # Apply activity regularizer if any:
+ layer.add_loss(regularization_losses, computed_tensors)
if context.in_graph_mode():
# Update model updates and losses:
diff --git a/tensorflow/python/lib/io/file_io.py b/tensorflow/python/lib/io/file_io.py
index 4e3071d851..59f5075f17 100644
--- a/tensorflow/python/lib/io/file_io.py
+++ b/tensorflow/python/lib/io/file_io.py
@@ -31,6 +31,7 @@ from tensorflow.python.framework import c_api_util
from tensorflow.python.framework import errors
from tensorflow.python.util import compat
from tensorflow.python.util import deprecation
+from tensorflow.python.util.tf_export import tf_export
class FileIO(object):
@@ -235,6 +236,7 @@ class FileIO(object):
self._writable_file = None
+@tf_export("gfile.Exists")
def file_exists(filename):
"""Determines whether a path exists or not.
@@ -256,6 +258,7 @@ def file_exists(filename):
return True
+@tf_export("gfile.Remove")
def delete_file(filename):
"""Deletes the file located at 'filename'.
@@ -306,6 +309,7 @@ def write_string_to_file(filename, file_content):
f.write(file_content)
+@tf_export("gfile.Glob")
def get_matching_files(filename):
"""Returns a list of files that match the given pattern(s).
@@ -336,6 +340,7 @@ def get_matching_files(filename):
]
+@tf_export("gfile.MkDir")
def create_dir(dirname):
"""Creates a directory with the name 'dirname'.
@@ -353,6 +358,7 @@ def create_dir(dirname):
pywrap_tensorflow.CreateDir(compat.as_bytes(dirname), status)
+@tf_export("gfile.MakeDirs")
def recursive_create_dir(dirname):
"""Creates a directory and all parent/intermediate directories.
@@ -368,6 +374,7 @@ def recursive_create_dir(dirname):
pywrap_tensorflow.RecursivelyCreateDir(compat.as_bytes(dirname), status)
+@tf_export("gfile.Copy")
def copy(oldpath, newpath, overwrite=False):
"""Copies data from oldpath to newpath.
@@ -385,6 +392,7 @@ def copy(oldpath, newpath, overwrite=False):
compat.as_bytes(oldpath), compat.as_bytes(newpath), overwrite, status)
+@tf_export("gfile.Rename")
def rename(oldname, newname, overwrite=False):
"""Rename or move a file / directory.
@@ -426,6 +434,7 @@ def atomic_write_string_to_file(filename, contents, overwrite=True):
raise
+@tf_export("gfile.DeleteRecursively")
def delete_recursively(dirname):
"""Deletes everything under dirname recursively.
@@ -439,6 +448,7 @@ def delete_recursively(dirname):
pywrap_tensorflow.DeleteRecursively(compat.as_bytes(dirname), status)
+@tf_export("gfile.IsDirectory")
def is_directory(dirname):
"""Returns whether the path is a directory or not.
@@ -452,6 +462,7 @@ def is_directory(dirname):
return pywrap_tensorflow.IsDirectory(compat.as_bytes(dirname), status)
+@tf_export("gfile.ListDirectory")
def list_directory(dirname):
"""Returns a list of entries contained within a directory.
@@ -479,6 +490,7 @@ def list_directory(dirname):
]
+@tf_export("gfile.Walk")
def walk(top, in_order=True):
"""Recursive directory tree generator for directories.
@@ -522,6 +534,7 @@ def walk(top, in_order=True):
yield here
+@tf_export("gfile.Stat")
def stat(filename):
"""Returns file statistics for a given path.
diff --git a/tensorflow/python/lib/io/tf_record.py b/tensorflow/python/lib/io/tf_record.py
index df19010068..48ea107a14 100644
--- a/tensorflow/python/lib/io/tf_record.py
+++ b/tensorflow/python/lib/io/tf_record.py
@@ -22,8 +22,10 @@ from __future__ import print_function
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.framework import errors
from tensorflow.python.util import compat
+from tensorflow.python.util.tf_export import tf_export
+@tf_export("python_io.TFRecordCompressionType")
class TFRecordCompressionType(object):
"""The type of compression for the record."""
NONE = 0
@@ -33,6 +35,7 @@ class TFRecordCompressionType(object):
# NOTE(vrv): This will eventually be converted into a proto. to match
# the interface used by the C++ RecordWriter.
+@tf_export("python_io.TFRecordOptions")
class TFRecordOptions(object):
"""Options used for manipulating TFRecord files."""
compression_type_map = {
@@ -51,6 +54,7 @@ class TFRecordOptions(object):
return cls.compression_type_map[options.compression_type]
+@tf_export("python_io.tf_record_iterator")
def tf_record_iterator(path, options=None):
"""An iterator that read the records from a TFRecords file.
@@ -81,6 +85,7 @@ def tf_record_iterator(path, options=None):
reader.Close()
+@tf_export("python_io.TFRecordWriter")
class TFRecordWriter(object):
"""A class to write records to a TFRecords file.
diff --git a/tensorflow/python/ops/losses/losses_impl.py b/tensorflow/python/ops/losses/losses_impl.py
index 72508eb435..73563486e1 100644
--- a/tensorflow/python/ops/losses/losses_impl.py
+++ b/tensorflow/python/ops/losses/losses_impl.py
@@ -28,8 +28,10 @@ from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import weights_broadcast_ops
from tensorflow.python.ops.losses import util
from tensorflow.python.util.deprecation import deprecated_args
+from tensorflow.python.util.tf_export import tf_export
+@tf_export("losses.Reduction")
class Reduction(object):
"""Types of loss reduction.
@@ -152,6 +154,7 @@ def _num_elements(losses):
return array_ops.size(losses, name=scope, out_type=losses.dtype)
+@tf_export("losses.compute_weighted_loss")
def compute_weighted_loss(
losses, weights=1.0, scope=None, loss_collection=ops.GraphKeys.LOSSES,
reduction=Reduction.SUM_BY_NONZERO_WEIGHTS):
@@ -211,6 +214,7 @@ def compute_weighted_loss(
return loss
+@tf_export("losses.absolute_difference")
def absolute_difference(
labels, predictions, weights=1.0, scope=None,
loss_collection=ops.GraphKeys.LOSSES,
@@ -258,6 +262,7 @@ def absolute_difference(
losses, weights, scope, loss_collection, reduction=reduction)
+@tf_export("losses.cosine_distance")
@deprecated_args(None, "dim is deprecated, use axis instead", "dim")
def cosine_distance(
labels, predictions, axis=None, weights=1.0, scope=None,
@@ -311,6 +316,7 @@ def cosine_distance(
losses, weights, scope, loss_collection, reduction=reduction)
+@tf_export("losses.hinge_loss")
def hinge_loss(labels, logits, weights=1.0, scope=None,
loss_collection=ops.GraphKeys.LOSSES,
reduction=Reduction.SUM_BY_NONZERO_WEIGHTS):
@@ -352,6 +358,7 @@ def hinge_loss(labels, logits, weights=1.0, scope=None,
losses, weights, scope, loss_collection, reduction=reduction)
+@tf_export("losses.huber_loss")
def huber_loss(labels, predictions, weights=1.0, delta=1.0, scope=None,
loss_collection=ops.GraphKeys.LOSSES,
reduction=Reduction.SUM_BY_NONZERO_WEIGHTS):
@@ -420,6 +427,7 @@ def huber_loss(labels, predictions, weights=1.0, delta=1.0, scope=None,
losses, weights, scope, loss_collection, reduction=reduction)
+@tf_export("losses.log_loss")
def log_loss(labels, predictions, weights=1.0, epsilon=1e-7, scope=None,
loss_collection=ops.GraphKeys.LOSSES,
reduction=Reduction.SUM_BY_NONZERO_WEIGHTS):
@@ -471,6 +479,7 @@ def log_loss(labels, predictions, weights=1.0, epsilon=1e-7, scope=None,
# TODO(b/37208492): Add reduction arg.
+@tf_export("losses.mean_pairwise_squared_error")
def mean_pairwise_squared_error(
labels, predictions, weights=1.0, scope=None,
loss_collection=ops.GraphKeys.LOSSES):
@@ -557,6 +566,7 @@ def mean_pairwise_squared_error(
return mean_loss
+@tf_export("losses.mean_squared_error")
def mean_squared_error(
labels, predictions, weights=1.0, scope=None,
loss_collection=ops.GraphKeys.LOSSES,
@@ -604,6 +614,7 @@ def mean_squared_error(
losses, weights, scope, loss_collection, reduction=reduction)
+@tf_export("losses.sigmoid_cross_entropy")
def sigmoid_cross_entropy(
multi_class_labels, logits, weights=1.0, label_smoothing=0, scope=None,
loss_collection=ops.GraphKeys.LOSSES,
@@ -662,6 +673,7 @@ def sigmoid_cross_entropy(
losses, weights, scope, loss_collection, reduction=reduction)
+@tf_export("losses.softmax_cross_entropy")
def softmax_cross_entropy(
onehot_labels, logits, weights=1.0, label_smoothing=0, scope=None,
loss_collection=ops.GraphKeys.LOSSES,
@@ -771,6 +783,7 @@ def _remove_squeezable_dimensions(
return labels, predictions, weights
+@tf_export("losses.sparse_softmax_cross_entropy")
def sparse_softmax_cross_entropy(
labels, logits, weights=1.0, scope=None,
loss_collection=ops.GraphKeys.LOSSES,
diff --git a/tensorflow/python/ops/losses/util.py b/tensorflow/python/ops/losses/util.py
index 3718c481c2..b835d96386 100644
--- a/tensorflow/python/ops/losses/util.py
+++ b/tensorflow/python/ops/losses/util.py
@@ -30,8 +30,10 @@ from __future__ import print_function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.util.tf_export import tf_export
+@tf_export("losses.add_loss")
def add_loss(loss, loss_collection=ops.GraphKeys.LOSSES):
"""Adds a externally defined loss to the collection of losses.
@@ -43,6 +45,7 @@ def add_loss(loss, loss_collection=ops.GraphKeys.LOSSES):
ops.add_to_collection(loss_collection, loss)
+@tf_export("losses.get_losses")
def get_losses(scope=None, loss_collection=ops.GraphKeys.LOSSES):
"""Gets the list of losses from the loss_collection.
@@ -56,6 +59,7 @@ def get_losses(scope=None, loss_collection=ops.GraphKeys.LOSSES):
return ops.get_collection(loss_collection, scope)
+@tf_export("losses.get_regularization_losses")
def get_regularization_losses(scope=None):
"""Gets the list of regularization losses.
@@ -68,6 +72,7 @@ def get_regularization_losses(scope=None):
return ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES, scope)
+@tf_export("losses.get_regularization_loss")
def get_regularization_loss(scope=None, name="total_regularization_loss"):
"""Gets the total regularization loss.
@@ -85,6 +90,7 @@ def get_regularization_loss(scope=None, name="total_regularization_loss"):
return constant_op.constant(0.0)
+@tf_export("losses.get_total_loss")
def get_total_loss(add_regularization_losses=True, name="total_loss"):
"""Returns a tensor whose value represents the total loss.
diff --git a/tensorflow/python/platform/app.py b/tensorflow/python/platform/app.py
index 9b92d9a180..cce64c0cca 100644
--- a/tensorflow/python/platform/app.py
+++ b/tensorflow/python/platform/app.py
@@ -23,6 +23,7 @@ import sys as _sys
from tensorflow.python.platform import flags
from tensorflow.python.util.all_util import remove_undocumented
+from tensorflow.python.util.tf_export import tf_export
def _usage(shorthelp):
@@ -108,6 +109,7 @@ def _define_help_flags():
_define_help_flags_called = True
+@tf_export('app.run')
def run(main=None, argv=None):
"""Runs the program with an optional 'main' function and 'argv' list."""
diff --git a/tensorflow/python/platform/resource_loader.py b/tensorflow/python/platform/resource_loader.py
index 2455acb4c0..8f7b12e2b2 100644
--- a/tensorflow/python/platform/resource_loader.py
+++ b/tensorflow/python/platform/resource_loader.py
@@ -29,8 +29,10 @@ import sys as _sys
from tensorflow.python.util import tf_inspect as _inspect
from tensorflow.python.util.all_util import remove_undocumented
+from tensorflow.python.util.tf_export import tf_export
+@tf_export('resource_loader.load_resource')
def load_resource(path):
"""Load the resource at given path, where path is relative to tensorflow/.
@@ -52,6 +54,7 @@ def load_resource(path):
# pylint: disable=protected-access
+@tf_export('resource_loader.get_data_files_path')
def get_data_files_path():
"""Get a direct path to the data files colocated with the script.
@@ -62,6 +65,7 @@ def get_data_files_path():
return _os.path.dirname(_inspect.getfile(_sys._getframe(1)))
+@tf_export('resource_loader.get_root_dir_with_all_resources')
def get_root_dir_with_all_resources():
"""Get a root directory containing all the data attributes in the build rule.
@@ -101,6 +105,7 @@ def get_root_dir_with_all_resources():
return data_files_dir or script_dir
+@tf_export('resource_loader.get_path_to_datafile')
def get_path_to_datafile(path):
"""Get the path to the specified file in the data dependencies.
@@ -120,6 +125,7 @@ def get_path_to_datafile(path):
return _os.path.join(data_files_path, path)
+@tf_export('resource_loader.readahead_file_path')
def readahead_file_path(path, readahead='128M'): # pylint: disable=unused-argument
"""Readahead files not implemented; simply returns given path."""
return path
diff --git a/tensorflow/python/platform/tf_logging.py b/tensorflow/python/platform/tf_logging.py
index 85ed4f071c..22aabfd712 100644
--- a/tensorflow/python/platform/tf_logging.py
+++ b/tensorflow/python/platform/tf_logging.py
@@ -35,6 +35,7 @@ import threading
import six
from tensorflow.python.util.all_util import remove_undocumented
+from tensorflow.python.util.tf_export import tf_export
# Don't use this directly. Use _get_logger() instead.
@@ -90,30 +91,37 @@ def _get_logger():
_logger_lock.release()
+@tf_export('logging.log')
def log(level, msg, *args, **kwargs):
_get_logger().log(level, msg, *args, **kwargs)
+@tf_export('logging.debug')
def debug(msg, *args, **kwargs):
_get_logger().debug(msg, *args, **kwargs)
+@tf_export('logging.error')
def error(msg, *args, **kwargs):
_get_logger().error(msg, *args, **kwargs)
+@tf_export('logging.fatal')
def fatal(msg, *args, **kwargs):
_get_logger().fatal(msg, *args, **kwargs)
+@tf_export('logging.info')
def info(msg, *args, **kwargs):
_get_logger().info(msg, *args, **kwargs)
+@tf_export('logging.warn')
def warn(msg, *args, **kwargs):
_get_logger().warn(msg, *args, **kwargs)
+@tf_export('logging.warning')
def warning(msg, *args, **kwargs):
_get_logger().warning(msg, *args, **kwargs)
@@ -136,15 +144,18 @@ _log_prefix = None # later set to google2_log_prefix
_log_counter_per_token = {}
+@tf_export('logging.TaskLevelStatusMessage')
def TaskLevelStatusMessage(msg):
error(msg)
+@tf_export('logging.flush')
def flush():
raise NotImplementedError()
# Code below is taken from pyglib/logging
+@tf_export('logging.vlog')
def vlog(level, msg, *args, **kwargs):
_get_logger().log(level, msg, *args, **kwargs)
@@ -164,6 +175,7 @@ def _GetNextLogCountPerToken(token):
return _log_counter_per_token[token]
+@tf_export('logging.log_every_n')
def log_every_n(level, msg, n, *args):
"""Log 'msg % args' at level 'level' once per 'n' times.
@@ -180,6 +192,7 @@ def log_every_n(level, msg, n, *args):
log_if(level, msg, not (count % n), *args)
+@tf_export('logging.log_first_n')
def log_first_n(level, msg, n, *args): # pylint: disable=g-bad-name
"""Log 'msg % args' at level 'level' only first 'n' times.
@@ -195,6 +208,7 @@ def log_first_n(level, msg, n, *args): # pylint: disable=g-bad-name
log_if(level, msg, count < n, *args)
+@tf_export('logging.log_if')
def log_if(level, msg, condition, *args):
"""Log 'msg % args' at level 'level' only if condition is fulfilled."""
if condition:
@@ -251,11 +265,13 @@ def google2_log_prefix(level, timestamp=None, file_and_line=None):
return s
+@tf_export('logging.get_verbosity')
def get_verbosity():
"""Return how much logging output will be produced."""
return _get_logger().getEffectiveLevel()
+@tf_export('logging.set_verbosity')
def set_verbosity(v):
"""Sets the threshold for what messages will be logged."""
_get_logger().setLevel(v)
@@ -296,4 +312,10 @@ _allowed_symbols = [
'warning',
]
+tf_export('logging.DEBUG').export_constant(__name__, 'DEBUG')
+tf_export('logging.ERROR').export_constant(__name__, 'ERROR')
+tf_export('logging.FATAL').export_constant(__name__, 'FATAL')
+tf_export('logging.INFO').export_constant(__name__, 'INFO')
+tf_export('logging.WARN').export_constant(__name__, 'WARN')
+
remove_undocumented(__name__, _allowed_symbols)
diff --git a/tensorflow/python/profiler/model_analyzer.py b/tensorflow/python/profiler/model_analyzer.py
index 8f78054560..0e20ca35bb 100644
--- a/tensorflow/python/profiler/model_analyzer.py
+++ b/tensorflow/python/profiler/model_analyzer.py
@@ -33,6 +33,7 @@ from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.profiler import option_builder
from tensorflow.python.profiler import tfprof_logger
+from tensorflow.python.util.tf_export import tf_export
_DEFAULT_PROFILE_OPTIONS = 0
_DEFAULT_ADVISE_OPTIONS = 0
@@ -121,6 +122,7 @@ def _build_advisor_options(options):
return opts
+@tf_export('profiler.Profiler')
class Profiler(object):
"""TensorFlow multi-step profiler.
@@ -304,6 +306,7 @@ class Profiler(object):
print_mdl.WriteProfile(filename)
+@tf_export('profiler.profile')
def profile(graph=None,
run_meta=None,
op_log=None,
@@ -378,6 +381,7 @@ def profile(graph=None,
return tfprof_node
+@tf_export('profiler.advise')
def advise(graph=None, run_meta=None, options=_DEFAULT_ADVISE_OPTIONS):
"""Auto profile and advise.
diff --git a/tensorflow/python/profiler/option_builder.py b/tensorflow/python/profiler/option_builder.py
index 13942ad6a2..957ebe6ddd 100644
--- a/tensorflow/python/profiler/option_builder.py
+++ b/tensorflow/python/profiler/option_builder.py
@@ -20,8 +20,10 @@ from __future__ import print_function
import copy
from tensorflow.python.profiler import tfprof_logger
+from tensorflow.python.util.tf_export import tf_export
+@tf_export('profiler.ProfileOptionBuilder')
class ProfileOptionBuilder(object):
# pylint: disable=line-too-long
"""Option Builder for Profiling API.
diff --git a/tensorflow/python/profiler/tfprof_logger.py b/tensorflow/python/profiler/tfprof_logger.py
index ffda7ddad7..8d12106496 100644
--- a/tensorflow/python/profiler/tfprof_logger.py
+++ b/tensorflow/python/profiler/tfprof_logger.py
@@ -30,6 +30,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.platform import gfile
from tensorflow.python.profiler.internal import flops_registry # pylint: disable=unused-import
+from tensorflow.python.util.tf_export import tf_export
TRAINABLE_VARIABLES = '_trainable_variables'
REGISTERED_FLOP_STATS = 'flops'
@@ -187,6 +188,7 @@ def merge_default_with_oplog(graph, op_log=None, run_meta=None,
return tmp_op_log
+@tf_export('profiler.write_op_log')
def write_op_log(graph, log_dir, op_log=None, run_meta=None, add_trace=True):
"""Log provided 'op_log', and add additional model information below.
diff --git a/tensorflow/python/summary/writer/writer.py b/tensorflow/python/summary/writer/writer.py
index 12f120116f..1f3f228704 100644
--- a/tensorflow/python/summary/writer/writer.py
+++ b/tensorflow/python/summary/writer/writer.py
@@ -32,6 +32,7 @@ from tensorflow.python.platform import gfile
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.summary import plugin_asset
from tensorflow.python.summary.writer.event_file_writer import EventFileWriter
+from tensorflow.python.util.tf_export import tf_export
_PLUGINS_DIR = "plugins"
@@ -276,6 +277,7 @@ class SummaryToEventTransformer(object):
self.event_writer.add_event(event)
+@tf_export("summary.FileWriter")
class FileWriter(SummaryToEventTransformer):
"""Writes `Summary` protocol buffers to event files.
diff --git a/tensorflow/python/summary/writer/writer_cache.py b/tensorflow/python/summary/writer/writer_cache.py
index bad289303c..645fa28a37 100644
--- a/tensorflow/python/summary/writer/writer_cache.py
+++ b/tensorflow/python/summary/writer/writer_cache.py
@@ -22,8 +22,10 @@ import threading
from tensorflow.python.framework import ops
from tensorflow.python.summary.writer.writer import FileWriter
+from tensorflow.python.util.tf_export import tf_export
+@tf_export('summary.FileWriterCache')
class FileWriterCache(object):
"""Cache for file writers.
diff --git a/tensorflow/python/training/adadelta.py b/tensorflow/python/training/adadelta.py
index 13c07cfd7b..c08e3cca00 100644
--- a/tensorflow/python/training/adadelta.py
+++ b/tensorflow/python/training/adadelta.py
@@ -22,8 +22,10 @@ from tensorflow.python.framework import ops
from tensorflow.python.ops import math_ops
from tensorflow.python.training import optimizer
from tensorflow.python.training import training_ops
+from tensorflow.python.util.tf_export import tf_export
+@tf_export("train.AdadeltaOptimizer")
class AdadeltaOptimizer(optimizer.Optimizer):
"""Optimizer that implements the Adadelta algorithm.
diff --git a/tensorflow/python/training/adagrad.py b/tensorflow/python/training/adagrad.py
index afa192f7cc..deb4e6f546 100644
--- a/tensorflow/python/training/adagrad.py
+++ b/tensorflow/python/training/adagrad.py
@@ -25,8 +25,10 @@ from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.training import optimizer
from tensorflow.python.training import training_ops
+from tensorflow.python.util.tf_export import tf_export
+@tf_export("train.AdagradOptimizer")
class AdagradOptimizer(optimizer.Optimizer):
"""Optimizer that implements the Adagrad algorithm.
diff --git a/tensorflow/python/training/adagrad_da.py b/tensorflow/python/training/adagrad_da.py
index b3f9ea323c..5ba403554f 100644
--- a/tensorflow/python/training/adagrad_da.py
+++ b/tensorflow/python/training/adagrad_da.py
@@ -23,8 +23,10 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.training import optimizer
from tensorflow.python.training import training_ops
+from tensorflow.python.util.tf_export import tf_export
+@tf_export("train.AdagradDAOptimizer")
class AdagradDAOptimizer(optimizer.Optimizer):
"""Adagrad Dual Averaging algorithm for sparse linear models.
diff --git a/tensorflow/python/training/adam.py b/tensorflow/python/training/adam.py
index 0c69f8bf39..c92f6fc301 100644
--- a/tensorflow/python/training/adam.py
+++ b/tensorflow/python/training/adam.py
@@ -26,8 +26,10 @@ from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.training import optimizer
from tensorflow.python.training import training_ops
+from tensorflow.python.util.tf_export import tf_export
+@tf_export("train.AdamOptimizer")
class AdamOptimizer(optimizer.Optimizer):
"""Optimizer that implements the Adam algorithm.
diff --git a/tensorflow/python/training/basic_loops.py b/tensorflow/python/training/basic_loops.py
index 52b0f42106..7af821c819 100644
--- a/tensorflow/python/training/basic_loops.py
+++ b/tensorflow/python/training/basic_loops.py
@@ -18,8 +18,10 @@ from __future__ import division
from __future__ import print_function
from tensorflow.python.framework import errors
+from tensorflow.python.util.tf_export import tf_export
+@tf_export("train.basic_train_loop")
def basic_train_loop(supervisor, train_step_fn, args=None,
kwargs=None, master=""):
"""Basic loop to train a model.
diff --git a/tensorflow/python/training/basic_session_run_hooks.py b/tensorflow/python/training/basic_session_run_hooks.py
index 864c2e4406..aae757b99a 100644
--- a/tensorflow/python/training/basic_session_run_hooks.py
+++ b/tensorflow/python/training/basic_session_run_hooks.py
@@ -47,6 +47,7 @@ from tensorflow.python.training import session_run_hook
from tensorflow.python.training import training_util
from tensorflow.python.training.session_run_hook import SessionRunArgs
from tensorflow.python.training.summary_io import SummaryWriterCache
+from tensorflow.python.util.tf_export import tf_export
class _HookTimer(object):
@@ -85,6 +86,7 @@ class _HookTimer(object):
raise NotImplementedError
+@tf_export("train.SecondOrStepTimer")
class SecondOrStepTimer(_HookTimer):
"""Timer that triggers at most once every N seconds or once every N steps.
"""
@@ -164,6 +166,7 @@ class NeverTriggerTimer(_HookTimer):
return None
+@tf_export("train.LoggingTensorHook")
class LoggingTensorHook(session_run_hook.SessionRunHook):
"""Prints the given tensors every N local steps, every N seconds, or at end.
@@ -262,6 +265,7 @@ class LoggingTensorHook(session_run_hook.SessionRunHook):
self._log_tensors(values)
+@tf_export("train.StopAtStepHook")
class StopAtStepHook(session_run_hook.SessionRunHook):
"""Hook that requests stop at a specified step."""
@@ -317,6 +321,7 @@ class StopAtStepHook(session_run_hook.SessionRunHook):
run_context.request_stop()
+@tf_export("train.CheckpointSaverListener")
class CheckpointSaverListener(object):
"""Interface for listeners that take action before or after checkpoint save.
@@ -375,6 +380,7 @@ class CheckpointSaverListener(object):
pass
+@tf_export("train.CheckpointSaverHook")
class CheckpointSaverHook(session_run_hook.SessionRunHook):
"""Saves checkpoints every N steps or seconds."""
@@ -497,6 +503,7 @@ class CheckpointSaverHook(session_run_hook.SessionRunHook):
return savers[0]
+@tf_export("train.StepCounterHook")
class StepCounterHook(session_run_hook.SessionRunHook):
"""Hook that counts steps per second."""
@@ -575,12 +582,14 @@ class StepCounterHook(session_run_hook.SessionRunHook):
self._last_global_step = stale_global_step
+@tf_export("train.NanLossDuringTrainingError")
class NanLossDuringTrainingError(RuntimeError):
def __str__(self):
return "NaN loss during training."
+@tf_export("train.NanTensorHook")
class NanTensorHook(session_run_hook.SessionRunHook):
"""Monitors the loss tensor and stops training if loss is NaN.
@@ -612,6 +621,7 @@ class NanTensorHook(session_run_hook.SessionRunHook):
run_context.request_stop()
+@tf_export("train.SummarySaverHook")
class SummarySaverHook(session_run_hook.SessionRunHook):
"""Saves summaries every N steps."""
@@ -720,6 +730,7 @@ class SummarySaverHook(session_run_hook.SessionRunHook):
return summary_op
+@tf_export("train.GlobalStepWaiterHook")
class GlobalStepWaiterHook(session_run_hook.SessionRunHook):
"""Delays execution until global step reaches `wait_until_step`.
@@ -767,6 +778,7 @@ class GlobalStepWaiterHook(session_run_hook.SessionRunHook):
time.sleep(0.5)
+@tf_export("train.FinalOpsHook")
class FinalOpsHook(session_run_hook.SessionRunHook):
"""A hook which evaluates `Tensors` at the end of a session."""
@@ -793,6 +805,7 @@ class FinalOpsHook(session_run_hook.SessionRunHook):
feed_dict=self._final_ops_feed_dict)
+@tf_export("train.FeedFnHook")
class FeedFnHook(session_run_hook.SessionRunHook):
"""Runs `feed_fn` and sets the `feed_dict` accordingly."""
@@ -810,6 +823,7 @@ class FeedFnHook(session_run_hook.SessionRunHook):
fetches=None, feed_dict=self.feed_fn())
+@tf_export("train.ProfilerHook")
class ProfilerHook(session_run_hook.SessionRunHook):
"""Captures CPU/GPU profiling information every N steps or seconds.
diff --git a/tensorflow/python/training/checkpoint_utils.py b/tensorflow/python/training/checkpoint_utils.py
index 63235a1454..fa3de6fad2 100644
--- a/tensorflow/python/training/checkpoint_utils.py
+++ b/tensorflow/python/training/checkpoint_utils.py
@@ -29,6 +29,7 @@ from tensorflow.python.ops import variables
from tensorflow.python.platform import gfile
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import saver
+from tensorflow.python.util.tf_export import tf_export
__all__ = [
@@ -36,6 +37,7 @@ __all__ = [
]
+@tf_export("train.load_checkpoint")
def load_checkpoint(ckpt_dir_or_file):
"""Returns `CheckpointReader` for checkpoint found in `ckpt_dir_or_file`.
@@ -60,6 +62,7 @@ def load_checkpoint(ckpt_dir_or_file):
return pywrap_tensorflow.NewCheckpointReader(filename)
+@tf_export("train.load_variable")
def load_variable(ckpt_dir_or_file, name):
"""Returns the tensor value of the given variable in the checkpoint.
@@ -77,6 +80,7 @@ def load_variable(ckpt_dir_or_file, name):
return reader.get_tensor(name)
+@tf_export("train.list_variables")
def list_variables(ckpt_dir_or_file):
"""Returns list of all variables in the checkpoint.
@@ -95,6 +99,7 @@ def list_variables(ckpt_dir_or_file):
return result
+@tf_export("train.init_from_checkpoint")
def init_from_checkpoint(ckpt_dir_or_file, assignment_map):
"""Initializes current variables with tensors loaded from given checkpoint.
diff --git a/tensorflow/python/training/coordinator.py b/tensorflow/python/training/coordinator.py
index 0e31255b74..0ff97d85e3 100644
--- a/tensorflow/python/training/coordinator.py
+++ b/tensorflow/python/training/coordinator.py
@@ -27,8 +27,10 @@ import six
from tensorflow.python.framework import errors
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import compat
+from tensorflow.python.util.tf_export import tf_export
+@tf_export("train.Coordinator")
class Coordinator(object):
"""A coordinator for threads.
@@ -406,6 +408,7 @@ class Coordinator(object):
# Threads for the standard services.
+@tf_export("train.LooperThread")
class LooperThread(threading.Thread):
"""A thread that runs code repeatedly, optionally on a timer.
diff --git a/tensorflow/python/training/device_setter.py b/tensorflow/python/training/device_setter.py
index 37ab625779..689088bb41 100644
--- a/tensorflow/python/training/device_setter.py
+++ b/tensorflow/python/training/device_setter.py
@@ -23,6 +23,7 @@ from tensorflow.core.framework import node_def_pb2
from tensorflow.python.framework import device as pydev
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import server_lib
+from tensorflow.python.util.tf_export import tf_export
class _RoundRobinStrategy(object):
@@ -121,6 +122,7 @@ class _ReplicaDeviceChooser(object):
return worker_device.to_string()
+@tf_export("train.replica_device_setter")
def replica_device_setter(ps_tasks=0, ps_device="/job:ps",
worker_device="/job:worker", merge_devices=True,
cluster=None, ps_ops=None, ps_strategy=None):
diff --git a/tensorflow/python/training/ftrl.py b/tensorflow/python/training/ftrl.py
index c64a1b3f79..9d02e694db 100644
--- a/tensorflow/python/training/ftrl.py
+++ b/tensorflow/python/training/ftrl.py
@@ -22,8 +22,10 @@ from tensorflow.python.framework import ops
from tensorflow.python.ops import math_ops
from tensorflow.python.training import optimizer
from tensorflow.python.training import training_ops
+from tensorflow.python.util.tf_export import tf_export
+@tf_export("train.FtrlOptimizer")
class FtrlOptimizer(optimizer.Optimizer):
"""Optimizer that implements the FTRL algorithm.
@@ -265,4 +267,3 @@ class FtrlOptimizer(optimizer.Optimizer):
grad.dtype),
math_ops.cast(self._learning_rate_power_tensor, grad.dtype),
use_locking=self._use_locking)
-
diff --git a/tensorflow/python/training/gradient_descent.py b/tensorflow/python/training/gradient_descent.py
index 5a536e2729..380e14e024 100644
--- a/tensorflow/python/training/gradient_descent.py
+++ b/tensorflow/python/training/gradient_descent.py
@@ -23,8 +23,10 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.training import optimizer
from tensorflow.python.training import training_ops
+from tensorflow.python.util.tf_export import tf_export
+@tf_export("train.GradientDescentOptimizer")
class GradientDescentOptimizer(optimizer.Optimizer):
"""Optimizer that implements the gradient descent algorithm.
"""
diff --git a/tensorflow/python/training/input.py b/tensorflow/python/training/input.py
index 7160420a33..bd9985a7c5 100644
--- a/tensorflow/python/training/input.py
+++ b/tensorflow/python/training/input.py
@@ -44,6 +44,7 @@ from tensorflow.python.ops import sparse_ops
from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.summary import summary
from tensorflow.python.training import queue_runner
+from tensorflow.python.util.tf_export import tf_export
# pylint: disable=protected-access
@@ -53,6 +54,7 @@ _restore_sparse = sparse_ops._take_many_sparse_from_tensors_map
# pylint: enable=protected-access
+@tf_export("train.match_filenames_once")
def match_filenames_once(pattern, name=None):
"""Save the list of files matching pattern, so it is only computed once.
@@ -72,6 +74,7 @@ def match_filenames_once(pattern, name=None):
collections=[ops.GraphKeys.LOCAL_VARIABLES])
+@tf_export("train.limit_epochs")
def limit_epochs(tensor, num_epochs=None, name=None):
"""Returns tensor `num_epochs` times and then raises an `OutOfRange` error.
@@ -104,6 +107,7 @@ def limit_epochs(tensor, num_epochs=None, name=None):
return array_ops.identity(tensor, name=name)
+@tf_export("train.input_producer")
def input_producer(input_tensor,
element_shape=None,
num_epochs=None,
@@ -186,6 +190,7 @@ def input_producer(input_tensor,
return q
+@tf_export("train.string_input_producer")
def string_input_producer(string_tensor,
num_epochs=None,
shuffle=True,
@@ -255,6 +260,7 @@ def string_input_producer(string_tensor,
cancel_op=cancel_op)
+@tf_export("train.range_input_producer")
def range_input_producer(limit, num_epochs=None, shuffle=True, seed=None,
capacity=32, shared_name=None, name=None):
"""Produces the integers from 0 to limit-1 in a queue.
@@ -292,6 +298,7 @@ def range_input_producer(limit, num_epochs=None, shuffle=True, seed=None,
shared_name, "fraction_of_%d_full" % capacity, name)
+@tf_export("train.slice_input_producer")
def slice_input_producer(tensor_list, num_epochs=None, shuffle=True, seed=None,
capacity=32, shared_name=None, name=None):
"""Produces a slice of each `Tensor` in `tensor_list`.
@@ -887,6 +894,7 @@ def _shuffle_batch_join(tensors_list, batch_size, capacity,
# Batching functions ----------------------------------------------------------
+@tf_export("train.batch")
def batch(tensors, batch_size, num_threads=1, capacity=32,
enqueue_many=False, shapes=None, dynamic_pad=False,
allow_smaller_final_batch=False, shared_name=None, name=None):
@@ -981,6 +989,7 @@ def batch(tensors, batch_size, num_threads=1, capacity=32,
name=name)
+@tf_export("train.maybe_batch")
def maybe_batch(tensors, keep_input, batch_size, num_threads=1, capacity=32,
enqueue_many=False, shapes=None, dynamic_pad=False,
allow_smaller_final_batch=False, shared_name=None, name=None):
@@ -1033,6 +1042,7 @@ def maybe_batch(tensors, keep_input, batch_size, num_threads=1, capacity=32,
name=name)
+@tf_export("train.batch_join")
def batch_join(tensors_list, batch_size, capacity=32, enqueue_many=False,
shapes=None, dynamic_pad=False, allow_smaller_final_batch=False,
shared_name=None, name=None):
@@ -1138,6 +1148,7 @@ def batch_join(tensors_list, batch_size, capacity=32, enqueue_many=False,
name=name)
+@tf_export("train.maybe_batch_join")
def maybe_batch_join(tensors_list, keep_input, batch_size, capacity=32,
enqueue_many=False, shapes=None, dynamic_pad=False,
allow_smaller_final_batch=False, shared_name=None,
@@ -1190,6 +1201,7 @@ def maybe_batch_join(tensors_list, keep_input, batch_size, capacity=32,
name=name)
+@tf_export("train.shuffle_batch")
def shuffle_batch(tensors, batch_size, capacity, min_after_dequeue,
num_threads=1, seed=None, enqueue_many=False, shapes=None,
allow_smaller_final_batch=False, shared_name=None, name=None):
@@ -1289,6 +1301,7 @@ def shuffle_batch(tensors, batch_size, capacity, min_after_dequeue,
name=name)
+@tf_export("train.maybe_shuffle_batch")
def maybe_shuffle_batch(tensors, batch_size, capacity, min_after_dequeue,
keep_input, num_threads=1, seed=None,
enqueue_many=False, shapes=None,
@@ -1348,6 +1361,7 @@ def maybe_shuffle_batch(tensors, batch_size, capacity, min_after_dequeue,
name=name)
+@tf_export("train.shuffle_batch_join")
def shuffle_batch_join(tensors_list, batch_size, capacity,
min_after_dequeue, seed=None, enqueue_many=False,
shapes=None, allow_smaller_final_batch=False,
@@ -1441,6 +1455,7 @@ def shuffle_batch_join(tensors_list, batch_size, capacity,
name=name)
+@tf_export("train.maybe_shuffle_batch_join")
def maybe_shuffle_batch_join(tensors_list, batch_size, capacity,
min_after_dequeue, keep_input, seed=None,
enqueue_many=False, shapes=None,
diff --git a/tensorflow/python/training/learning_rate_decay.py b/tensorflow/python/training/learning_rate_decay.py
index 343a49cded..10ab4c1137 100644
--- a/tensorflow/python/training/learning_rate_decay.py
+++ b/tensorflow/python/training/learning_rate_decay.py
@@ -25,8 +25,10 @@ from tensorflow.python.framework import ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
+from tensorflow.python.util.tf_export import tf_export
+@tf_export("train.exponential_decay")
def exponential_decay(learning_rate,
global_step,
decay_steps,
@@ -103,6 +105,7 @@ def exponential_decay(learning_rate,
learning_rate, math_ops.pow(decay_rate, p), name=name)
+@tf_export("train.piecewise_constant")
def piecewise_constant(x, boundaries, values, name=None):
"""Piecewise constant from boundaries and interval values.
@@ -182,6 +185,7 @@ def piecewise_constant(x, boundaries, values, name=None):
return control_flow_ops.case(pred_fn_pairs, default, exclusive=True)
+@tf_export("train.polynomial_decay")
def polynomial_decay(learning_rate,
global_step,
decay_steps,
@@ -291,6 +295,7 @@ def polynomial_decay(learning_rate,
name=name)
+@tf_export("train.natural_exp_decay")
def natural_exp_decay(learning_rate,
global_step,
decay_steps,
@@ -362,6 +367,7 @@ def natural_exp_decay(learning_rate,
return math_ops.multiply(learning_rate, exponent, name=name)
+@tf_export("train.inverse_time_decay")
def inverse_time_decay(learning_rate,
global_step,
decay_steps,
@@ -444,6 +450,7 @@ def inverse_time_decay(learning_rate,
return math_ops.div(learning_rate, denom, name=name)
+@tf_export("train.cosine_decay")
def cosine_decay(learning_rate, global_step, decay_steps, alpha=0.0, name=None):
"""Applies cosine decay to the learning rate.
@@ -503,6 +510,7 @@ def cosine_decay(learning_rate, global_step, decay_steps, alpha=0.0, name=None):
return math_ops.multiply(learning_rate, decayed)
+@tf_export("train.cosine_decay_restarts")
def cosine_decay_restarts(learning_rate,
global_step,
first_decay_steps,
@@ -596,6 +604,7 @@ def cosine_decay_restarts(learning_rate,
return math_ops.multiply(learning_rate, decayed, name=name)
+@tf_export("train.linear_cosine_decay")
def linear_cosine_decay(learning_rate,
global_step,
decay_steps,
@@ -679,6 +688,7 @@ def linear_cosine_decay(learning_rate,
return math_ops.multiply(learning_rate, linear_cosine_decayed, name=name)
+@tf_export("train.noisy_linear_cosine_decay")
def noisy_linear_cosine_decay(learning_rate,
global_step,
decay_steps,
diff --git a/tensorflow/python/training/momentum.py b/tensorflow/python/training/momentum.py
index cf9530d87c..bd9fa79d8f 100644
--- a/tensorflow/python/training/momentum.py
+++ b/tensorflow/python/training/momentum.py
@@ -22,8 +22,10 @@ from tensorflow.python.framework import ops
from tensorflow.python.ops import math_ops
from tensorflow.python.training import optimizer
from tensorflow.python.training import training_ops
+from tensorflow.python.util.tf_export import tf_export
+@tf_export("train.MomentumOptimizer")
class MomentumOptimizer(optimizer.Optimizer):
"""Optimizer that implements the Momentum algorithm.
diff --git a/tensorflow/python/training/monitored_session.py b/tensorflow/python/training/monitored_session.py
index fa3517db27..6c5c9e01a7 100644
--- a/tensorflow/python/training/monitored_session.py
+++ b/tensorflow/python/training/monitored_session.py
@@ -41,6 +41,7 @@ from tensorflow.python.training import queue_runner
from tensorflow.python.training import saver as training_saver
from tensorflow.python.training import session_manager as sm
from tensorflow.python.training import session_run_hook
+from tensorflow.python.util.tf_export import tf_export
# The list of exceptions that we should recover from. Exceptions not in this
@@ -52,6 +53,7 @@ _PREEMPTION_ERRORS = (errors.AbortedError, errors.UnavailableError)
USE_DEFAULT = object()
+@tf_export('train.Scaffold')
class Scaffold(object):
"""Structure to create or gather pieces commonly needed to train a model.
@@ -272,6 +274,7 @@ class Scaffold(object):
resources.initialize_resources(resources.local_resources()))
+@tf_export('train.MonitoredTrainingSession')
def MonitoredTrainingSession(master='', # pylint: disable=invalid-name
is_chief=True,
checkpoint_dir=None,
@@ -381,6 +384,7 @@ def MonitoredTrainingSession(master='', # pylint: disable=invalid-name
stop_grace_period_secs=stop_grace_period_secs)
+@tf_export('train.SessionCreator')
class SessionCreator(object):
"""A factory for tf.Session."""
@@ -390,6 +394,7 @@ class SessionCreator(object):
'create_session is not implemented for {}.'.format(self))
+@tf_export('train.ChiefSessionCreator')
class ChiefSessionCreator(SessionCreator):
"""Creates a tf.Session for a chief."""
@@ -441,6 +446,7 @@ class ChiefSessionCreator(SessionCreator):
init_fn=self._scaffold.init_fn)
+@tf_export('train.WorkerSessionCreator')
class WorkerSessionCreator(SessionCreator):
"""Creates a tf.Session for a worker."""
@@ -706,6 +712,7 @@ class _MonitoredSession(object):
return self._coordinated_creator.tf_sess
+@tf_export('train.MonitoredSession')
class MonitoredSession(_MonitoredSession):
"""Session-like object that handles initialization, recovery and hooks.
@@ -788,6 +795,7 @@ class MonitoredSession(_MonitoredSession):
stop_grace_period_secs=stop_grace_period_secs)
+@tf_export('train.SingularMonitoredSession')
class SingularMonitoredSession(_MonitoredSession):
"""Session-like object that handles initialization, restoring, and hooks.
diff --git a/tensorflow/python/training/moving_averages.py b/tensorflow/python/training/moving_averages.py
index 43ed1ac170..2d89082ad7 100644
--- a/tensorflow/python/training/moving_averages.py
+++ b/tensorflow/python/training/moving_averages.py
@@ -26,6 +26,7 @@ from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.training import slot_creator
+from tensorflow.python.util.tf_export import tf_export
# TODO(touts): switch to variables.Variable.
@@ -230,6 +231,7 @@ def _zero_debias(unbiased_var, value, decay):
return unbiased_ema_delta
+@tf_export("train.ExponentialMovingAverage")
class ExponentialMovingAverage(object):
"""Maintains moving averages of variables by employing an exponential decay.
diff --git a/tensorflow/python/training/optimizer.py b/tensorflow/python/training/optimizer.py
index a06b3eada6..425dbd8313 100644
--- a/tensorflow/python/training/optimizer.py
+++ b/tensorflow/python/training/optimizer.py
@@ -36,6 +36,7 @@ from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.training import slot_creator
from tensorflow.python.util import nest
+from tensorflow.python.util.tf_export import tf_export
def _get_variable_for(v):
@@ -187,6 +188,7 @@ def _get_processor(v):
raise NotImplementedError("Trying to optimize unsupported type ", v)
+@tf_export("train.Optimizer")
class Optimizer(object):
"""Base class for optimizers.
diff --git a/tensorflow/python/training/proximal_adagrad.py b/tensorflow/python/training/proximal_adagrad.py
index da31ab325d..9bd677b8ef 100644
--- a/tensorflow/python/training/proximal_adagrad.py
+++ b/tensorflow/python/training/proximal_adagrad.py
@@ -23,8 +23,10 @@ from tensorflow.python.framework import ops
from tensorflow.python.ops import math_ops
from tensorflow.python.training import optimizer
from tensorflow.python.training import training_ops
+from tensorflow.python.util.tf_export import tf_export
+@tf_export("train.ProximalAdagradOptimizer")
class ProximalAdagradOptimizer(optimizer.Optimizer):
# pylint: disable=line-too-long
"""Optimizer that implements the Proximal Adagrad algorithm.
diff --git a/tensorflow/python/training/proximal_gradient_descent.py b/tensorflow/python/training/proximal_gradient_descent.py
index 53e9dc2ef2..369b6cbb50 100644
--- a/tensorflow/python/training/proximal_gradient_descent.py
+++ b/tensorflow/python/training/proximal_gradient_descent.py
@@ -24,8 +24,10 @@ from tensorflow.python.ops import math_ops
# pylint: enable=unused-import
from tensorflow.python.training import optimizer
from tensorflow.python.training import training_ops
+from tensorflow.python.util.tf_export import tf_export
+@tf_export("train.ProximalGradientDescentOptimizer")
class ProximalGradientDescentOptimizer(optimizer.Optimizer):
# pylint: disable=line-too-long
"""Optimizer that implements the proximal gradient descent algorithm.
diff --git a/tensorflow/python/training/queue_runner_impl.py b/tensorflow/python/training/queue_runner_impl.py
index 4e7c81d7b2..07afba79ab 100644
--- a/tensorflow/python/training/queue_runner_impl.py
+++ b/tensorflow/python/training/queue_runner_impl.py
@@ -27,8 +27,10 @@ from tensorflow.python.eager import context
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.util.tf_export import tf_export
+@tf_export("train.queue_runner.QueueRunner", "train.QueueRunner")
class QueueRunner(object):
"""Holds a list of enqueue operations for a queue, each to be run in a thread.
@@ -384,6 +386,7 @@ class QueueRunner(object):
import_scope=import_scope)
+@tf_export("train.queue_runner.add_queue_runner", "train.add_queue_runner")
def add_queue_runner(qr, collection=ops.GraphKeys.QUEUE_RUNNERS):
"""Adds a `QueueRunner` to a collection in the graph.
@@ -402,6 +405,8 @@ def add_queue_runner(qr, collection=ops.GraphKeys.QUEUE_RUNNERS):
ops.add_to_collection(collection, qr)
+@tf_export("train.queue_runner.start_queue_runners",
+ "train.start_queue_runners")
def start_queue_runners(sess=None, coord=None, daemon=True, start=True,
collection=ops.GraphKeys.QUEUE_RUNNERS):
"""Starts all queue runners collected in the graph.
diff --git a/tensorflow/python/training/rmsprop.py b/tensorflow/python/training/rmsprop.py
index 745e612018..89d1099a49 100644
--- a/tensorflow/python/training/rmsprop.py
+++ b/tensorflow/python/training/rmsprop.py
@@ -46,8 +46,10 @@ from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.training import optimizer
from tensorflow.python.training import training_ops
+from tensorflow.python.util.tf_export import tf_export
+@tf_export("train.RMSPropOptimizer")
class RMSPropOptimizer(optimizer.Optimizer):
"""Optimizer that implements the RMSProp algorithm.
diff --git a/tensorflow/python/training/saver.py b/tensorflow/python/training/saver.py
index 554472e043..764f840012 100644
--- a/tensorflow/python/training/saver.py
+++ b/tensorflow/python/training/saver.py
@@ -53,6 +53,7 @@ from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import training_util
from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState
from tensorflow.python.util import compat
+from tensorflow.python.util.tf_export import tf_export
# Op names which identify variable reads which should be saved.
@@ -889,6 +890,7 @@ def _GetCheckpointFilename(save_dir, latest_filename):
return os.path.join(save_dir, latest_filename)
+@tf_export("train.generate_checkpoint_state_proto")
def generate_checkpoint_state_proto(save_dir,
model_checkpoint_path,
all_model_checkpoint_paths=None):
@@ -933,6 +935,7 @@ def generate_checkpoint_state_proto(save_dir,
return coord_checkpoint_proto
+@tf_export("train.update_checkpoint_state")
def update_checkpoint_state(save_dir,
model_checkpoint_path,
all_model_checkpoint_paths=None,
@@ -1025,6 +1028,7 @@ def _update_checkpoint_state(save_dir,
text_format.MessageToString(ckpt))
+@tf_export("train.get_checkpoint_state")
def get_checkpoint_state(checkpoint_dir, latest_filename=None):
"""Returns CheckpointState proto from the "checkpoint" file.
@@ -1082,6 +1086,7 @@ def get_checkpoint_state(checkpoint_dir, latest_filename=None):
return ckpt
+@tf_export("train.Saver")
class Saver(object):
"""Saves and restores variables.
@@ -1788,6 +1793,7 @@ def _prefix_to_checkpoint_path(prefix, format_version):
return prefix # Just the data file.
+@tf_export("train.latest_checkpoint")
def latest_checkpoint(checkpoint_dir, latest_filename=None):
"""Finds the filename of latest saved checkpoint file.
@@ -1817,6 +1823,7 @@ def latest_checkpoint(checkpoint_dir, latest_filename=None):
return None
+@tf_export("train.import_meta_graph")
def import_meta_graph(meta_graph_or_file, clear_devices=False,
import_scope=None, **kwargs):
"""Recreates a Graph saved in a `MetaGraphDef` proto.
@@ -1918,6 +1925,7 @@ def import_meta_graph(meta_graph_or_file, clear_devices=False,
return None
+@tf_export("train.export_meta_graph")
def export_meta_graph(filename=None,
meta_info_def=None,
graph_def=None,
@@ -1994,6 +2002,7 @@ def export_meta_graph(filename=None,
return meta_graph_def
+@tf_export("train.checkpoint_exists")
def checkpoint_exists(checkpoint_prefix):
"""Checks whether a V1 or V2 checkpoint exists with the specified prefix.
@@ -2018,6 +2027,7 @@ def checkpoint_exists(checkpoint_prefix):
return False
+@tf_export("train.get_checkpoint_mtimes")
def get_checkpoint_mtimes(checkpoint_prefixes):
"""Returns the mtimes (modification timestamps) of the checkpoints.
diff --git a/tensorflow/python/training/server_lib.py b/tensorflow/python/training/server_lib.py
index 29da67a30a..2f421d1cc0 100644
--- a/tensorflow/python/training/server_lib.py
+++ b/tensorflow/python/training/server_lib.py
@@ -23,6 +23,7 @@ from tensorflow.core.protobuf import tensorflow_server_pb2
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.framework import errors
from tensorflow.python.util import compat
+from tensorflow.python.util.tf_export import tf_export
def _make_server_def(server_or_cluster_def, job_name, task_index, protocol,
@@ -92,6 +93,7 @@ def _make_server_def(server_or_cluster_def, job_name, task_index, protocol,
return server_def
+@tf_export("train.Server")
class Server(object):
"""An in-process TensorFlow server, for use in distributed training.
@@ -221,6 +223,7 @@ class Server(object):
start=start)
+@tf_export("train.ClusterSpec")
class ClusterSpec(object):
"""Represents a cluster as a set of "tasks", organized into "jobs".
diff --git a/tensorflow/python/training/session_manager.py b/tensorflow/python/training/session_manager.py
index b396a1e7d0..360e02fb44 100644
--- a/tensorflow/python/training/session_manager.py
+++ b/tensorflow/python/training/session_manager.py
@@ -25,6 +25,7 @@ from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import saver as saver_mod
+from tensorflow.python.util.tf_export import tf_export
def _maybe_name(obj):
@@ -44,6 +45,7 @@ def _maybe_name(obj):
return "<no name for %s>" % type(obj)
+@tf_export("train.SessionManager")
class SessionManager(object):
"""Training helper that restores from checkpoint and creates session.
diff --git a/tensorflow/python/training/session_run_hook.py b/tensorflow/python/training/session_run_hook.py
index 5b023d8a26..89f4030065 100644
--- a/tensorflow/python/training/session_run_hook.py
+++ b/tensorflow/python/training/session_run_hook.py
@@ -96,8 +96,10 @@ from __future__ import division
from __future__ import print_function
import collections
+from tensorflow.python.util.tf_export import tf_export
+@tf_export("train.SessionRunHook")
class SessionRunHook(object):
"""Hook to extend calls to MonitoredSession.run()."""
@@ -189,6 +191,7 @@ class SessionRunHook(object):
pass
+@tf_export("train.SessionRunArgs")
class SessionRunArgs(
collections.namedtuple("SessionRunArgs",
["fetches", "feed_dict", "options"])):
@@ -213,6 +216,7 @@ class SessionRunArgs(
return super(SessionRunArgs, cls).__new__(cls, fetches, feed_dict, options)
+@tf_export("train.SessionRunContext")
class SessionRunContext(object):
"""Provides information about the `session.run()` call being made.
@@ -264,6 +268,7 @@ class SessionRunContext(object):
self._stop_requested = True
+@tf_export("train.SessionRunValues")
class SessionRunValues(
collections.namedtuple("SessionRunValues",
["results", "options", "run_metadata"])):
diff --git a/tensorflow/python/training/supervisor.py b/tensorflow/python/training/supervisor.py
index e4514aaea2..d2ad34773e 100644
--- a/tensorflow/python/training/supervisor.py
+++ b/tensorflow/python/training/supervisor.py
@@ -37,8 +37,10 @@ from tensorflow.python.training import saver as saver_mod
from tensorflow.python.training import session_manager as session_manager_mod
from tensorflow.python.training import training_util
from tensorflow.python.util import deprecation
+from tensorflow.python.util.tf_export import tf_export
+@tf_export("train.Supervisor")
class Supervisor(object):
"""A training helper that checkpoints models and computes summaries.
diff --git a/tensorflow/python/training/sync_replicas_optimizer.py b/tensorflow/python/training/sync_replicas_optimizer.py
index 47702fdad0..0c6cf910d1 100644
--- a/tensorflow/python/training/sync_replicas_optimizer.py
+++ b/tensorflow/python/training/sync_replicas_optimizer.py
@@ -31,6 +31,7 @@ from tensorflow.python.training import optimizer
from tensorflow.python.training import queue_runner
from tensorflow.python.training import session_manager
from tensorflow.python.training import session_run_hook
+from tensorflow.python.util.tf_export import tf_export
# Please note that the gradients from replicas are averaged instead of summed
@@ -38,6 +39,7 @@ from tensorflow.python.training import session_run_hook
# rate according to the number of replicas. This change is introduced to be
# consistent with how gradients are aggregated (averaged) within a batch in a
# replica.
+@tf_export("train.SyncReplicasOptimizer")
class SyncReplicasOptimizer(optimizer.Optimizer):
"""Class to synchronize, aggregate gradients and pass them to the optimizer.
diff --git a/tensorflow/python/training/training_util.py b/tensorflow/python/training/training_util.py
index 89a9e12932..499f1feb2d 100644
--- a/tensorflow/python/training/training_util.py
+++ b/tensorflow/python/training/training_util.py
@@ -29,6 +29,7 @@ from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.util.tf_export import tf_export
# Picked a long key value to minimize the chance of collision with user defined
@@ -40,6 +41,7 @@ GLOBAL_STEP_READ_KEY = 'global_step_read_op_cache'
write_graph = graph_io.write_graph
+@tf_export('train.global_step')
def global_step(sess, global_step_tensor):
"""Small helper to get the global step.
@@ -67,6 +69,7 @@ def global_step(sess, global_step_tensor):
return int(sess.run(global_step_tensor))
+@tf_export('train.get_global_step')
def get_global_step(graph=None):
"""Get the global step tensor.
@@ -101,6 +104,7 @@ def get_global_step(graph=None):
return global_step_tensor
+@tf_export('train.create_global_step')
def create_global_step(graph=None):
"""Create global step tensor in graph.
@@ -139,6 +143,7 @@ def create_global_step(graph=None):
ops.GraphKeys.GLOBAL_STEP])
+@tf_export('train.get_or_create_global_step')
def get_or_create_global_step(graph=None):
"""Returns and create (if necessary) the global step tensor.
@@ -156,6 +161,7 @@ def get_or_create_global_step(graph=None):
return global_step_tensor
+@tf_export('train.assert_global_step')
def assert_global_step(global_step_tensor):
"""Asserts `global_step_tensor` is a scalar int `Variable` or `Tensor`.
diff --git a/tensorflow/python/util/compat.py b/tensorflow/python/util/compat.py
index 270d96a3c7..7e5f192b8f 100644
--- a/tensorflow/python/util/compat.py
+++ b/tensorflow/python/util/compat.py
@@ -41,8 +41,10 @@ import numpy as _np
import six as _six
from tensorflow.python.util.all_util import remove_undocumented
+from tensorflow.python.util.tf_export import tf_export
+@tf_export('compat.as_bytes', 'compat.as_str')
def as_bytes(bytes_or_text, encoding='utf-8'):
"""Converts either bytes or unicode to `bytes`, using utf-8 encoding for text.
@@ -65,6 +67,7 @@ def as_bytes(bytes_or_text, encoding='utf-8'):
(bytes_or_text,))
+@tf_export('compat.as_text')
def as_text(bytes_or_text, encoding='utf-8'):
"""Returns the given argument as a unicode string.
@@ -93,6 +96,7 @@ else:
as_str = as_text
+@tf_export('compat.as_str_any')
def as_str_any(value):
"""Converts to `str` as `str(value)`, but use `as_str` for `bytes`.
@@ -125,11 +129,16 @@ def path_to_str(path):
# Numpy 1.8 scalars don't inherit from numbers.Integral in Python 3, so we
# need to check them specifically. The same goes from Real and Complex.
integral_types = (_numbers.Integral, _np.integer)
+tf_export('compat.integral_types').export_constant(__name__, 'integral_types')
real_types = (_numbers.Real, _np.integer, _np.floating)
+tf_export('compat.real_types').export_constant(__name__, 'real_types')
complex_types = (_numbers.Complex, _np.number)
+tf_export('compat.complex_types').export_constant(__name__, 'complex_types')
# Either bytes or text.
bytes_or_text_types = (bytes, _six.text_type)
+tf_export('compat.bytes_or_text_types').export_constant(__name__,
+ 'bytes_or_text_types')
_allowed_symbols = [
'as_str',
diff --git a/tensorflow/tools/api/generator/BUILD b/tensorflow/tools/api/generator/BUILD
index d110316395..66bbd572a6 100644
--- a/tensorflow/tools/api/generator/BUILD
+++ b/tensorflow/tools/api/generator/BUILD
@@ -77,6 +77,16 @@ genrule(
"api/nn/rnn_cell/__init__.py",
"api/sets/__init__.py",
"api/summary/__init__.py",
+ "api/train/queue_runner/__init__.py",
+ "api/compat/__init__.py",
+ "api/data/__init__.py",
+ "api/estimator/__init__.py",
+ "api/estimator/export/__init__.py",
+ "api/estimator/inputs/__init__.py",
+ "api/feature_column/__init__.py",
+ "api/losses/__init__.py",
+ "api/profiler/__init__.py",
+ "api/python_io/__init__.py",
],
cmd = "$(location create_python_api) $(OUTS)",
tools = ["create_python_api"],
diff --git a/tensorflow/tools/ci_build/ci_sanity.sh b/tensorflow/tools/ci_build/ci_sanity.sh
index 106ea19d46..a58db51cb8 100755
--- a/tensorflow/tools/ci_build/ci_sanity.sh
+++ b/tensorflow/tools/ci_build/ci_sanity.sh
@@ -517,9 +517,14 @@ do_check_futures_test() {
python check_futures_test.py
}
+do_check_file_name_test() {
+ cd "$ROOT_DIR/tensorflow/tools/test"
+ python file_name_test.py
+}
+
# Supply all sanity step commands and descriptions
-SANITY_STEPS=("do_pylint PYTHON2" "do_pylint PYTHON3" "do_check_futures_test" "do_buildifier" "do_bazel_nobuild" "do_pip_package_licenses_check" "do_lib_package_licenses_check" "do_java_package_licenses_check" "do_pip_smoke_test" "do_check_load_py_test" "do_code_link_check" "do_cmake_python_sanity")
-SANITY_STEPS_DESC=("Python 2 pylint" "Python 3 pylint" "Check that python files have certain __future__ imports" "buildifier check" "bazel nobuild" "pip: license check for external dependencies" "C library: license check for external dependencies" "Java Native Library: license check for external dependencies" "Pip Smoke Test: Checking py_test dependencies exist in pip package" "Check load py_test: Check that BUILD files with py_test target properly load py_test" "Code Link Check: Check there are no broken links" "Test entries in /tensorflow/contrib/cmake/python_{modules|protos|protos_cc}.txt for validity and consistency")
+SANITY_STEPS=("do_pylint PYTHON2" "do_pylint PYTHON3" "do_check_futures_test" "do_buildifier" "do_bazel_nobuild" "do_pip_package_licenses_check" "do_lib_package_licenses_check" "do_java_package_licenses_check" "do_pip_smoke_test" "do_check_load_py_test" "do_code_link_check" "do_cmake_python_sanity" "do_check_file_name_test")
+SANITY_STEPS_DESC=("Python 2 pylint" "Python 3 pylint" "Check that python files have certain __future__ imports" "buildifier check" "bazel nobuild" "pip: license check for external dependencies" "C library: license check for external dependencies" "Java Native Library: license check for external dependencies" "Pip Smoke Test: Checking py_test dependencies exist in pip package" "Check load py_test: Check that BUILD files with py_test target properly load py_test" "Code Link Check: Check there are no broken links" "Test entries in /tensorflow/contrib/cmake/python_{modules|protos|protos_cc}.txt for validity and consistency" "Check file names for cases")
INCREMENTAL_FLAG=""
DEFAULT_BAZEL_CONFIGS="--config=hdfs --config=gcp"
diff --git a/tensorflow/tools/dist_test/build_server.sh b/tensorflow/tools/dist_test/build_server.sh
index 878fabd248..225c034741 100755
--- a/tensorflow/tools/dist_test/build_server.sh
+++ b/tensorflow/tools/dist_test/build_server.sh
@@ -16,14 +16,15 @@
#
# Builds the test server for distributed (GRPC) TensorFlow
#
-# Usage: build_server.sh <docker_image_name> <whl_url> [--test]
+# Usage: build_server.sh <docker_image_name> <whl_file_location> [--test]
#
# Arguments:
# docker_image_name: Name of the docker image to build.
# E.g.: tensorflow/tf_grpc_test_server:0.11.0rc1
#
-# whl_url: URL from which the TensorFlow whl file will be downloaded.
+# whl_file_location: URL from which the TensorFlow whl file will be downloaded.
# E.g.: https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-0.11.0rc1-cp27-none-linux_x86_64.whl
+# E.g.: /path/to/folder/tensorflow-0.11.0rc1-cp27-none-linux_x86_64.whl
#
# The optional flag --test lets the script to use the Dockerfile for the
# testing GRPC server. Without the flag, the script will build the non-test
@@ -41,11 +42,11 @@ die() {
# Check arguments
if [[ $# -lt 2 ]]; then
- die "Usage: $0 <docker_image_name> <whl_url> [--test]"
+ die "Usage: $0 <docker_image_name> <whl_location> [--test]"
fi
DOCKER_IMG_NAME=$1
-WHL_URL=$2
+WHL_FILE_LOCATION=$2
shift 2
# Current script directory
@@ -53,7 +54,7 @@ DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
BUILD_DIR=$(mktemp -d)
echo ""
-echo "Using whl file URL: ${WHL_URL}"
+echo "Using whl file URL: ${WHL_FILE_LOCATION}"
echo "Building in temporary directory: ${BUILD_DIR}"
cp -r ${DIR}/* "${BUILD_DIR}"/ || \
@@ -65,9 +66,15 @@ if [[ $1 == "--test" ]]; then
fi
echo "Using Docker file: ${DOCKER_FILE}"
+if [[ $WHL_FILE_LOCATION =~ 'http://' || $WHL_FILE_LOCATION =~ 'https://' ]]; then
+ # Download whl file into the build context directory.
+ wget -P "${BUILD_DIR}" "${WHL_FILE_LOCATION}" || \
+ die "Failed to download tensorflow whl file from URL: ${WHL_FILE_LOCATION}"
+else
+ cp "${WHL_FILE_LOCATION}" "${BUILD_DIR}"
+fi
+
# Download whl file into the build context directory.
-wget -P "${BUILD_DIR}" ${WHL_URL} || \
- die "Failed to download tensorflow whl file from URL: ${WHL_URL}"
if [[ ! -f "${DOCKER_FILE}" ]]; then
die "ERROR: Unable to find dockerfile: ${DOCKER_FILE}"
diff --git a/tensorflow/tools/graph_transforms/sparsify_gather.cc b/tensorflow/tools/graph_transforms/sparsify_gather.cc
index 593c654f9f..9c583d83ca 100644
--- a/tensorflow/tools/graph_transforms/sparsify_gather.cc
+++ b/tensorflow/tools/graph_transforms/sparsify_gather.cc
@@ -181,6 +181,14 @@ Status ObtainVariableInfo(
return Status::OK();
}
+Status RemoveInputAtIndex(NodeDef* n, int index) {
+ for (int i = index; i < n->input_size() - 1; i++) {
+ n->mutable_input()->SwapElements(i, i + 1);
+ }
+ n->mutable_input()->RemoveLast();
+ return Status::OK();
+}
+
Status SparsifyGatherInternal(
const GraphDef& input_graph_def,
const std::unique_ptr<std::unordered_map<string, string> >&
@@ -301,13 +309,13 @@ Status SparsifyGatherInternal(
TF_RETURN_IF_ERROR(ReadTensorFromCheckpoint(
weights_node.name(), ckpt_reader,
(*shapes_and_slices)[weights_node.name()], &weight));
- // Add both both weight and identity node names.
- removed_node_names.push_back(weights_node.name());
- removed_node_names.push_back(match.inputs[0].node.name());
- for (auto input_node : match.inputs[0].node.input()) {
- auto parsed_input = StringReplace(input_node, "^", "", true);
- refs[parsed_input]--;
- }
+ }
+ // Add both both weight and identity node names.
+ removed_node_names.push_back(weights_node.name());
+ removed_node_names.push_back(match.inputs[0].node.name());
+ for (auto input_node : match.inputs[0].node.input()) {
+ auto parsed_input = StringReplace(input_node, "^", "", true);
+ refs[parsed_input]--;
}
Tensor indices_tensor;
Tensor values_tensor;
@@ -468,26 +476,49 @@ Status SparsifyGatherInternal(
continue;
}
int j = 0;
+ bool deleted_inputs = false;
while (j < replaced_graph_def.node(i).input_size()) {
if (replaced_graph_def.node(i).input(j) == name ||
replaced_graph_def.node(i).input(j) == ("^" + name)) {
- replaced_graph_def.mutable_node(i)->mutable_input()->SwapElements(
- j, replaced_graph_def.node(i).input_size() - 1);
- replaced_graph_def.mutable_node(i)->mutable_input()->RemoveLast();
+ TF_RETURN_IF_ERROR(
+ RemoveInputAtIndex(replaced_graph_def.mutable_node(i), j));
+ deleted_inputs = true;
continue;
}
j++;
}
- if (!replaced_graph_def.node(i).input_size()) {
- if ((refs.find(replaced_graph_def.node(i).name()) != refs.end()) &&
- (refs[replaced_graph_def.node(i).name()] == 0)) {
+ if (deleted_inputs) {
+ if (replaced_graph_def.node(i).op() == "ConcatV2") {
+ if (replaced_graph_def.node(i).input_size() > 2) {
+ SetNodeAttr("N", replaced_graph_def.node(i).input_size() - 1,
+ replaced_graph_def.mutable_node(i));
+ } else if (replaced_graph_def.node(i).input_size() == 2) {
+ if (refs[replaced_graph_def.node(i).input(1)] != 1) {
+ return errors::Internal(
+ "Expect axis tensor of ConcatV2 node to only be referenced "
+ "once.");
+ }
+ refs[replaced_graph_def.node(i).input(1)] -= 1;
+ removed_node_names.push_back(replaced_graph_def.node(i).input(1));
+ replaced_graph_def.mutable_node(i)->mutable_input()->RemoveLast();
+ replaced_graph_def.mutable_node(i)->mutable_attr()->erase("N");
+ replaced_graph_def.mutable_node(i)->set_op("Identity");
+ } else {
+ return errors::Internal(
+ "ConcatV2 should have at least two elements");
+ }
+ }
+ if ((replaced_graph_def.node(i).op() == "Assign" ||
+ replaced_graph_def.node(i).op() == "Reshape" ||
+ replaced_graph_def.node(i).op() == "Equal" ||
+ replaced_graph_def.node(i).op() == "Mean" ||
+ replaced_graph_def.node(i).op() == "ScalarSummary") &&
+ replaced_graph_def.node(i).input_size() == 1) {
+ removed_node_names.push_back(replaced_graph_def.node(i).name());
+ }
+ if (!replaced_graph_def.node(i).input_size()) {
removed_node_names.push_back(replaced_graph_def.node(i).name());
}
- }
-
- if (replaced_graph_def.node(i).op() == "Assign" &&
- replaced_graph_def.node(i).input_size() == 1) {
- removed_node_names.push_back(replaced_graph_def.node(i).name());
}
i++;
}
@@ -528,17 +559,22 @@ Status SparsifyGather(const GraphDef& input_graph_def,
};
// clang-format on
+ GraphDef cleaned_input_graph_def;
+ RemoveAttributes(input_graph_def, {"_output_shapes"},
+ &cleaned_input_graph_def);
+
GraphDef temp_output;
std::unique_ptr<BundleReader> ckpt_reader;
TF_RETURN_IF_ERROR(InitializeCheckpointReader(context, &ckpt_reader));
std::unique_ptr<std::unordered_map<string, string> > shapes_and_slices;
- TF_RETURN_IF_ERROR(ObtainVariableInfo(input_graph_def, &shapes_and_slices));
+ TF_RETURN_IF_ERROR(
+ ObtainVariableInfo(cleaned_input_graph_def, &shapes_and_slices));
- TF_RETURN_IF_ERROR(SparsifyGatherInternal(input_graph_def, shapes_and_slices,
- context, gather_pattern,
- ckpt_reader, &temp_output));
+ TF_RETURN_IF_ERROR(SparsifyGatherInternal(
+ cleaned_input_graph_def, shapes_and_slices, context, gather_pattern,
+ ckpt_reader, &temp_output));
TF_RETURN_IF_ERROR(SparsifyGatherInternal(temp_output, shapes_and_slices,
context, gather_v2_pattern,
diff --git a/tensorflow/tools/graph_transforms/sparsify_gather_test.cc b/tensorflow/tools/graph_transforms/sparsify_gather_test.cc
index 6627df1331..203ed3e0f9 100644
--- a/tensorflow/tools/graph_transforms/sparsify_gather_test.cc
+++ b/tensorflow/tools/graph_transforms/sparsify_gather_test.cc
@@ -71,7 +71,7 @@ class SparsifyGatherTest : public ::testing::Test {
}
void TestSinglePartition(bool gather_v2, bool include_shared_init,
- bool test_variable,
+ bool test_variable, bool test_kept_concat,
const string& shared_init_name = "group_deps") {
GraphDef graph_def;
@@ -139,6 +139,26 @@ class SparsifyGatherTest : public ::testing::Test {
}
}
+ NodeDef* concat_axis_node =
+ CreateNode("linear/concat/axis", "Const", {}, &graph_def);
+ NodeDef* concat_input_node =
+ CreateNode("concat/input/node", "Const", {}, &graph_def);
+ NodeDef* concat_node = nullptr;
+ if (!test_kept_concat) {
+ concat_node = CreateNode(
+ "concat/node", "ConcatV2",
+ {identity_node, concat_input_node, concat_axis_node}, &graph_def);
+ SetNodeAttr("N", 2, concat_node);
+ } else {
+ NodeDef* concat_input_node_2 =
+ CreateNode("concat/input/node_2", "Const", {}, &graph_def);
+ concat_node = CreateNode("concat/node", "ConcatV2",
+ {identity_node, concat_input_node,
+ concat_input_node_2, concat_axis_node},
+ &graph_def);
+ SetNodeAttr("N", 3, concat_node);
+ }
+
// Run the op.
GraphDef result;
TransformFuncContext context;
@@ -166,6 +186,23 @@ class SparsifyGatherTest : public ::testing::Test {
EXPECT_EQ(1, node_lookup.count("ids"));
EXPECT_EQ("Const", node_lookup.at("ids")->op());
+ EXPECT_EQ(1, node_lookup.count("concat/node"));
+
+ if (!test_kept_concat) {
+ EXPECT_EQ(0, node_lookup.count("linear/concat/axis"));
+ EXPECT_EQ("Identity", node_lookup.at("concat/node")->op());
+ EXPECT_EQ(1, node_lookup.at("concat/node")->input_size());
+ EXPECT_EQ("concat/input/node", node_lookup.at("concat/node")->input(0));
+ } else {
+ EXPECT_EQ(1, node_lookup.count("linear/concat/axis"));
+ EXPECT_EQ("ConcatV2", node_lookup.at("concat/node")->op());
+ EXPECT_EQ(3, node_lookup.at("concat/node")->input_size());
+ EXPECT_EQ("concat/input/node", node_lookup.at("concat/node")->input(0));
+ EXPECT_EQ("concat/input/node_2", node_lookup.at("concat/node")->input(1));
+ EXPECT_EQ("linear/concat/axis", node_lookup.at("concat/node")->input(2));
+ EXPECT_EQ(2, node_lookup.at("concat/node")->attr().at("N").i());
+ }
+
EXPECT_EQ(1, node_lookup.count("w/part_1/indices"));
EXPECT_EQ("Const", node_lookup.at("w/part_1/indices")->op());
Tensor expected_indices_tensor(DT_INT64, TensorShape({3}));
@@ -344,6 +381,13 @@ class SparsifyGatherTest : public ::testing::Test {
MakeGather("gather1", gather_v2, identity_node1, input_node, &graph_def);
MakeGather("gather2", gather_v2, identity_node2, input_node, &graph_def);
+ NodeDef* concat_axis_node =
+ CreateNode("linear/concat/axis", "Const", {}, &graph_def);
+ NodeDef* concat_node = CreateNode(
+ "concat/node", "ConcatV2",
+ {identity_node1, identity_node2, concat_axis_node}, &graph_def);
+ SetNodeAttr("N", 2, concat_node);
+
// Shared init node
if (include_shared_init) {
if (!test_variable) {
@@ -515,6 +559,9 @@ class SparsifyGatherTest : public ::testing::Test {
node_lookup.at("gather2/LookupTableFind")->input(2));
EXPECT_EQ("gather2/LookupTableFind", node_lookup.at("gather2")->input(0));
+ EXPECT_EQ(0, node_lookup.count("linear/concat/axis"));
+ EXPECT_EQ(0, node_lookup.count("concat/node"));
+
// Check control deps.
EXPECT_EQ(2, node_lookup.at(shared_init_name)->input_size());
EXPECT_NE(std::find(node_lookup.at(shared_init_name)->input().begin(),
@@ -550,18 +597,31 @@ class SparsifyGatherTest : public ::testing::Test {
};
TEST_F(SparsifyGatherTest, TestSinglePartition) {
- TestSinglePartition(false, false, false);
- TestSinglePartition(false, true, false);
- TestSinglePartition(true, false, false);
- TestSinglePartition(true, true, false);
- TestSinglePartition(false, false, true);
- TestSinglePartition(false, true, true);
- TestSinglePartition(true, false, true);
- TestSinglePartition(true, true, true);
- TestSinglePartition(false, true, false, "shared_inits");
- TestSinglePartition(true, true, false, "shared_inits");
- TestSinglePartition(false, true, true, "shared_inits");
- TestSinglePartition(true, true, true, "shared_inits");
+ TestSinglePartition(false, false, false, false);
+ TestSinglePartition(false, true, false, false);
+ TestSinglePartition(true, false, false, false);
+ TestSinglePartition(true, true, false, false);
+ TestSinglePartition(false, false, true, false);
+ TestSinglePartition(false, true, true, false);
+ TestSinglePartition(true, false, true, false);
+ TestSinglePartition(true, true, true, false);
+ TestSinglePartition(false, true, false, false, "shared_inits");
+ TestSinglePartition(true, true, false, false, "shared_inits");
+ TestSinglePartition(false, true, true, false, "shared_inits");
+ TestSinglePartition(true, true, true, false, "shared_inits");
+
+ TestSinglePartition(false, false, false, true);
+ TestSinglePartition(false, true, false, true);
+ TestSinglePartition(true, false, false, true);
+ TestSinglePartition(true, true, false, true);
+ TestSinglePartition(false, false, true, true);
+ TestSinglePartition(false, true, true, true);
+ TestSinglePartition(true, false, true, true);
+ TestSinglePartition(true, true, true, true);
+ TestSinglePartition(false, true, false, true, "shared_inits");
+ TestSinglePartition(true, true, false, true, "shared_inits");
+ TestSinglePartition(false, true, true, true, "shared_inits");
+ TestSinglePartition(true, true, true, true, "shared_inits");
}
TEST_F(SparsifyGatherTest, TestMultiPartition) {
diff --git a/tensorflow/tools/pip_package/BUILD b/tensorflow/tools/pip_package/BUILD
index 598080ed27..e4fa6694d8 100644
--- a/tensorflow/tools/pip_package/BUILD
+++ b/tensorflow/tools/pip_package/BUILD
@@ -151,9 +151,10 @@ sh_binary(
"//tensorflow/contrib/ndlstm:ndlstm",
"//tensorflow/contrib/nn:nn_py",
"//tensorflow/contrib/predictor:predictor_pip",
- "//tensorflow/contrib/py2tf:py2tf_internal",
+ "//tensorflow/contrib/py2tf:py2tf",
"//tensorflow/contrib/py2tf/converters:converters",
"//tensorflow/contrib/py2tf/converters:test_lib",
+ "//tensorflow/contrib/py2tf/impl:impl",
"//tensorflow/contrib/py2tf/pyct:pyct",
"//tensorflow/contrib/py2tf/pyct/static_analysis:static_analysis",
"//tensorflow/contrib/receptive_field:receptive_field_pip",
diff --git a/tensorflow/tools/pip_package/pip_smoke_test.py b/tensorflow/tools/pip_package/pip_smoke_test.py
index 38a9007387..73d759eb13 100644
--- a/tensorflow/tools/pip_package/pip_smoke_test.py
+++ b/tensorflow/tools/pip_package/pip_smoke_test.py
@@ -65,7 +65,6 @@ BLACKLIST = [
"//tensorflow/contrib/framework:checkpoint_ops_testdata",
"//tensorflow/contrib/bayesflow:reinforce_simple_example",
"//tensorflow/contrib/bayesflow:examples/reinforce_simple/reinforce_simple_example.py", # pylint:disable=line-too-long
- "//tensorflow/contrib/py2tf:py2tf_internal",
"//tensorflow/contrib/timeseries/examples:predict",
"//tensorflow/contrib/timeseries/examples:multivariate",
"//tensorflow/contrib/timeseries/examples:known_anomaly",
diff --git a/tensorflow/tools/pip_package/setup.py b/tensorflow/tools/pip_package/setup.py
index a456300fae..6c9b5e46ee 100644
--- a/tensorflow/tools/pip_package/setup.py
+++ b/tensorflow/tools/pip_package/setup.py
@@ -35,6 +35,7 @@ REQUIRED_PACKAGES = [
'absl-py >= 0.1.6',
'astor >= 0.6.0',
'gast >= 0.2.0',
+ 'grpcio >= 1.8.6',
'numpy >= 1.12.1',
'six >= 1.10.0',
'protobuf >= 3.4.0',
diff --git a/tensorflow/tools/test/file_name_test.py b/tensorflow/tools/test/file_name_test.py
new file mode 100644
index 0000000000..16fb8a822d
--- /dev/null
+++ b/tensorflow/tools/test/file_name_test.py
@@ -0,0 +1,48 @@
+#!/usr/bin/python
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+#
+# Test that checks if we have any issues with case insensitive filesystems.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+BASE_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), '../..'))
+ERROR_MESSAGE = """
+Files with same name but different case detected in directory: {}
+"""
+
+
+def main():
+ # Make sure BASE_DIR ends with tensorflow. If it doesn't, we probably
+ # computed the wrong directory.
+ if os.path.split(BASE_DIR)[-1] != 'tensorflow':
+ raise AssertionError(
+ "BASE_DIR = '%s' doesn't end with tensorflow" % BASE_DIR)
+
+ for dirpath, dirnames, filenames in os.walk(BASE_DIR, followlinks=True):
+ lowercase_directories = [x.lower() for x in dirnames]
+ lowercase_files = [x.lower() for x in filenames]
+
+ lowercase_dir_contents = lowercase_directories + lowercase_files
+ if len(lowercase_dir_contents) != len(set(lowercase_dir_contents)):
+ raise AssertionError(ERROR_MESSAGE.format(dirpath))
+
+
+if __name__ == '__main__':
+ main()
diff --git a/tensorflow/tools/test/run_and_gather_logs_lib.py b/tensorflow/tools/test/run_and_gather_logs_lib.py
index a953ed1b53..3b4921bb98 100644
--- a/tensorflow/tools/test/run_and_gather_logs_lib.py
+++ b/tensorflow/tools/test/run_and_gather_logs_lib.py
@@ -136,7 +136,7 @@ def run_and_gather_logs(name, test_name, test_args,
gpu_config = gpu_info_lib.gather_gpu_devices()
if gpu_config:
gpu_name = gpu_config[0].model
- gpu_short_name_match = re.search(r"Tesla (K40|K80|P100)", gpu_name)
+ gpu_short_name_match = re.search(r"Tesla (K40|K80|P100|V100)", gpu_name)
if gpu_short_name_match:
gpu_short_name = gpu_short_name_match.group(0)
test_adjusted_name = name + "|" + gpu_short_name.replace(" ", "_")
diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl
index 660520d36a..f965bd696f 100644
--- a/tensorflow/workspace.bzl
+++ b/tensorflow/workspace.bzl
@@ -352,11 +352,11 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
tf_http_archive(
name = "protobuf_archive",
urls = [
- "https://mirror.bazel.build/github.com/google/protobuf/archive/b04e5cba356212e4e8c66c61bbe0c3a20537c5b9.tar.gz",
- "https://github.com/google/protobuf/archive/b04e5cba356212e4e8c66c61bbe0c3a20537c5b9.tar.gz",
+ "https://mirror.bazel.build/github.com/google/protobuf/archive/396336eb961b75f03b25824fe86cf6490fb75e3a.tar.gz",
+ "https://github.com/google/protobuf/archive/396336eb961b75f03b25824fe86cf6490fb75e3a.tar.gz",
],
- sha256 = "e178a25c52efcb6b05988bdbeace4c0d3f2d2fe5b46696d1d9898875c3803d6a",
- strip_prefix = "protobuf-b04e5cba356212e4e8c66c61bbe0c3a20537c5b9",
+ sha256 = "846d907acf472ae233ec0882ef3a2d24edbbe834b80c305e867ac65a1f2c59e3",
+ strip_prefix = "protobuf-396336eb961b75f03b25824fe86cf6490fb75e3a",
)
# We need to import the protobuf library under the names com_google_protobuf
@@ -365,21 +365,21 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
tf_http_archive(
name = "com_google_protobuf",
urls = [
- "https://mirror.bazel.build/github.com/google/protobuf/archive/b04e5cba356212e4e8c66c61bbe0c3a20537c5b9.tar.gz",
- "https://github.com/google/protobuf/archive/b04e5cba356212e4e8c66c61bbe0c3a20537c5b9.tar.gz",
+ "https://mirror.bazel.build/github.com/google/protobuf/archive/396336eb961b75f03b25824fe86cf6490fb75e3a.tar.gz",
+ "https://github.com/google/protobuf/archive/396336eb961b75f03b25824fe86cf6490fb75e3a.tar.gz",
],
- sha256 = "e178a25c52efcb6b05988bdbeace4c0d3f2d2fe5b46696d1d9898875c3803d6a",
- strip_prefix = "protobuf-b04e5cba356212e4e8c66c61bbe0c3a20537c5b9",
+ sha256 = "846d907acf472ae233ec0882ef3a2d24edbbe834b80c305e867ac65a1f2c59e3",
+ strip_prefix = "protobuf-396336eb961b75f03b25824fe86cf6490fb75e3a",
)
tf_http_archive(
name = "com_google_protobuf_cc",
urls = [
- "https://mirror.bazel.build/github.com/google/protobuf/archive/b04e5cba356212e4e8c66c61bbe0c3a20537c5b9.tar.gz",
- "https://github.com/google/protobuf/archive/b04e5cba356212e4e8c66c61bbe0c3a20537c5b9.tar.gz",
+ "https://mirror.bazel.build/github.com/google/protobuf/archive/396336eb961b75f03b25824fe86cf6490fb75e3a.tar.gz",
+ "https://github.com/google/protobuf/archive/396336eb961b75f03b25824fe86cf6490fb75e3a.tar.gz",
],
- sha256 = "e178a25c52efcb6b05988bdbeace4c0d3f2d2fe5b46696d1d9898875c3803d6a",
- strip_prefix = "protobuf-b04e5cba356212e4e8c66c61bbe0c3a20537c5b9",
+ sha256 = "846d907acf472ae233ec0882ef3a2d24edbbe834b80c305e867ac65a1f2c59e3",
+ strip_prefix = "protobuf-396336eb961b75f03b25824fe86cf6490fb75e3a",
)
tf_http_archive(
@@ -472,11 +472,11 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
tf_http_archive(
name = "llvm",
urls = [
- "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/36a30fc7c9ee6fdfe5157190ad15c1801b1ab2de.tar.gz",
- "https://github.com/llvm-mirror/llvm/archive/36a30fc7c9ee6fdfe5157190ad15c1801b1ab2de.tar.gz",
+ "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/f135378ec6365e852f7d5a3cfcdce342f08cb5f3.tar.gz",
+ "https://github.com/llvm-mirror/llvm/archive/f135378ec6365e852f7d5a3cfcdce342f08cb5f3.tar.gz",
],
- sha256 = "3b74ecd8f59c712b4daf715a4da15c43ebdd40edcd4c30737bffef62f6a2bc9d",
- strip_prefix = "llvm-36a30fc7c9ee6fdfe5157190ad15c1801b1ab2de",
+ sha256 = "296ab832167e6c46eb65ef1f9a2b5fc31c77fcd2248799b306aa2d5d2e4edbfe",
+ strip_prefix = "llvm-f135378ec6365e852f7d5a3cfcdce342f08cb5f3",
build_file = str(Label("//third_party/llvm:llvm.BUILD")),
)